import pickle from argparse import ArgumentParser from os.path import expanduser from time import time from dataManager.amazonDataset import AmazonDataset from dataManager.multilingualDatset import MultilingualDataset from dataManager.multiNewsDataset import MultiNewsDataset from dataManager.glamiDataset import GlamiDataset from evaluation.evaluate import evaluate, log_eval from gfun.generalizedFunnelling import GeneralizedFunnelling """ TODO: - add documentations sphinx - zero-shot setup - load pre-trained VGFs while retaining ability to train new ones (self.fitted = True in loaded? or smt like that) - test split in MultiNews dataset - when we load a model and change its config (eg change the agg func, re-train meta), we should store this model as a new one (save it) """ def get_dataset(datasetname): assert datasetname in [ "multinews", "amazon", "rcv1-2", "glami", ], "dataset not supported" RCV_DATAPATH = expanduser( "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" ) JRC_DATAPATH = expanduser( "~/datasets/jrc/jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle" ) MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/") GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset") if datasetname == "multinews": dataset = MultiNewsDataset( expanduser(MULTINEWS_DATAPATH), excluded_langs=["ar", "pe", "pl", "tr", "ua"], ) elif datasetname == "amazon": dataset = AmazonDataset( domains=args.domains, nrows=args.nrows, min_count=args.min_count, max_labels=args.max_labels, ) elif datasetname == "rcv1-2": dataset = MultilingualDataset(dataset_name="rcv1-2").load(RCV_DATAPATH) if args.nrows is not None: dataset.reduce_data(langs=["en", "it", "fr"], maxn=args.nrows) elif datasetname == "glami": dataset = GlamiDataset(dataset_dir=GLAMI_DATAPATH, nrows=args.nrows) dataset.build_dataset() else: raise NotImplementedError return dataset def main(args): dataset = get_dataset(args.dataset) if ( isinstance(dataset, MultilingualDataset) or isinstance(dataset, MultiNewsDataset) or isinstance(dataset, GlamiDataset) ): lX, lY = dataset.training() lX_te, lY_te = dataset.test() else: lX = dataset.dX lY = dataset.dY tinit = time() if args.load_trained is None: assert any( [ args.posteriors, args.wce, args.multilingual, args.multilingual, args.transformer, ] ), "At least one of VGF must be True" gfun = GeneralizedFunnelling( # dataset params ---------------------- dataset_name=args.dataset, langs=dataset.langs(), num_labels=dataset.num_labels(), # Posterior VGF params ---------------- posterior=args.posteriors, # Multilingual VGF params ------------- multilingual=args.multilingual, embed_dir="~/resources/muse_embeddings", # WCE VGF params ---------------------- wce=args.wce, # Transformer VGF params -------------- transformer=args.transformer, transformer_name=args.transformer_name, batch_size=args.batch_size, epochs=args.epochs, lr=args.lr, max_length=args.max_length, patience=args.patience, evaluate_step=args.evaluate_step, device="cuda", # General params ---------------------- probabilistic=args.features, aggfunc=args.aggfunc, optimc=args.optimc, load_trained=args.load_trained, load_meta=args.meta, n_jobs=args.n_jobs, ) # gfun.get_config() gfun.fit(lX, lY) if args.load_trained is None and not args.nosave: gfun.save(save_first_tier=True, save_meta=True) preds = gfun.transform(lX) # train_eval = evaluate(lY, preds) # log_eval(train_eval, phase="train") timetr = time() print(f"- training completed in {timetr - tinit:.2f} seconds") test_eval = evaluate(lY_te, gfun.transform(lX_te)) log_eval(test_eval, phase="test") timeval = time() print(f"- testing completed in {timeval - timetr:.2f} seconds") if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-l", "--load_trained", type=str, default=None) parser.add_argument("--meta", action="store_true") parser.add_argument("--nosave", action="store_true") # Dataset parameters ------------------- parser.add_argument("-d", "--dataset", type=str, default="multinews") parser.add_argument("--domains", type=str, default="all") parser.add_argument("--nrows", type=int, default=None) parser.add_argument("--min_count", type=int, default=10) parser.add_argument("--max_labels", type=int, default=50) # gFUN parameters ---------------------- parser.add_argument("-p", "--posteriors", action="store_true") parser.add_argument("-m", "--multilingual", action="store_true") parser.add_argument("-w", "--wce", action="store_true") parser.add_argument("-t", "--transformer", action="store_true") parser.add_argument("--n_jobs", type=int, default=1) parser.add_argument("--optimc", action="store_true") parser.add_argument("--features", action="store_false") parser.add_argument("--aggfunc", type=str, default="mean") # transformer parameters --------------- parser.add_argument("--transformer_name", type=str, default="mbert") parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--epochs", type=int, default=1000) parser.add_argument("--lr", type=float, default=1e-5) parser.add_argument("--max_length", type=int, default=512) parser.add_argument("--patience", type=int, default=5) parser.add_argument("--evaluate_step", type=int, default=10) args = parser.parse_args() main(args)