import pickle from argparse import ArgumentParser from os.path import expanduser from time import time from dataManager.amazonDataset import AmazonDataset from dataManager.multilingualDatset import MultilingualDataset from dataManager.multiNewsDataset import MultiNewsDataset from evaluation.evaluate import evaluate, log_eval from gfun.generalizedFunnelling import GeneralizedFunnelling """ TODO: - add documentations sphinx - zero-shot setup - set probabilistic behaviour in Transformer parent-class - pooling / attention aggregation - test split in MultiNews dataset """ def get_dataset(datasetname): assert datasetname in ["multinews", "amazon", "rcv1-2"], "dataset not supported" RCV_DATAPATH = expanduser( "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" ) MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/") if datasetname == "multinews": dataset = MultiNewsDataset( expanduser(MULTINEWS_DATAPATH), excluded_langs=["ar", "pe", "pl", "tr", "ua"], ) elif datasetname == "amazon": dataset = AmazonDataset( domains=args.domains, nrows=args.nrows, min_count=args.min_count, max_labels=args.max_labels, ) elif datasetname == "rcv1-2": dataset = ( MultilingualDataset(dataset_name="rcv1-2") .load(RCV_DATAPATH) .reduce_data(langs=["en", "it", "fr"], maxn=500) ) else: raise NotImplementedError return dataset def main(args): dataset = get_dataset(args.dataset) if isinstance(dataset, MultilingualDataset) or isinstance( dataset, MultiNewsDataset ): lX, lY = dataset.training() # lX_te, lY_te = dataset.test() print("[NB: for debug purposes, training set is also used as test set]\n") lX_te, lY_te = dataset.training() else: _lX = dataset.dX _lY = dataset.dY tinit = time() if args.load_trained is None: assert any( [ args.posteriors, args.wce, args.multilingual, args.multilingual, args.transformer, ] ), "At least one of VGF must be True" gfun = GeneralizedFunnelling( dataset_name=args.dataset, posterior=args.posteriors, multilingual=args.multilingual, wce=args.wce, transformer=args.transformer, langs=dataset.langs(), embed_dir="~/resources/muse_embeddings", n_jobs=args.n_jobs, max_length=args.max_length, batch_size=args.batch_size, epochs=args.epochs, lr=args.lr, patience=args.patience, evaluate_step=args.evaluate_step, transformer_name=args.transformer_name, device="cuda", optimc=args.optimc, load_trained=args.load_trained, ) # gfun.get_config() gfun.fit(lX, lY) if args.load_trained is not None: gfun.save() # if not args.load_model: # gfun.save() preds = gfun.transform(lX) train_eval = evaluate(lY, preds) log_eval(train_eval, phase="train") timetr = time() print(f"- training completed in {timetr - tinit:.2f} seconds") test_eval = evaluate(lY_te, gfun.transform(lX_te)) log_eval(test_eval, phase="test") timeval = time() print(f"- testing completed in {timeval - timetr:.2f} seconds") if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-l", "--load_trained", type=str, default=None) # Dataset parameters ------------------- parser.add_argument("-d", "--dataset", type=str, default="multinews") parser.add_argument("--domains", type=str, default="all") parser.add_argument("--nrows", type=int, default=10000) parser.add_argument("--min_count", type=int, default=10) parser.add_argument("--max_labels", type=int, default=50) # gFUN parameters ---------------------- parser.add_argument("-p", "--posteriors", action="store_true") parser.add_argument("-m", "--multilingual", action="store_true") parser.add_argument("-w", "--wce", action="store_true") parser.add_argument("-t", "--transformer", action="store_true") parser.add_argument("--n_jobs", type=int, default=1) parser.add_argument("--optimc", action="store_true") # transformer parameters --------------- parser.add_argument("--transformer_name", type=str, default="mbert") parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--epochs", type=int, default=10) parser.add_argument("--lr", type=float, default=1e-5) parser.add_argument("--max_length", type=int, default=512) parser.add_argument("--patience", type=int, default=5) parser.add_argument("--evaluate_step", type=int, default=10) args = parser.parse_args() main(args)