2023-02-08 14:51:56 +01:00
|
|
|
import pickle
|
2023-02-07 18:40:17 +01:00
|
|
|
from argparse import ArgumentParser
|
2023-02-08 14:51:56 +01:00
|
|
|
from os.path import expanduser
|
|
|
|
from time import time
|
2023-02-07 18:40:17 +01:00
|
|
|
|
|
|
|
from dataManager.amazonDataset import AmazonDataset
|
|
|
|
from dataManager.multilingualDatset import MultilingualDataset
|
2023-02-08 14:51:56 +01:00
|
|
|
from dataManager.multiNewsDataset import MultiNewsDataset
|
2023-02-13 18:29:54 +01:00
|
|
|
from dataManager.glamiDataset import GlamiDataset
|
2023-02-07 18:40:17 +01:00
|
|
|
from evaluation.evaluate import evaluate, log_eval
|
2023-02-08 14:51:56 +01:00
|
|
|
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
2023-02-07 18:40:17 +01:00
|
|
|
|
2023-02-08 14:51:56 +01:00
|
|
|
"""
|
|
|
|
TODO:
|
|
|
|
- add documentations sphinx
|
|
|
|
- zero-shot setup
|
2023-02-13 15:01:50 +01:00
|
|
|
- load pre-trained VGFs while retaining ability to train new ones (self.fitted = True in loaded? or smt like that)
|
2023-02-10 11:37:32 +01:00
|
|
|
- test split in MultiNews dataset
|
2023-02-13 15:01:50 +01:00
|
|
|
- when we load a model and change its config (eg change the agg func, re-train meta), we should store this model as a new one (save it)
|
2023-02-08 14:51:56 +01:00
|
|
|
"""
|
2023-02-07 18:40:17 +01:00
|
|
|
|
|
|
|
|
2023-02-09 18:42:27 +01:00
|
|
|
def get_dataset(datasetname):
|
2023-02-13 18:29:54 +01:00
|
|
|
assert datasetname in [
|
|
|
|
"multinews",
|
|
|
|
"amazon",
|
|
|
|
"rcv1-2",
|
|
|
|
"glami",
|
|
|
|
], "dataset not supported"
|
2023-02-13 15:01:50 +01:00
|
|
|
|
2023-02-07 18:40:17 +01:00
|
|
|
RCV_DATAPATH = expanduser(
|
|
|
|
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
|
|
|
|
)
|
2023-02-13 15:01:50 +01:00
|
|
|
JRC_DATAPATH = expanduser(
|
|
|
|
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
|
|
|
|
)
|
2023-02-09 18:42:27 +01:00
|
|
|
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
|
2023-02-13 15:01:50 +01:00
|
|
|
|
2023-02-13 18:29:54 +01:00
|
|
|
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
|
|
|
|
2023-02-09 18:42:27 +01:00
|
|
|
if datasetname == "multinews":
|
|
|
|
dataset = MultiNewsDataset(
|
|
|
|
expanduser(MULTINEWS_DATAPATH),
|
|
|
|
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
|
|
|
|
)
|
|
|
|
elif datasetname == "amazon":
|
|
|
|
dataset = AmazonDataset(
|
|
|
|
domains=args.domains,
|
|
|
|
nrows=args.nrows,
|
|
|
|
min_count=args.min_count,
|
|
|
|
max_labels=args.max_labels,
|
|
|
|
)
|
|
|
|
elif datasetname == "rcv1-2":
|
2023-02-13 15:01:50 +01:00
|
|
|
dataset = MultilingualDataset(dataset_name="rcv1-2").load(RCV_DATAPATH)
|
|
|
|
if args.nrows is not None:
|
|
|
|
dataset.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows)
|
2023-02-13 18:29:54 +01:00
|
|
|
elif datasetname == "glami":
|
|
|
|
dataset = GlamiDataset(dataset_dir=GLAMI_DATAPATH, nrows=args.nrows)
|
|
|
|
dataset.build_dataset()
|
2023-02-09 18:42:27 +01:00
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
|
|
|
return dataset
|
2023-02-07 18:40:17 +01:00
|
|
|
|
2023-02-09 18:42:27 +01:00
|
|
|
|
|
|
|
def main(args):
|
|
|
|
dataset = get_dataset(args.dataset)
|
2023-02-13 18:29:54 +01:00
|
|
|
if (
|
|
|
|
isinstance(dataset, MultilingualDataset)
|
|
|
|
or isinstance(dataset, MultiNewsDataset)
|
|
|
|
or isinstance(dataset, GlamiDataset)
|
2023-02-09 18:42:27 +01:00
|
|
|
):
|
2023-02-07 18:40:17 +01:00
|
|
|
lX, lY = dataset.training()
|
2023-02-10 12:58:26 +01:00
|
|
|
lX_te, lY_te = dataset.test()
|
2023-02-07 18:40:17 +01:00
|
|
|
else:
|
2023-02-13 15:01:50 +01:00
|
|
|
lX = dataset.dX
|
|
|
|
lY = dataset.dY
|
2023-02-07 18:40:17 +01:00
|
|
|
|
|
|
|
tinit = time()
|
|
|
|
|
2023-02-08 16:06:24 +01:00
|
|
|
if args.load_trained is None:
|
2023-02-07 18:40:17 +01:00
|
|
|
assert any(
|
|
|
|
[
|
|
|
|
args.posteriors,
|
|
|
|
args.wce,
|
|
|
|
args.multilingual,
|
|
|
|
args.multilingual,
|
|
|
|
args.transformer,
|
|
|
|
]
|
|
|
|
), "At least one of VGF must be True"
|
|
|
|
|
2023-02-08 14:51:56 +01:00
|
|
|
gfun = GeneralizedFunnelling(
|
2023-02-10 12:58:26 +01:00
|
|
|
# dataset params ----------------------
|
2023-02-10 11:37:32 +01:00
|
|
|
dataset_name=args.dataset,
|
2023-02-10 12:58:26 +01:00
|
|
|
langs=dataset.langs(),
|
2023-02-13 15:01:50 +01:00
|
|
|
num_labels=dataset.num_labels(),
|
2023-02-10 12:58:26 +01:00
|
|
|
# Posterior VGF params ----------------
|
2023-02-08 14:51:56 +01:00
|
|
|
posterior=args.posteriors,
|
2023-02-10 12:58:26 +01:00
|
|
|
# Multilingual VGF params -------------
|
2023-02-08 14:51:56 +01:00
|
|
|
multilingual=args.multilingual,
|
2023-02-10 12:58:26 +01:00
|
|
|
embed_dir="~/resources/muse_embeddings",
|
|
|
|
# WCE VGF params ----------------------
|
2023-02-08 14:51:56 +01:00
|
|
|
wce=args.wce,
|
2023-02-10 12:58:26 +01:00
|
|
|
# Transformer VGF params --------------
|
2023-02-08 14:51:56 +01:00
|
|
|
transformer=args.transformer,
|
2023-02-10 12:58:26 +01:00
|
|
|
transformer_name=args.transformer_name,
|
2023-02-08 14:51:56 +01:00
|
|
|
batch_size=args.batch_size,
|
|
|
|
epochs=args.epochs,
|
|
|
|
lr=args.lr,
|
2023-02-10 12:58:26 +01:00
|
|
|
max_length=args.max_length,
|
2023-02-08 14:51:56 +01:00
|
|
|
patience=args.patience,
|
|
|
|
evaluate_step=args.evaluate_step,
|
|
|
|
device="cuda",
|
2023-02-10 12:58:26 +01:00
|
|
|
# General params ----------------------
|
|
|
|
probabilistic=args.features,
|
|
|
|
aggfunc=args.aggfunc,
|
2023-02-08 14:51:56 +01:00
|
|
|
optimc=args.optimc,
|
|
|
|
load_trained=args.load_trained,
|
2023-02-13 15:01:50 +01:00
|
|
|
load_meta=args.meta,
|
2023-02-10 12:58:26 +01:00
|
|
|
n_jobs=args.n_jobs,
|
2023-02-08 14:51:56 +01:00
|
|
|
)
|
|
|
|
|
2023-02-08 16:06:24 +01:00
|
|
|
# gfun.get_config()
|
2023-02-08 14:51:56 +01:00
|
|
|
gfun.fit(lX, lY)
|
|
|
|
|
2023-02-13 18:29:54 +01:00
|
|
|
if args.load_trained is None and not args.nosave:
|
2023-02-13 15:01:50 +01:00
|
|
|
gfun.save(save_first_tier=True, save_meta=True)
|
2023-02-08 14:51:56 +01:00
|
|
|
|
|
|
|
preds = gfun.transform(lX)
|
|
|
|
|
2023-02-13 15:01:50 +01:00
|
|
|
# train_eval = evaluate(lY, preds)
|
|
|
|
# log_eval(train_eval, phase="train")
|
2023-02-08 14:51:56 +01:00
|
|
|
|
|
|
|
timetr = time()
|
|
|
|
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
2023-02-07 18:40:17 +01:00
|
|
|
|
|
|
|
test_eval = evaluate(lY_te, gfun.transform(lX_te))
|
|
|
|
log_eval(test_eval, phase="test")
|
|
|
|
|
|
|
|
timeval = time()
|
|
|
|
print(f"- testing completed in {timeval - timetr:.2f} seconds")
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = ArgumentParser()
|
2023-02-08 16:06:24 +01:00
|
|
|
parser.add_argument("-l", "--load_trained", type=str, default=None)
|
2023-02-13 15:01:50 +01:00
|
|
|
parser.add_argument("--meta", action="store_true")
|
2023-02-13 18:29:54 +01:00
|
|
|
parser.add_argument("--nosave", action="store_true")
|
2023-02-07 18:40:17 +01:00
|
|
|
# Dataset parameters -------------------
|
2023-02-09 18:42:27 +01:00
|
|
|
parser.add_argument("-d", "--dataset", type=str, default="multinews")
|
2023-02-07 18:40:17 +01:00
|
|
|
parser.add_argument("--domains", type=str, default="all")
|
2023-02-13 15:01:50 +01:00
|
|
|
parser.add_argument("--nrows", type=int, default=None)
|
2023-02-07 18:40:17 +01:00
|
|
|
parser.add_argument("--min_count", type=int, default=10)
|
|
|
|
parser.add_argument("--max_labels", type=int, default=50)
|
|
|
|
# gFUN parameters ----------------------
|
|
|
|
parser.add_argument("-p", "--posteriors", action="store_true")
|
|
|
|
parser.add_argument("-m", "--multilingual", action="store_true")
|
|
|
|
parser.add_argument("-w", "--wce", action="store_true")
|
|
|
|
parser.add_argument("-t", "--transformer", action="store_true")
|
|
|
|
parser.add_argument("--n_jobs", type=int, default=1)
|
2023-02-08 14:51:56 +01:00
|
|
|
parser.add_argument("--optimc", action="store_true")
|
2023-02-10 12:58:26 +01:00
|
|
|
parser.add_argument("--features", action="store_false")
|
|
|
|
parser.add_argument("--aggfunc", type=str, default="mean")
|
2023-02-07 18:40:17 +01:00
|
|
|
# transformer parameters ---------------
|
|
|
|
parser.add_argument("--transformer_name", type=str, default="mbert")
|
|
|
|
parser.add_argument("--batch_size", type=int, default=32)
|
2023-02-13 15:01:50 +01:00
|
|
|
parser.add_argument("--epochs", type=int, default=1000)
|
2023-02-07 18:40:17 +01:00
|
|
|
parser.add_argument("--lr", type=float, default=1e-5)
|
|
|
|
parser.add_argument("--max_length", type=int, default=512)
|
|
|
|
parser.add_argument("--patience", type=int, default=5)
|
|
|
|
parser.add_argument("--evaluate_step", type=int, default=10)
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
main(args)
|