gfun_multimodal/main.py

123 lines
4.1 KiB
Python

import pickle
from argparse import ArgumentParser
from os.path import expanduser
from time import time
from dataManager.amazonDataset import AmazonDataset
from dataManager.multilingualDatset import MultilingualDataset
from dataManager.multiNewsDataset import MultiNewsDataset
from evaluation.evaluate import evaluate, log_eval
from gfun.generalizedFunnelling import GeneralizedFunnelling
"""
TODO:
- a cleaner way to save the model? each VGF saved independently (together with
standardizer and feature2posteriors). What about the metaclassifier and the vectorizers?
- add documentations sphinx
- zero-shot setup
"""
def main(args):
# Loading dataset ------------------------
RCV_DATAPATH = expanduser(
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
)
# dataset = MultiNewsDataset(expanduser(args.dataset_path))
# dataset = AmazonDataset(domains=args.domains,nrows=args.nrows,min_count=args.min_count,max_labels=args.max_labels)
dataset = (
MultilingualDataset(dataset_name="rcv1-2")
.load(RCV_DATAPATH)
.reduce_data(langs=["en", "it", "fr"], maxn=100)
)
if isinstance(dataset, MultilingualDataset):
lX, lY = dataset.training()
lX_te, lY_te = dataset.test()
else:
_lX = dataset.dX
_lY = dataset.dY
# ----------------------------------------
tinit = time()
if not args.load_trained:
assert any(
[
args.posteriors,
args.wce,
args.multilingual,
args.multilingual,
args.transformer,
]
), "At least one of VGF must be True"
gfun = GeneralizedFunnelling(
posterior=args.posteriors,
multilingual=args.multilingual,
wce=args.wce,
transformer=args.transformer,
langs=dataset.langs(),
embed_dir="~/resources/muse_embeddings",
n_jobs=args.n_jobs,
max_length=args.max_length,
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
patience=args.patience,
evaluate_step=args.evaluate_step,
transformer_name=args.transformer_name,
device="cuda",
optimc=args.optimc,
load_trained=args.load_trained,
)
gfun.fit(lX, lY)
# if not args.load_model:
# gfun.save()
preds = gfun.transform(lX)
train_eval = evaluate(lY, preds)
log_eval(train_eval, phase="train")
timetr = time()
print(f"- training completed in {timetr - tinit:.2f} seconds")
test_eval = evaluate(lY_te, gfun.transform(lX_te))
log_eval(test_eval, phase="test")
timeval = time()
print(f"- testing completed in {timeval - timetr:.2f} seconds")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("-l", "--load_trained", action="store_true")
# Dataset parameters -------------------
parser.add_argument("--domains", type=str, default="all")
parser.add_argument("--nrows", type=int, default=10000)
parser.add_argument("--min_count", type=int, default=10)
parser.add_argument("--max_labels", type=int, default=50)
# gFUN parameters ----------------------
parser.add_argument("-p", "--posteriors", action="store_true")
parser.add_argument("-m", "--multilingual", action="store_true")
parser.add_argument("-w", "--wce", action="store_true")
parser.add_argument("-t", "--transformer", action="store_true")
parser.add_argument("--n_jobs", type=int, default=1)
parser.add_argument("--optimc", action="store_true")
# transformer parameters ---------------
parser.add_argument("--transformer_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--max_length", type=int, default=512)
parser.add_argument("--patience", type=int, default=5)
parser.add_argument("--evaluate_step", type=int, default=10)
args = parser.parse_args()
main(args)