moved dataloader function get_dataset
This commit is contained in:
parent
77227bbe13
commit
f274ec7615
|
@ -0,0 +1,78 @@
|
||||||
|
from os.path import expanduser
|
||||||
|
from dataManager.gFunDataset import gFunDataset
|
||||||
|
from dataManager.multiNewsDataset import MultiNewsDataset
|
||||||
|
from dataManager.amazonDataset import AmazonDataset
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset(dataset_name, args):
|
||||||
|
assert dataset_name in [
|
||||||
|
"multinews",
|
||||||
|
"amazon",
|
||||||
|
"rcv1-2",
|
||||||
|
"glami",
|
||||||
|
"cls",
|
||||||
|
], "dataset not supported"
|
||||||
|
|
||||||
|
RCV_DATAPATH = expanduser(
|
||||||
|
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
||||||
|
)
|
||||||
|
JRC_DATAPATH = expanduser(
|
||||||
|
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
||||||
|
)
|
||||||
|
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
|
||||||
|
|
||||||
|
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
|
||||||
|
|
||||||
|
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
||||||
|
|
||||||
|
if dataset_name == "multinews":
|
||||||
|
# TODO: convert to gFunDataset
|
||||||
|
raise NotImplementedError
|
||||||
|
dataset = MultiNewsDataset(
|
||||||
|
expanduser(MULTINEWS_DATAPATH),
|
||||||
|
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
|
||||||
|
)
|
||||||
|
elif dataset_name == "amazon":
|
||||||
|
# TODO: convert to gFunDataset
|
||||||
|
raise NotImplementedError
|
||||||
|
dataset = AmazonDataset(
|
||||||
|
domains=args.domains,
|
||||||
|
nrows=args.nrows,
|
||||||
|
min_count=args.min_count,
|
||||||
|
max_labels=args.max_labels,
|
||||||
|
)
|
||||||
|
elif dataset_name == "jrc":
|
||||||
|
dataset = gFunDataset(
|
||||||
|
dataset_dir=JRC_DATAPATH,
|
||||||
|
is_textual=True,
|
||||||
|
is_visual=False,
|
||||||
|
is_multilabel=True,
|
||||||
|
nrows=args.nrows,
|
||||||
|
)
|
||||||
|
elif dataset_name == "rcv1-2":
|
||||||
|
dataset = gFunDataset(
|
||||||
|
dataset_dir=RCV_DATAPATH,
|
||||||
|
is_textual=True,
|
||||||
|
is_visual=False,
|
||||||
|
is_multilabel=True,
|
||||||
|
nrows=args.nrows,
|
||||||
|
)
|
||||||
|
elif dataset_name == "glami":
|
||||||
|
dataset = gFunDataset(
|
||||||
|
dataset_dir=GLAMI_DATAPATH,
|
||||||
|
is_textual=True,
|
||||||
|
is_visual=True,
|
||||||
|
is_multilabel=False,
|
||||||
|
nrows=args.nrows,
|
||||||
|
)
|
||||||
|
elif dataset_name == "cls":
|
||||||
|
dataset = gFunDataset(
|
||||||
|
dataset_dir=CLS_DATAPATH,
|
||||||
|
is_textual=True,
|
||||||
|
is_visual=False,
|
||||||
|
is_multilabel=False,
|
||||||
|
nrows=args.nrows,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return dataset
|
|
@ -5,10 +5,15 @@ from evaluation.metrics import *
|
||||||
|
|
||||||
def evaluation_metrics(y, y_):
|
def evaluation_metrics(y, y_):
|
||||||
if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label
|
if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label
|
||||||
raise NotImplementedError() # return f1_score(y,y_,average='macro'), f1_score(y,y_,average='micro')
|
raise NotImplementedError()
|
||||||
else: # the metrics I implemented assume multiclass multilabel classification as binary classifiers
|
else:
|
||||||
return macroF1(y, y_), microF1(y, y_), macroK(y, y_), microK(y, y_)
|
return (
|
||||||
# return macroF1(y, y_), microF1(y, y_), macroK(y, y_), macroAcc(y, y_)
|
macroF1(y, y_),
|
||||||
|
microF1(y, y_),
|
||||||
|
macroK(y, y_),
|
||||||
|
microK(y, y_),
|
||||||
|
# macroAcc(y, y_),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1):
|
def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1):
|
||||||
|
|
|
@ -334,7 +334,7 @@ class GeneralizedFunnelling:
|
||||||
pickle.dump(self.metaclassifier, f)
|
pickle.dump(self.metaclassifier, f)
|
||||||
return
|
return
|
||||||
|
|
||||||
def save_first_tier_learners(self, model_id):
|
def save_first_tier_learners(self):
|
||||||
for vgf in self.first_tier_learners:
|
for vgf in self.first_tier_learners:
|
||||||
vgf.save_vgf(model_id=self._model_id)
|
vgf.save_vgf(model_id=self._model_id)
|
||||||
return self
|
return self
|
||||||
|
|
86
main.py
86
main.py
|
@ -1,13 +1,7 @@
|
||||||
import pickle
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from os.path import expanduser
|
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from dataManager.amazonDataset import AmazonDataset
|
from dataManager.utils import get_dataset
|
||||||
from dataManager.multilingualDataset import MultilingualDataset
|
|
||||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
|
||||||
from dataManager.glamiDataset import GlamiDataset
|
|
||||||
from dataManager.gFunDataset import gFunDataset
|
|
||||||
from evaluation.evaluate import evaluate, log_eval
|
from evaluation.evaluate import evaluate, log_eval
|
||||||
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
|
|
||||||
|
@ -25,85 +19,10 @@ TODO:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(datasetname, args):
|
|
||||||
assert datasetname in [
|
|
||||||
"multinews",
|
|
||||||
"amazon",
|
|
||||||
"rcv1-2",
|
|
||||||
"glami",
|
|
||||||
"cls",
|
|
||||||
], "dataset not supported"
|
|
||||||
|
|
||||||
RCV_DATAPATH = expanduser(
|
|
||||||
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
|
||||||
)
|
|
||||||
JRC_DATAPATH = expanduser(
|
|
||||||
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
|
||||||
)
|
|
||||||
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
|
|
||||||
|
|
||||||
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
|
|
||||||
|
|
||||||
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
|
||||||
|
|
||||||
if datasetname == "multinews":
|
|
||||||
# TODO: convert to gFunDataset
|
|
||||||
raise NotImplementedError
|
|
||||||
dataset = MultiNewsDataset(
|
|
||||||
expanduser(MULTINEWS_DATAPATH),
|
|
||||||
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
|
|
||||||
)
|
|
||||||
elif datasetname == "amazon":
|
|
||||||
# TODO: convert to gFunDataset
|
|
||||||
raise NotImplementedError
|
|
||||||
dataset = AmazonDataset(
|
|
||||||
domains=args.domains,
|
|
||||||
nrows=args.nrows,
|
|
||||||
min_count=args.min_count,
|
|
||||||
max_labels=args.max_labels,
|
|
||||||
)
|
|
||||||
elif datasetname == "rcv1-2":
|
|
||||||
dataset = gFunDataset(
|
|
||||||
dataset_dir=RCV_DATAPATH,
|
|
||||||
is_textual=True,
|
|
||||||
is_visual=False,
|
|
||||||
is_multilabel=True,
|
|
||||||
nrows=args.nrows,
|
|
||||||
)
|
|
||||||
elif datasetname == "glami":
|
|
||||||
dataset = gFunDataset(
|
|
||||||
dataset_dir=GLAMI_DATAPATH,
|
|
||||||
is_textual=True,
|
|
||||||
is_visual=True,
|
|
||||||
is_multilabel=False,
|
|
||||||
nrows=args.nrows,
|
|
||||||
)
|
|
||||||
elif datasetname == "cls":
|
|
||||||
dataset = gFunDataset(
|
|
||||||
dataset_dir=CLS_DATAPATH,
|
|
||||||
is_textual=True,
|
|
||||||
is_visual=False,
|
|
||||||
is_multilabel=False,
|
|
||||||
nrows=args.nrows,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
dataset = get_dataset(args.dataset, args)
|
dataset = get_dataset(args.dataset, args)
|
||||||
if (
|
|
||||||
isinstance(dataset, MultilingualDataset)
|
|
||||||
or isinstance(dataset, MultiNewsDataset)
|
|
||||||
or isinstance(dataset, GlamiDataset)
|
|
||||||
or isinstance(dataset, gFunDataset)
|
|
||||||
):
|
|
||||||
lX, lY = dataset.training()
|
lX, lY = dataset.training()
|
||||||
lX_te, lY_te = dataset.test()
|
lX_te, lY_te = dataset.test()
|
||||||
else:
|
|
||||||
lX = dataset.dX
|
|
||||||
lY = dataset.dY
|
|
||||||
|
|
||||||
tinit = time()
|
tinit = time()
|
||||||
|
|
||||||
|
@ -140,7 +59,7 @@ def main(args):
|
||||||
max_length=args.max_length,
|
max_length=args.max_length,
|
||||||
patience=args.patience,
|
patience=args.patience,
|
||||||
evaluate_step=args.evaluate_step,
|
evaluate_step=args.evaluate_step,
|
||||||
device="cuda",
|
device=args.device,
|
||||||
# Visual Transformer VGF params --------------
|
# Visual Transformer VGF params --------------
|
||||||
visual_transformer=args.visual_transformer,
|
visual_transformer=args.visual_transformer,
|
||||||
visual_transformer_name=args.visual_transformer_name,
|
visual_transformer_name=args.visual_transformer_name,
|
||||||
|
@ -186,6 +105,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("-l", "--load_trained", type=str, default=None)
|
parser.add_argument("-l", "--load_trained", type=str, default=None)
|
||||||
parser.add_argument("--meta", action="store_true")
|
parser.add_argument("--meta", action="store_true")
|
||||||
parser.add_argument("--nosave", action="store_true")
|
parser.add_argument("--nosave", action="store_true")
|
||||||
|
parser.add_argument("--device", type=str, default="cuda")
|
||||||
# Dataset parameters -------------------
|
# Dataset parameters -------------------
|
||||||
parser.add_argument("-d", "--dataset", type=str, default="rcv1-2")
|
parser.add_argument("-d", "--dataset", type=str, default="rcv1-2")
|
||||||
parser.add_argument("--domains", type=str, default="all")
|
parser.add_argument("--domains", type=str, default="all")
|
||||||
|
|
Loading…
Reference in New Issue