from os.path import expanduser, join from dataManager.gFunDataset import gFunDataset from dataManager.multiNewsDataset import MultiNewsDataset from dataManager.amazonDataset import AmazonDataset def load_from_pickle(path, dataset_name, nrows): import pickle filepath = join(path, f"{dataset_name}_{nrows}.pkl") with open(filepath, "rb") as f: loaded = pickle.load(f) print(f"- Loaded dataset from {filepath}") loaded.show_dimension() return loaded def get_dataset(dataset_name, args): assert dataset_name in [ "multinews", "amazon", "rcv1-2", "glami", "cls", "webis", ], "dataset not supported" RCV_DATAPATH = expanduser( "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" ) JRC_DATAPATH = expanduser( "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle" ) CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl") MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/") GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset") WEBIS_CLS = expanduser( "~/datasets/cls-acl10-unprocessed/cls-acl10-unprocessed-all.pkl" ) if dataset_name == "multinews": # TODO: convert to gFunDataset raise NotImplementedError dataset = MultiNewsDataset( expanduser(MULTINEWS_DATAPATH), excluded_langs=["ar", "pe", "pl", "tr", "ua"], ) elif dataset_name == "amazon": # TODO: convert to gFunDataset raise NotImplementedError dataset = AmazonDataset( domains=args.domains, nrows=args.nrows, min_count=args.min_count, max_labels=args.max_labels, ) elif dataset_name == "jrc": dataset = gFunDataset( dataset_dir=JRC_DATAPATH, is_textual=True, is_visual=False, is_multilabel=True, nrows=args.nrows, ) elif dataset_name == "rcv1-2": dataset = gFunDataset( dataset_dir=RCV_DATAPATH, is_textual=True, is_visual=False, is_multilabel=True, nrows=args.nrows, ) elif dataset_name == "glami": if args.save_dataset is False: dataset = load_from_pickle(GLAMI_DATAPATH, dataset_name, args.nrows) else: dataset = gFunDataset( dataset_dir=GLAMI_DATAPATH, is_textual=True, is_visual=True, is_multilabel=False, nrows=args.nrows, ) dataset.save_as_pickle(GLAMI_DATAPATH) elif dataset_name == "cls": dataset = gFunDataset( dataset_dir=CLS_DATAPATH, is_textual=True, is_visual=False, is_multilabel=False, nrows=args.nrows, ) elif dataset_name == "webis": dataset = gFunDataset( dataset_dir=WEBIS_CLS, is_textual=True, is_visual=False, is_multilabel=False, nrows=args.nrows, ) else: raise NotImplementedError return dataset