79 lines
2.4 KiB
Python
79 lines
2.4 KiB
Python
from os.path import expanduser
|
|
from dataManager.gFunDataset import gFunDataset
|
|
from dataManager.multiNewsDataset import MultiNewsDataset
|
|
from dataManager.amazonDataset import AmazonDataset
|
|
|
|
|
|
def get_dataset(dataset_name, args):
|
|
assert dataset_name in [
|
|
"multinews",
|
|
"amazon",
|
|
"rcv1-2",
|
|
"glami",
|
|
"cls",
|
|
], "dataset not supported"
|
|
|
|
RCV_DATAPATH = expanduser(
|
|
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
|
)
|
|
JRC_DATAPATH = expanduser(
|
|
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
|
)
|
|
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
|
|
|
|
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
|
|
|
|
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
|
|
|
if dataset_name == "multinews":
|
|
# TODO: convert to gFunDataset
|
|
raise NotImplementedError
|
|
dataset = MultiNewsDataset(
|
|
expanduser(MULTINEWS_DATAPATH),
|
|
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
|
|
)
|
|
elif dataset_name == "amazon":
|
|
# TODO: convert to gFunDataset
|
|
raise NotImplementedError
|
|
dataset = AmazonDataset(
|
|
domains=args.domains,
|
|
nrows=args.nrows,
|
|
min_count=args.min_count,
|
|
max_labels=args.max_labels,
|
|
)
|
|
elif dataset_name == "jrc":
|
|
dataset = gFunDataset(
|
|
dataset_dir=JRC_DATAPATH,
|
|
is_textual=True,
|
|
is_visual=False,
|
|
is_multilabel=True,
|
|
nrows=args.nrows,
|
|
)
|
|
elif dataset_name == "rcv1-2":
|
|
dataset = gFunDataset(
|
|
dataset_dir=RCV_DATAPATH,
|
|
is_textual=True,
|
|
is_visual=False,
|
|
is_multilabel=True,
|
|
nrows=args.nrows,
|
|
)
|
|
elif dataset_name == "glami":
|
|
dataset = gFunDataset(
|
|
dataset_dir=GLAMI_DATAPATH,
|
|
is_textual=True,
|
|
is_visual=True,
|
|
is_multilabel=False,
|
|
nrows=args.nrows,
|
|
)
|
|
elif dataset_name == "cls":
|
|
dataset = gFunDataset(
|
|
dataset_dir=CLS_DATAPATH,
|
|
is_textual=True,
|
|
is_visual=False,
|
|
is_multilabel=False,
|
|
nrows=args.nrows,
|
|
)
|
|
else:
|
|
raise NotImplementedError
|
|
return dataset
|