gfun_multimodal/main.py

218 lines
7.4 KiB
Python
Raw Normal View History

import pickle
2023-02-07 18:40:17 +01:00
from argparse import ArgumentParser
from os.path import expanduser
from time import time
2023-02-07 18:40:17 +01:00
from dataManager.amazonDataset import AmazonDataset
from dataManager.multilingualDataset import MultilingualDataset
from dataManager.multiNewsDataset import MultiNewsDataset
2023-02-13 18:29:54 +01:00
from dataManager.glamiDataset import GlamiDataset
from dataManager.gFunDataset import gFunDataset
2023-02-07 18:40:17 +01:00
from evaluation.evaluate import evaluate, log_eval
from gfun.generalizedFunnelling import GeneralizedFunnelling
2023-02-07 18:40:17 +01:00
"""
TODO:
- [!] add support for Binary Datasets (e.g. cls)
2023-03-02 18:20:43 +01:00
- [!] logging
- add documentations sphinx
2023-03-02 18:20:43 +01:00
- [!] zero-shot setup
- FFNN posterior-probabilities' dependent
- re-init langs when loading VGFs?
- [!] loss of Attention-aggregator seems to be uncorrelated with Macro-F1 on the validation set!
2023-03-02 18:20:43 +01:00
- [!] experiment with weight init of Attention-aggregator
"""
2023-02-07 18:40:17 +01:00
def get_dataset(datasetname, args):
2023-02-13 18:29:54 +01:00
assert datasetname in [
"multinews",
"amazon",
"rcv1-2",
"glami",
"cls",
2023-02-13 18:29:54 +01:00
], "dataset not supported"
2023-02-07 18:40:17 +01:00
RCV_DATAPATH = expanduser(
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
)
JRC_DATAPATH = expanduser(
"~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle"
)
CLS_DATAPATH = expanduser("~/datasets/cls-acl10-processed/cls-acl10-processed.pkl")
MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/")
2023-02-13 18:29:54 +01:00
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
if datasetname == "multinews":
# TODO: convert to gFunDataset
raise NotImplementedError
dataset = MultiNewsDataset(
expanduser(MULTINEWS_DATAPATH),
excluded_langs=["ar", "pe", "pl", "tr", "ua"],
)
elif datasetname == "amazon":
# TODO: convert to gFunDataset
raise NotImplementedError
dataset = AmazonDataset(
domains=args.domains,
nrows=args.nrows,
min_count=args.min_count,
max_labels=args.max_labels,
)
elif datasetname == "rcv1-2":
dataset = gFunDataset(
dataset_dir=RCV_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=True,
nrows=args.nrows,
)
2023-02-13 18:29:54 +01:00
elif datasetname == "glami":
dataset = gFunDataset(
dataset_dir=GLAMI_DATAPATH,
is_textual=True,
is_visual=True,
is_multilabel=False,
nrows=args.nrows,
)
elif datasetname == "cls":
dataset = gFunDataset(
dataset_dir=CLS_DATAPATH,
is_textual=True,
is_visual=False,
is_multilabel=False,
nrows=args.nrows,
)
else:
raise NotImplementedError
return dataset
2023-02-07 18:40:17 +01:00
def main(args):
dataset = get_dataset(args.dataset, args)
2023-02-13 18:29:54 +01:00
if (
isinstance(dataset, MultilingualDataset)
or isinstance(dataset, MultiNewsDataset)
or isinstance(dataset, GlamiDataset)
or isinstance(dataset, gFunDataset)
):
2023-02-07 18:40:17 +01:00
lX, lY = dataset.training()
2023-02-10 12:58:26 +01:00
lX_te, lY_te = dataset.test()
2023-02-07 18:40:17 +01:00
else:
lX = dataset.dX
lY = dataset.dY
2023-02-07 18:40:17 +01:00
tinit = time()
if args.load_trained is None:
2023-02-07 18:40:17 +01:00
assert any(
[
args.posteriors,
args.wce,
args.multilingual,
args.multilingual,
args.textual_transformer,
args.visual_transformer,
2023-02-07 18:40:17 +01:00
]
), "At least one of VGF must be True"
gfun = GeneralizedFunnelling(
2023-02-10 12:58:26 +01:00
# dataset params ----------------------
dataset_name=args.dataset,
2023-02-10 12:58:26 +01:00
langs=dataset.langs(),
num_labels=dataset.num_labels(),
2023-02-10 12:58:26 +01:00
# Posterior VGF params ----------------
posterior=args.posteriors,
2023-02-10 12:58:26 +01:00
# Multilingual VGF params -------------
multilingual=args.multilingual,
2023-02-10 12:58:26 +01:00
embed_dir="~/resources/muse_embeddings",
# WCE VGF params ----------------------
wce=args.wce,
2023-02-10 12:58:26 +01:00
# Transformer VGF params --------------
textual_transformer=args.textual_transformer,
textual_transformer_name=args.transformer_name,
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
2023-02-10 12:58:26 +01:00
max_length=args.max_length,
patience=args.patience,
evaluate_step=args.evaluate_step,
device="cuda",
# Visual Transformer VGF params --------------
visual_transformer=args.visual_transformer,
visual_transformer_name=args.visual_transformer_name,
# batch_size=args.batch_size,
# epochs=args.epochs,
# lr=args.lr,
# patience=args.patience,
# evaluate_step=args.evaluate_step,
# device="cuda",
2023-02-10 12:58:26 +01:00
# General params ----------------------
probabilistic=args.features,
aggfunc=args.aggfunc,
optimc=args.optimc,
load_trained=args.load_trained,
load_meta=args.meta,
2023-02-10 12:58:26 +01:00
n_jobs=args.n_jobs,
)
# gfun.get_config()
gfun.fit(lX, lY)
2023-02-13 18:29:54 +01:00
if args.load_trained is None and not args.nosave:
gfun.save(save_first_tier=True, save_meta=True)
# print("- Computing evaluation on training set")
# preds = gfun.transform(lX)
# train_eval = evaluate(lY, preds)
# log_eval(train_eval, phase="train")
timetr = time()
print(f"- training completed in {timetr - tinit:.2f} seconds")
2023-02-07 18:40:17 +01:00
gfun_preds = gfun.transform(lX_te)
test_eval = evaluate(lY_te, gfun_preds)
2023-02-07 18:40:17 +01:00
log_eval(test_eval, phase="test")
timeval = time()
print(f"- testing completed in {timeval - timetr:.2f} seconds")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("-l", "--load_trained", type=str, default=None)
parser.add_argument("--meta", action="store_true")
2023-02-13 18:29:54 +01:00
parser.add_argument("--nosave", action="store_true")
2023-02-07 18:40:17 +01:00
# Dataset parameters -------------------
parser.add_argument("-d", "--dataset", type=str, default="rcv1-2")
2023-02-07 18:40:17 +01:00
parser.add_argument("--domains", type=str, default="all")
parser.add_argument("--nrows", type=int, default=None)
2023-02-07 18:40:17 +01:00
parser.add_argument("--min_count", type=int, default=10)
parser.add_argument("--max_labels", type=int, default=50)
# gFUN parameters ----------------------
parser.add_argument("-p", "--posteriors", action="store_true")
parser.add_argument("-m", "--multilingual", action="store_true")
parser.add_argument("-w", "--wce", action="store_true")
parser.add_argument("-t", "--textual_transformer", action="store_true")
parser.add_argument("-v", "--visual_transformer", action="store_true")
parser.add_argument("--n_jobs", type=int, default=-1)
parser.add_argument("--optimc", action="store_true")
2023-02-10 12:58:26 +01:00
parser.add_argument("--features", action="store_false")
parser.add_argument("--aggfunc", type=str, default="mean")
2023-02-07 18:40:17 +01:00
# transformer parameters ---------------
parser.add_argument("--transformer_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--epochs", type=int, default=100)
2023-02-07 18:40:17 +01:00
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--max_length", type=int, default=128)
2023-02-07 18:40:17 +01:00
parser.add_argument("--patience", type=int, default=5)
parser.add_argument("--evaluate_step", type=int, default=10)
# Visual Transformer parameters --------------
parser.add_argument("--visual_transformer_name", type=str, default="vit")
2023-02-07 18:40:17 +01:00
args = parser.parse_args()
main(args)