moved dataloader function get_dataset
This commit is contained in:
parent
77227bbe13
commit
f274ec7615
|
@ -0,0 +1,78 @@
|
|||
from os.path import expanduser
|
||||
from dataManager.gFunDataset import gFunDataset
|
||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
||||
from dataManager.amazonDataset import AmazonDataset
|
||||
|
||||
|
||||
def get_dataset(dataset_name, args):
|
||||
assert dataset_name in [
|
||||
"multinews",
|
||||
"amazon",
|
||||
"rcv1-2",
|
||||
"glami",
|
||||
"cls",
|
||||
], "dataset not supported"
|
||||
|
||||
RCV_DATAPATH = expanduser(
|
||||
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
||||
)
|
||||
JRC_DATAPATH = expanduser(
|
||||
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
||||
)
|
||||
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
|
||||
|
||||
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
|
||||
|
||||
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
||||
|
||||
if dataset_name == "multinews":
|
||||
# TODO: convert to gFunDataset
|
||||
raise NotImplementedError
|
||||
dataset = MultiNewsDataset(
|
||||
expanduser(MULTINEWS_DATAPATH),
|
||||
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
|
||||
)
|
||||
elif dataset_name == "amazon":
|
||||
# TODO: convert to gFunDataset
|
||||
raise NotImplementedError
|
||||
dataset = AmazonDataset(
|
||||
domains=args.domains,
|
||||
nrows=args.nrows,
|
||||
min_count=args.min_count,
|
||||
max_labels=args.max_labels,
|
||||
)
|
||||
elif dataset_name == "jrc":
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=JRC_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=False,
|
||||
is_multilabel=True,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
elif dataset_name == "rcv1-2":
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=RCV_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=False,
|
||||
is_multilabel=True,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
elif dataset_name == "glami":
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=GLAMI_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=True,
|
||||
is_multilabel=False,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
elif dataset_name == "cls":
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=CLS_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=False,
|
||||
is_multilabel=False,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return dataset
|
|
@ -5,10 +5,15 @@ from evaluation.metrics import *
|
|||
|
||||
def evaluation_metrics(y, y_):
|
||||
if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label
|
||||
raise NotImplementedError() # return f1_score(y,y_,average='macro'), f1_score(y,y_,average='micro')
|
||||
else: # the metrics I implemented assume multiclass multilabel classification as binary classifiers
|
||||
return macroF1(y, y_), microF1(y, y_), macroK(y, y_), microK(y, y_)
|
||||
# return macroF1(y, y_), microF1(y, y_), macroK(y, y_), macroAcc(y, y_)
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
return (
|
||||
macroF1(y, y_),
|
||||
microF1(y, y_),
|
||||
macroK(y, y_),
|
||||
microK(y, y_),
|
||||
# macroAcc(y, y_),
|
||||
)
|
||||
|
||||
|
||||
def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1):
|
||||
|
|
|
@ -334,7 +334,7 @@ class GeneralizedFunnelling:
|
|||
pickle.dump(self.metaclassifier, f)
|
||||
return
|
||||
|
||||
def save_first_tier_learners(self, model_id):
|
||||
def save_first_tier_learners(self):
|
||||
for vgf in self.first_tier_learners:
|
||||
vgf.save_vgf(model_id=self._model_id)
|
||||
return self
|
||||
|
|
86
main.py
86
main.py
|
@ -1,13 +1,7 @@
|
|||
import pickle
|
||||
from argparse import ArgumentParser
|
||||
from os.path import expanduser
|
||||
from time import time
|
||||
|
||||
from dataManager.amazonDataset import AmazonDataset
|
||||
from dataManager.multilingualDataset import MultilingualDataset
|
||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
||||
from dataManager.glamiDataset import GlamiDataset
|
||||
from dataManager.gFunDataset import gFunDataset
|
||||
from dataManager.utils import get_dataset
|
||||
from evaluation.evaluate import evaluate, log_eval
|
||||
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||
|
||||
|
@ -25,85 +19,10 @@ TODO:
|
|||
"""
|
||||
|
||||
|
||||
def get_dataset(datasetname, args):
|
||||
assert datasetname in [
|
||||
"multinews",
|
||||
"amazon",
|
||||
"rcv1-2",
|
||||
"glami",
|
||||
"cls",
|
||||
], "dataset not supported"
|
||||
|
||||
RCV_DATAPATH = expanduser(
|
||||
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
||||
)
|
||||
JRC_DATAPATH = expanduser(
|
||||
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
||||
)
|
||||
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
|
||||
|
||||
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
|
||||
|
||||
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
||||
|
||||
if datasetname == "multinews":
|
||||
# TODO: convert to gFunDataset
|
||||
raise NotImplementedError
|
||||
dataset = MultiNewsDataset(
|
||||
expanduser(MULTINEWS_DATAPATH),
|
||||
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
|
||||
)
|
||||
elif datasetname == "amazon":
|
||||
# TODO: convert to gFunDataset
|
||||
raise NotImplementedError
|
||||
dataset = AmazonDataset(
|
||||
domains=args.domains,
|
||||
nrows=args.nrows,
|
||||
min_count=args.min_count,
|
||||
max_labels=args.max_labels,
|
||||
)
|
||||
elif datasetname == "rcv1-2":
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=RCV_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=False,
|
||||
is_multilabel=True,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
elif datasetname == "glami":
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=GLAMI_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=True,
|
||||
is_multilabel=False,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
elif datasetname == "cls":
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=CLS_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=False,
|
||||
is_multilabel=False,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return dataset
|
||||
|
||||
|
||||
def main(args):
|
||||
dataset = get_dataset(args.dataset, args)
|
||||
if (
|
||||
isinstance(dataset, MultilingualDataset)
|
||||
or isinstance(dataset, MultiNewsDataset)
|
||||
or isinstance(dataset, GlamiDataset)
|
||||
or isinstance(dataset, gFunDataset)
|
||||
):
|
||||
lX, lY = dataset.training()
|
||||
lX_te, lY_te = dataset.test()
|
||||
else:
|
||||
lX = dataset.dX
|
||||
lY = dataset.dY
|
||||
|
||||
tinit = time()
|
||||
|
||||
|
@ -140,7 +59,7 @@ def main(args):
|
|||
max_length=args.max_length,
|
||||
patience=args.patience,
|
||||
evaluate_step=args.evaluate_step,
|
||||
device="cuda",
|
||||
device=args.device,
|
||||
# Visual Transformer VGF params --------------
|
||||
visual_transformer=args.visual_transformer,
|
||||
visual_transformer_name=args.visual_transformer_name,
|
||||
|
@ -186,6 +105,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("-l", "--load_trained", type=str, default=None)
|
||||
parser.add_argument("--meta", action="store_true")
|
||||
parser.add_argument("--nosave", action="store_true")
|
||||
parser.add_argument("--device", type=str, default="cuda")
|
||||
# Dataset parameters -------------------
|
||||
parser.add_argument("-d", "--dataset", type=str, default="rcv1-2")
|
||||
parser.add_argument("--domains", type=str, default="all")
|
||||
|
|
Loading…
Reference in New Issue