From f274ec7615788987236fc3f8476a1b8c54e1cad7 Mon Sep 17 00:00:00 2001 From: Andrea Pedrotti Date: Mon, 6 Mar 2023 12:40:12 +0100 Subject: [PATCH] moved dataloader function get_dataset --- dataManager/utils.py | 78 ++++++++++++++++++++++++++++++ evaluation/evaluate.py | 13 +++-- gfun/generalizedFunnelling.py | 2 +- main.py | 90 ++--------------------------------- 4 files changed, 93 insertions(+), 90 deletions(-) create mode 100644 dataManager/utils.py diff --git a/dataManager/utils.py b/dataManager/utils.py new file mode 100644 index 0000000..4dfa953 --- /dev/null +++ b/dataManager/utils.py @@ -0,0 +1,78 @@ +from os.path import expanduser +from dataManager.gFunDataset import gFunDataset +from dataManager.multiNewsDataset import MultiNewsDataset +from dataManager.amazonDataset import AmazonDataset + + +def get_dataset(dataset_name, args): + assert dataset_name in [ + "multinews", + "amazon", + "rcv1-2", + "glami", + "cls", + ], "dataset not supported" + + RCV_DATAPATH = expanduser( + "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" + ) + JRC_DATAPATH = expanduser( + "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle" + ) + CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl") + + MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/") + + GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset") + + if dataset_name == "multinews": + # TODO: convert to gFunDataset + raise NotImplementedError + dataset = MultiNewsDataset( + expanduser(MULTINEWS_DATAPATH), + excluded_langs=["ar", "pe", "pl", "tr", "ua"], + ) + elif dataset_name == "amazon": + # TODO: convert to gFunDataset + raise NotImplementedError + dataset = AmazonDataset( + domains=args.domains, + nrows=args.nrows, + min_count=args.min_count, + max_labels=args.max_labels, + ) + elif dataset_name == "jrc": + dataset = gFunDataset( + dataset_dir=JRC_DATAPATH, + is_textual=True, + is_visual=False, + is_multilabel=True, + nrows=args.nrows, + ) + elif dataset_name == "rcv1-2": + dataset = gFunDataset( + dataset_dir=RCV_DATAPATH, + is_textual=True, + is_visual=False, + is_multilabel=True, + nrows=args.nrows, + ) + elif dataset_name == "glami": + dataset = gFunDataset( + dataset_dir=GLAMI_DATAPATH, + is_textual=True, + is_visual=True, + is_multilabel=False, + nrows=args.nrows, + ) + elif dataset_name == "cls": + dataset = gFunDataset( + dataset_dir=CLS_DATAPATH, + is_textual=True, + is_visual=False, + is_multilabel=False, + nrows=args.nrows, + ) + else: + raise NotImplementedError + return dataset diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index 28c1649..10c2333 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -5,10 +5,15 @@ from evaluation.metrics import * def evaluation_metrics(y, y_): if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label - raise NotImplementedError() # return f1_score(y,y_,average='macro'), f1_score(y,y_,average='micro') - else: # the metrics I implemented assume multiclass multilabel classification as binary classifiers - return macroF1(y, y_), microF1(y, y_), macroK(y, y_), microK(y, y_) - # return macroF1(y, y_), microF1(y, y_), macroK(y, y_), macroAcc(y, y_) + raise NotImplementedError() + else: + return ( + macroF1(y, y_), + microF1(y, y_), + macroK(y, y_), + microK(y, y_), + # macroAcc(y, y_), + ) def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1): diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index 52f57a3..4d4f25d 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -334,7 +334,7 @@ class GeneralizedFunnelling: pickle.dump(self.metaclassifier, f) return - def save_first_tier_learners(self, model_id): + def save_first_tier_learners(self): for vgf in self.first_tier_learners: vgf.save_vgf(model_id=self._model_id) return self diff --git a/main.py b/main.py index 7198356..9b4ead6 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,7 @@ -import pickle from argparse import ArgumentParser -from os.path import expanduser from time import time -from dataManager.amazonDataset import AmazonDataset -from dataManager.multilingualDataset import MultilingualDataset -from dataManager.multiNewsDataset import MultiNewsDataset -from dataManager.glamiDataset import GlamiDataset -from dataManager.gFunDataset import gFunDataset +from dataManager.utils import get_dataset from evaluation.evaluate import evaluate, log_eval from gfun.generalizedFunnelling import GeneralizedFunnelling @@ -25,85 +19,10 @@ TODO: """ -def get_dataset(datasetname, args): - assert datasetname in [ - "multinews", - "amazon", - "rcv1-2", - "glami", - "cls", - ], "dataset not supported" - - RCV_DATAPATH = expanduser( - "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" - ) - JRC_DATAPATH = expanduser( - "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle" - ) - CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl") - - MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/") - - GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset") - - if datasetname == "multinews": - # TODO: convert to gFunDataset - raise NotImplementedError - dataset = MultiNewsDataset( - expanduser(MULTINEWS_DATAPATH), - excluded_langs=["ar", "pe", "pl", "tr", "ua"], - ) - elif datasetname == "amazon": - # TODO: convert to gFunDataset - raise NotImplementedError - dataset = AmazonDataset( - domains=args.domains, - nrows=args.nrows, - min_count=args.min_count, - max_labels=args.max_labels, - ) - elif datasetname == "rcv1-2": - dataset = gFunDataset( - dataset_dir=RCV_DATAPATH, - is_textual=True, - is_visual=False, - is_multilabel=True, - nrows=args.nrows, - ) - elif datasetname == "glami": - dataset = gFunDataset( - dataset_dir=GLAMI_DATAPATH, - is_textual=True, - is_visual=True, - is_multilabel=False, - nrows=args.nrows, - ) - elif datasetname == "cls": - dataset = gFunDataset( - dataset_dir=CLS_DATAPATH, - is_textual=True, - is_visual=False, - is_multilabel=False, - nrows=args.nrows, - ) - else: - raise NotImplementedError - return dataset - - def main(args): dataset = get_dataset(args.dataset, args) - if ( - isinstance(dataset, MultilingualDataset) - or isinstance(dataset, MultiNewsDataset) - or isinstance(dataset, GlamiDataset) - or isinstance(dataset, gFunDataset) - ): - lX, lY = dataset.training() - lX_te, lY_te = dataset.test() - else: - lX = dataset.dX - lY = dataset.dY + lX, lY = dataset.training() + lX_te, lY_te = dataset.test() tinit = time() @@ -140,7 +59,7 @@ def main(args): max_length=args.max_length, patience=args.patience, evaluate_step=args.evaluate_step, - device="cuda", + device=args.device, # Visual Transformer VGF params -------------- visual_transformer=args.visual_transformer, visual_transformer_name=args.visual_transformer_name, @@ -186,6 +105,7 @@ if __name__ == "__main__": parser.add_argument("-l", "--load_trained", type=str, default=None) parser.add_argument("--meta", action="store_true") parser.add_argument("--nosave", action="store_true") + parser.add_argument("--device", type=str, default="cuda") # Dataset parameters ------------------- parser.add_argument("-d", "--dataset", type=str, default="rcv1-2") parser.add_argument("--domains", type=str, default="all")