from os.path import expanduser from dataManager.gFunDataset import gFunDataset from dataManager.multiNewsDataset import MultiNewsDataset from dataManager.amazonDataset import AmazonDataset def get_dataset(dataset_name, args): assert dataset_name in [ "multinews", "amazon", "rcv1-2", "glami", "cls", ], "dataset not supported" RCV_DATAPATH = expanduser( "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" ) JRC_DATAPATH = expanduser( "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle" ) CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl") MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/") GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset") if dataset_name == "multinews": # TODO: convert to gFunDataset raise NotImplementedError dataset = MultiNewsDataset( expanduser(MULTINEWS_DATAPATH), excluded_langs=["ar", "pe", "pl", "tr", "ua"], ) elif dataset_name == "amazon": # TODO: convert to gFunDataset raise NotImplementedError dataset = AmazonDataset( domains=args.domains, nrows=args.nrows, min_count=args.min_count, max_labels=args.max_labels, ) elif dataset_name == "jrc": dataset = gFunDataset( dataset_dir=JRC_DATAPATH, is_textual=True, is_visual=False, is_multilabel=True, nrows=args.nrows, ) elif dataset_name == "rcv1-2": dataset = gFunDataset( dataset_dir=RCV_DATAPATH, is_textual=True, is_visual=False, is_multilabel=True, nrows=args.nrows, ) elif dataset_name == "glami": dataset = gFunDataset( dataset_dir=GLAMI_DATAPATH, is_textual=True, is_visual=True, is_multilabel=False, nrows=args.nrows, ) elif dataset_name == "cls": dataset = gFunDataset( dataset_dir=CLS_DATAPATH, is_textual=True, is_visual=False, is_multilabel=False, nrows=args.nrows, ) else: raise NotImplementedError return dataset