moved dataloader function get_dataset

This commit is contained in:
Andrea Pedrotti 2023-03-06 12:40:12 +01:00
parent 77227bbe13
commit f274ec7615
4 changed files with 93 additions and 90 deletions

78
dataManager/utils.py Normal file
View File

@ -0,0 +1,78 @@
from os.path import expanduser
from dataManager.gFunDataset import gFunDataset
from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.amazonDataset import AmazonDataset
def get_dataset(dataset_name, args):
assert dataset_name in [
"multinews",
"amazon",
"rcv1-2",
"glami",
"cls",
], "dataset not supported"
RCV_DATAPATH = expanduser(
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
)
JRC_DATAPATH = expanduser(
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
)
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
if dataset_name == "multinews":
# TODO: convert to gFunDataset
raise NotImplementedError
dataset = MultiNewsDataset(
expanduser(MULTINEWS_DATAPATH),
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
)
elif dataset_name == "amazon":
# TODO: convert to gFunDataset
raise NotImplementedError
dataset = AmazonDataset(
domains=args.domains,
nrows=args.nrows,
min_count=args.min_count,
max_labels=args.max_labels,
)
elif dataset_name == "jrc":
dataset = gFunDataset(
dataset_dir=JRC_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=True,
nrows=args.nrows,
)
elif dataset_name == "rcv1-2":
dataset = gFunDataset(
dataset_dir=RCV_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=True,
nrows=args.nrows,
)
elif dataset_name == "glami":
dataset = gFunDataset(
dataset_dir=GLAMI_DATAPATH,
is_textual=True,
is_visual=True,
is_multilabel=False,
nrows=args.nrows,
)
elif dataset_name == "cls":
dataset = gFunDataset(
dataset_dir=CLS_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=False,
nrows=args.nrows,
)
else:
raise NotImplementedError
return dataset

View File

@ -5,10 +5,15 @@ from evaluation.metrics import *
def evaluation_metrics(y, y_):
if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label
raise NotImplementedError() # return f1_score(y,y_,average='macro'), f1_score(y,y_,average='micro')
else: # the metrics I implemented assume multiclass multilabel classification as binary classifiers
return macroF1(y, y_), microF1(y, y_), macroK(y, y_), microK(y, y_)
# return macroF1(y, y_), microF1(y, y_), macroK(y, y_), macroAcc(y, y_)
raise NotImplementedError()
else:
return (
macroF1(y, y_),
microF1(y, y_),
macroK(y, y_),
microK(y, y_),
# macroAcc(y, y_),
)
def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1):

View File

@ -334,7 +334,7 @@ class GeneralizedFunnelling:
pickle.dump(self.metaclassifier, f)
return
def save_first_tier_learners(self, model_id):
def save_first_tier_learners(self):
for vgf in self.first_tier_learners:
vgf.save_vgf(model_id=self._model_id)
return self

86
main.py
View File

@ -1,13 +1,7 @@
import pickle
from argparse import ArgumentParser
from os.path import expanduser
from time import time
from dataManager.amazonDataset import AmazonDataset
from dataManager.multilingualDataset import MultilingualDataset
from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.glamiDataset import GlamiDataset
from dataManager.gFunDataset import gFunDataset
from dataManager.utils import get_dataset
from evaluation.evaluate import evaluate, log_eval
from gfun.generalizedFunnelling import GeneralizedFunnelling
@ -25,85 +19,10 @@ TODO:
"""
def get_dataset(datasetname, args):
assert datasetname in [
"multinews",
"amazon",
"rcv1-2",
"glami",
"cls",
], "dataset not supported"
RCV_DATAPATH = expanduser(
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
)
JRC_DATAPATH = expanduser(
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
)
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
if datasetname == "multinews":
# TODO: convert to gFunDataset
raise NotImplementedError
dataset = MultiNewsDataset(
expanduser(MULTINEWS_DATAPATH),
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
)
elif datasetname == "amazon":
# TODO: convert to gFunDataset
raise NotImplementedError
dataset = AmazonDataset(
domains=args.domains,
nrows=args.nrows,
min_count=args.min_count,
max_labels=args.max_labels,
)
elif datasetname == "rcv1-2":
dataset = gFunDataset(
dataset_dir=RCV_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=True,
nrows=args.nrows,
)
elif datasetname == "glami":
dataset = gFunDataset(
dataset_dir=GLAMI_DATAPATH,
is_textual=True,
is_visual=True,
is_multilabel=False,
nrows=args.nrows,
)
elif datasetname == "cls":
dataset = gFunDataset(
dataset_dir=CLS_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=False,
nrows=args.nrows,
)
else:
raise NotImplementedError
return dataset
def main(args):
dataset = get_dataset(args.dataset, args)
if (
isinstance(dataset, MultilingualDataset)
or isinstance(dataset, MultiNewsDataset)
or isinstance(dataset, GlamiDataset)
or isinstance(dataset, gFunDataset)
):
lX, lY = dataset.training()
lX_te, lY_te = dataset.test()
else:
lX = dataset.dX
lY = dataset.dY
tinit = time()
@ -140,7 +59,7 @@ def main(args):
max_length=args.max_length,
patience=args.patience,
evaluate_step=args.evaluate_step,
device="cuda",
device=args.device,
# Visual Transformer VGF params --------------
visual_transformer=args.visual_transformer,
visual_transformer_name=args.visual_transformer_name,
@ -186,6 +105,7 @@ if __name__ == "__main__":
parser.add_argument("-l", "--load_trained", type=str, default=None)
parser.add_argument("--meta", action="store_true")
parser.add_argument("--nosave", action="store_true")
parser.add_argument("--device", type=str, default="cuda")
# Dataset parameters -------------------
parser.add_argument("-d", "--dataset", type=str, default="rcv1-2")
parser.add_argument("--domains", type=str, default="all")