gfun_multimodal/main.py

123 lines
4.1 KiB
Python
Raw Normal View History

import pickle
2023-02-07 18:40:17 +01:00
from argparse import ArgumentParser
from os.path import expanduser
from time import time
2023-02-07 18:40:17 +01:00
from dataManager.amazonDataset import AmazonDataset
from dataManager.multilingualDatset import MultilingualDataset
from dataManager.multiNewsDataset import MultiNewsDataset
2023-02-07 18:40:17 +01:00
from evaluation.evaluate import evaluate, log_eval
from gfun.generalizedFunnelling import GeneralizedFunnelling
2023-02-07 18:40:17 +01:00
"""
TODO:
- a cleaner way to save the model? each VGF saved independently (together with
standardizer and feature2posteriors). What about the metaclassifier and the vectorizers?
- add documentations sphinx
- zero-shot setup
"""
2023-02-07 18:40:17 +01:00
def main(args):
# Loading dataset ------------------------
RCV_DATAPATH = expanduser(
"~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle"
)
# dataset = MultiNewsDataset(expanduser(args.dataset_path))
# dataset = AmazonDataset(domains=args.domains,nrows=args.nrows,min_count=args.min_count,max_labels=args.max_labels)
dataset = (
MultilingualDataset(dataset_name="rcv1-2")
.load(RCV_DATAPATH)
.reduce_data(langs=["en", "it", "fr"], maxn=100)
2023-02-07 18:40:17 +01:00
)
if isinstance(dataset, MultilingualDataset):
lX, lY = dataset.training()
lX_te, lY_te = dataset.test()
else:
_lX = dataset.dX
_lY = dataset.dY
# ----------------------------------------
tinit = time()
if not args.load_trained:
2023-02-07 18:40:17 +01:00
assert any(
[
args.posteriors,
args.wce,
args.multilingual,
args.multilingual,
args.transformer,
]
), "At least one of VGF must be True"
gfun = GeneralizedFunnelling(
posterior=args.posteriors,
multilingual=args.multilingual,
wce=args.wce,
transformer=args.transformer,
langs=dataset.langs(),
embed_dir="~/resources/muse_embeddings",
n_jobs=args.n_jobs,
max_length=args.max_length,
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
patience=args.patience,
evaluate_step=args.evaluate_step,
transformer_name=args.transformer_name,
device="cuda",
optimc=args.optimc,
load_trained=args.load_trained,
)
gfun.fit(lX, lY)
# if not args.load_model:
# gfun.save()
preds = gfun.transform(lX)
train_eval = evaluate(lY, preds)
log_eval(train_eval, phase="train")
timetr = time()
print(f"- training completed in {timetr - tinit:.2f} seconds")
2023-02-07 18:40:17 +01:00
test_eval = evaluate(lY_te, gfun.transform(lX_te))
log_eval(test_eval, phase="test")
timeval = time()
print(f"- testing completed in {timeval - timetr:.2f} seconds")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("-l", "--load_trained", action="store_true")
2023-02-07 18:40:17 +01:00
# Dataset parameters -------------------
parser.add_argument("--domains", type=str, default="all")
parser.add_argument("--nrows", type=int, default=10000)
parser.add_argument("--min_count", type=int, default=10)
parser.add_argument("--max_labels", type=int, default=50)
# gFUN parameters ----------------------
parser.add_argument("-p", "--posteriors", action="store_true")
parser.add_argument("-m", "--multilingual", action="store_true")
parser.add_argument("-w", "--wce", action="store_true")
parser.add_argument("-t", "--transformer", action="store_true")
parser.add_argument("--n_jobs", type=int, default=1)
parser.add_argument("--optimc", action="store_true")
2023-02-07 18:40:17 +01:00
# transformer parameters ---------------
parser.add_argument("--transformer_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--max_length", type=int, default=512)
parser.add_argument("--patience", type=int, default=5)
parser.add_argument("--evaluate_step", type=int, default=10)
args = parser.parse_args()
main(args)