moved dataloader function get_dataset

This commit is contained in:
Andrea Pedrotti 2023-03-06 12:40:12 +01:00
parent 77227bbe13
commit f274ec7615
4 changed files with 93 additions and 90 deletions

78
dataManager/utils.py Normal file
View File

@ -0,0 +1,78 @@
from os.path import expanduser
from dataManager.gFunDataset import gFunDataset
from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.amazonDataset import AmazonDataset
def get_dataset(dataset_name, args):
assert dataset_name in [
"multinews",
"amazon",
"rcv1-2",
"glami",
"cls",
], "dataset not supported"
RCV_DATAPATH = expanduser(
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
)
JRC_DATAPATH = expanduser(
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
)
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
if dataset_name == "multinews":
# TODO: convert to gFunDataset
raise NotImplementedError
dataset = MultiNewsDataset(
expanduser(MULTINEWS_DATAPATH),
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
)
elif dataset_name == "amazon":
# TODO: convert to gFunDataset
raise NotImplementedError
dataset = AmazonDataset(
domains=args.domains,
nrows=args.nrows,
min_count=args.min_count,
max_labels=args.max_labels,
)
elif dataset_name == "jrc":
dataset = gFunDataset(
dataset_dir=JRC_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=True,
nrows=args.nrows,
)
elif dataset_name == "rcv1-2":
dataset = gFunDataset(
dataset_dir=RCV_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=True,
nrows=args.nrows,
)
elif dataset_name == "glami":
dataset = gFunDataset(
dataset_dir=GLAMI_DATAPATH,
is_textual=True,
is_visual=True,
is_multilabel=False,
nrows=args.nrows,
)
elif dataset_name == "cls":
dataset = gFunDataset(
dataset_dir=CLS_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=False,
nrows=args.nrows,
)
else:
raise NotImplementedError
return dataset

View File

@ -5,10 +5,15 @@ from evaluation.metrics import *
def evaluation_metrics(y, y_): def evaluation_metrics(y, y_):
if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label
raise NotImplementedError() # return f1_score(y,y_,average='macro'), f1_score(y,y_,average='micro') raise NotImplementedError()
else: # the metrics I implemented assume multiclass multilabel classification as binary classifiers else:
return macroF1(y, y_), microF1(y, y_), macroK(y, y_), microK(y, y_) return (
# return macroF1(y, y_), microF1(y, y_), macroK(y, y_), macroAcc(y, y_) macroF1(y, y_),
microF1(y, y_),
macroK(y, y_),
microK(y, y_),
# macroAcc(y, y_),
)
def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1): def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1):

View File

@ -334,7 +334,7 @@ class GeneralizedFunnelling:
pickle.dump(self.metaclassifier, f) pickle.dump(self.metaclassifier, f)
return return
def save_first_tier_learners(self, model_id): def save_first_tier_learners(self):
for vgf in self.first_tier_learners: for vgf in self.first_tier_learners:
vgf.save_vgf(model_id=self._model_id) vgf.save_vgf(model_id=self._model_id)
return self return self

90
main.py
View File

@ -1,13 +1,7 @@
import pickle
from argparse import ArgumentParser from argparse import ArgumentParser
from os.path import expanduser
from time import time from time import time
from dataManager.amazonDataset import AmazonDataset from dataManager.utils import get_dataset
from dataManager.multilingualDataset import MultilingualDataset
from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.glamiDataset import GlamiDataset
from dataManager.gFunDataset import gFunDataset
from evaluation.evaluate import evaluate, log_eval from evaluation.evaluate import evaluate, log_eval
from gfun.generalizedFunnelling import GeneralizedFunnelling from gfun.generalizedFunnelling import GeneralizedFunnelling
@ -25,85 +19,10 @@ TODO:
""" """
def get_dataset(datasetname, args):
assert datasetname in [
"multinews",
"amazon",
"rcv1-2",
"glami",
"cls",
], "dataset not supported"
RCV_DATAPATH = expanduser(
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
)
JRC_DATAPATH = expanduser(
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
)
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
if datasetname == "multinews":
# TODO: convert to gFunDataset
raise NotImplementedError
dataset = MultiNewsDataset(
expanduser(MULTINEWS_DATAPATH),
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
)
elif datasetname == "amazon":
# TODO: convert to gFunDataset
raise NotImplementedError
dataset = AmazonDataset(
domains=args.domains,
nrows=args.nrows,
min_count=args.min_count,
max_labels=args.max_labels,
)
elif datasetname == "rcv1-2":
dataset = gFunDataset(
dataset_dir=RCV_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=True,
nrows=args.nrows,
)
elif datasetname == "glami":
dataset = gFunDataset(
dataset_dir=GLAMI_DATAPATH,
is_textual=True,
is_visual=True,
is_multilabel=False,
nrows=args.nrows,
)
elif datasetname == "cls":
dataset = gFunDataset(
dataset_dir=CLS_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=False,
nrows=args.nrows,
)
else:
raise NotImplementedError
return dataset
def main(args): def main(args):
dataset = get_dataset(args.dataset, args) dataset = get_dataset(args.dataset, args)
if ( lX, lY = dataset.training()
isinstance(dataset, MultilingualDataset) lX_te, lY_te = dataset.test()
or isinstance(dataset, MultiNewsDataset)
or isinstance(dataset, GlamiDataset)
or isinstance(dataset, gFunDataset)
):
lX, lY = dataset.training()
lX_te, lY_te = dataset.test()
else:
lX = dataset.dX
lY = dataset.dY
tinit = time() tinit = time()
@ -140,7 +59,7 @@ def main(args):
max_length=args.max_length, max_length=args.max_length,
patience=args.patience, patience=args.patience,
evaluate_step=args.evaluate_step, evaluate_step=args.evaluate_step,
device="cuda", device=args.device,
# Visual Transformer VGF params -------------- # Visual Transformer VGF params --------------
visual_transformer=args.visual_transformer, visual_transformer=args.visual_transformer,
visual_transformer_name=args.visual_transformer_name, visual_transformer_name=args.visual_transformer_name,
@ -186,6 +105,7 @@ if __name__ == "__main__":
parser.add_argument("-l", "--load_trained", type=str, default=None) parser.add_argument("-l", "--load_trained", type=str, default=None)
parser.add_argument("--meta", action="store_true") parser.add_argument("--meta", action="store_true")
parser.add_argument("--nosave", action="store_true") parser.add_argument("--nosave", action="store_true")
parser.add_argument("--device", type=str, default="cuda")
# Dataset parameters ------------------- # Dataset parameters -------------------
parser.add_argument("-d", "--dataset", type=str, default="rcv1-2") parser.add_argument("-d", "--dataset", type=str, default="rcv1-2")
parser.add_argument("--domains", type=str, default="all") parser.add_argument("--domains", type=str, default="all")