import os from argparse import ArgumentParser from time import time from csvlogger import CsvLogger from dataManager.utils import get_dataset from evaluation.evaluate import evaluate, log_eval from gfun.generalizedFunnelling import GeneralizedFunnelling import pandas as pd """ TODO: - Docs: - add documentations sphinx """ def get_config_name(args): config_name = "" if args.posteriors: config_name += "P+" if args.wce: config_name += "W+" if args.multilingual: config_name += "M+" if args.textual_transformer: config_name += f"TT_{args.textual_trf_name}+" return config_name.rstrip("+") def main(args): dataset = get_dataset(args.datadir, args) lX, lY = dataset.training(merge_validation=True) lX_te, lY_te = dataset.test() tinit = time() if args.load_trained is None: assert any( [ args.posteriors, args.wce, args.multilingual, args.multilingual, args.textual_transformer, ] ), "At least one of VGF must be True" gfun = GeneralizedFunnelling( # dataset params ---------------------- dataset_name=dataset.name, langs=dataset.langs(), num_labels=dataset.num_labels(), classification_type=args.clf_type, # Posterior VGF params ---------------- posterior=args.posteriors, # Multilingual VGF params ------------- multilingual=args.multilingual, embed_dir="~/resources/muse_embeddings", # WCE VGF params ---------------------- wce=args.wce, # Transformer VGF params -------------- textual_transformer=args.textual_transformer, textual_transformer_name=args.textual_trf_name, # trained_text_trf="hf_models/mbert-zeroshot-rai/checkpoint-1350", trained_text_trf="hf_models/mbert-fewshot-rai-full/checkpoint-5150", batch_size=args.batch_size, eval_batch_size=args.eval_batch_size, epochs=args.epochs, textual_lr=args.textual_lr, max_length=args.max_length, patience=args.patience, evaluate_step=args.evaluate_step, device=args.device, # General params --------------------- probabilistic=args.features, aggfunc=args.aggfunc, optimc=args.optimc, load_trained=args.load_trained, load_meta=args.meta, n_jobs=args.n_jobs, ) config = gfun.get_config() if args.wandb: import wandb wandb.init(project="gfun", name=f"gFun-{get_config_name(args)}", config=config) gfun.fit(lX, lY) # if args.load_trained is None and not args.nosave: print("saving model") gfun.save(save_first_tier=True, save_meta=True) timetr = time() print(f"- training completed in {timetr - tinit:.2f} seconds") gfun_preds = gfun.transform(lX_te) test_eval = evaluate(lY_te, gfun_preds, clf_type=args.clf_type, n_jobs=args.n_jobs) avg_metrics_gfun, lang_metrics_gfun = log_eval( test_eval, phase="test", clf_type=args.clf_type ) timeval = time() print(f"- testing completed in {timeval - timetr:.2f} seconds") def log_barplot_wandb(gfun_res, title_affix="per langauge"): if title_affix == "per language": for metric, lang_values in gfun_res.items(): data = [[lang, v] for lang, v in lang_values.items()] table = wandb.Table(data=data, columns=["lang", f"{metric}"]) wandb.log( { f"gFun/language {metric}": wandb.plot.bar( table, "lang", metric, title=f"{metric} {title_affix}" ) } ) else: data = [[metric, value] for metric, value in gfun_res.items()] table = wandb.Table(data=data, columns=["metric", "value"]) wandb.log( { f"gFun/average metric": wandb.plot.bar( table, "metric", "value", title=f"metric {title_affix}" ) } ) wandb.log(gfun_res) if args.wandb: log_barplot_wandb(lang_metrics_gfun, title_affix="per language") config["gFun"]["timing"] = f"{timeval - tinit:.2f}" csvlogger = CsvLogger(outfile="results/gfun.log.csv").log_lang_results(lang_metrics_gfun, config, notes="") save_preds(gfun_preds, lY_te, config=config["gFun"]["simple_id"], dataset=config["gFun"]["dataset"]) def save_preds(preds, targets, config="unk", dataset="unk"): os.makedirs("results/preds") df = pd.DataFrame() langs = sorted(preds.keys()) _preds = [] _targets = [] _langs = [] for lang in langs: _preds.extend(preds[lang].argmax(axis=1).tolist()) if targets is None: _targets.extend(["na" for i in range(len(preds[lang]))]) else: _targets.extend(targets[lang].argmax(axis=1).tolist()) _langs.extend([lang for i in range(len(preds[lang]))]) df["langs"] = _langs df["labels"] = _targets df["preds"] = _preds print(f"- storing predictions in 'results/preds/preds.gfun.{config}.{dataset}.csv'") df.to_csv(f"results/preds/preds.gfun.{config}.{dataset}.csv", index=False) if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-l", "--load_trained", type=str, default=None) parser.add_argument("--meta", action="store_true") parser.add_argument("--nosave", action="store_true") parser.add_argument("--device", type=str, default="cuda") parser.add_argument("--tr_langs", nargs="+", default=None) parser.add_argument("--te_langs", nargs="+", default=None) # Dataset parameters ------------------- parser.add_argument("-d", "--datadir", type=str, default=None, help="dir to dataset. It should contain both a train.csv and a test.csv file") parser.add_argument("--domains", type=str, default="all") parser.add_argument("--nrows", type=int, default=None) parser.add_argument("--min_count", type=int, default=10) parser.add_argument("--max_labels", type=int, default=50) parser.add_argument("--clf_type", type=str, default="multilabel") parser.add_argument("--save_dataset", action="store_true") # gFUN parameters ---------------------- parser.add_argument("-p", "--posteriors", action="store_true") parser.add_argument("-m", "--multilingual", action="store_true") parser.add_argument("-w", "--wce", action="store_true") parser.add_argument("-t", "--textual_transformer", action="store_true") parser.add_argument("--n_jobs", type=int, default=-1) parser.add_argument("--optimc", action="store_true") parser.add_argument("--features", action="store_false") parser.add_argument("--aggfunc", type=str, default="mean") # transformer parameters --------------- parser.add_argument("--epochs", type=int, default=5) parser.add_argument("--textual_trf_name", type=str, default="mbert") parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--eval_batch_size", type=int, default=128) parser.add_argument("--textual_lr", type=float, default=1e-4) parser.add_argument("--max_length", type=int, default=128) parser.add_argument("--patience", type=int, default=5) parser.add_argument("--evaluate_step", type=int, default=10) parser.add_argument("--reduced", action="store_true", help="run on reduced set of documents") # logging parser.add_argument("--wandb", action="store_true") args = parser.parse_args() main(args)