import os import wandb os.environ["CUDA_VISIBLE_DEVICES"] = "0" from argparse import ArgumentParser from time import time from dataManager.utils import get_dataset from evaluation.evaluate import evaluate, log_eval from gfun.generalizedFunnelling import GeneralizedFunnelling """ TODO: - Transformers VGFs: - save/load for MT5ForSqeuenceClassification - freeze params method - log on step rather than epoch? - General: [!] zero-shot setup - CLS dataset is loading only "books" domain data - log on wandb also the other VGF results + final results - documents should be trimmed to the same length (for SVMs we are using way too long tokens) - Attention Aggregator: - experiment with weight init of Attention-aggregator - FFNN posterior-probabilities' dependent - Docs: - add documentations sphinx """ def get_config_name(args): config_name = "" if args.posteriors: config_name += "P+" if args.wce: config_name += "W+" if args.multilingual: config_name += "M+" if args.textual_transformer: config_name += f"TT_{args.textual_trf_name}+" if args.visual_transformer: config_name += f"VT_{args.visual_trf_name}+" return config_name.rstrip("+") def main(args): dataset = get_dataset(args.dataset, args) lX, lY = dataset.training() lX_te, lY_te = dataset.test() tinit = time() if args.load_trained is None: assert any( [ args.posteriors, args.wce, args.multilingual, args.multilingual, args.textual_transformer, args.visual_transformer, ] ), "At least one of VGF must be True" gfun = GeneralizedFunnelling( # dataset params ---------------------- dataset_name=args.dataset, langs=dataset.langs(), num_labels=dataset.num_labels(), classification_type=args.clf_type, # Posterior VGF params ---------------- posterior=args.posteriors, # Multilingual VGF params ------------- multilingual=args.multilingual, embed_dir="~/resources/muse_embeddings", # WCE VGF params ---------------------- wce=args.wce, # Transformer VGF params -------------- textual_transformer=args.textual_transformer, textual_transformer_name=args.textual_trf_name, batch_size=args.batch_size, eval_batch_size=args.eval_batch_size, epochs=args.epochs, textual_lr=args.textual_lr, visual_lr=args.visual_lr, max_length=args.max_length, patience=args.patience, evaluate_step=args.evaluate_step, device=args.device, # Visual Transformer VGF params -------------- visual_transformer=args.visual_transformer, visual_transformer_name=args.visual_trf_name, # batch_size=args.batch_size, # epochs=args.epochs, # lr=args.lr, # patience=args.patience, # evaluate_step=args.evaluate_step, # device="cuda", # General params --------------------- probabilistic=args.features, aggfunc=args.aggfunc, optimc=args.optimc, load_trained=args.load_trained, load_meta=args.meta, n_jobs=args.n_jobs, ) wandb.init( project="gfun", name=f"gFun-{get_config_name(args)}" ) # TODO: Add config to log gfun.fit(lX, lY) if args.load_trained is None and not args.nosave: gfun.save(save_first_tier=True, save_meta=True) timetr = time() print(f"- training completed in {timetr - tinit:.2f} seconds") gfun_preds = gfun.transform(lX_te) test_eval = evaluate(lY_te, gfun_preds, clf_type=args.clf_type, n_jobs=args.n_jobs) avg_metrics_gfun, lang_metrics_gfun = log_eval( test_eval, phase="test", clf_type=args.clf_type ) timeval = time() print(f"- testing completed in {timeval - timetr:.2f} seconds") def log_barplot_wandb(gfun_res, title_affix="per langauge"): if title_affix == "per language": for metric, lang_values in gfun_res.items(): data = [[lang, v] for lang, v in lang_values.items()] table = wandb.Table(data=data, columns=["lang", f"{metric}"]) wandb.log( { f"gFun/language {metric}": wandb.plot.bar( table, "lang", metric, title=f"{metric} {title_affix}" ) } ) else: data = [[metric, value] for metric, value in gfun_res.items()] table = wandb.Table(data=data, columns=["metric", "value"]) wandb.log( { f"gFun/average metric": wandb.plot.bar( table, "metric", "value", title=f"metric {title_affix}" ) } ) wandb.log(gfun_res) log_barplot_wandb(lang_metrics_gfun, title_affix="per language") log_barplot_wandb(avg_metrics_gfun, title_affix="averages") if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-l", "--load_trained", type=str, default=None) parser.add_argument("--meta", action="store_true") parser.add_argument("--nosave", action="store_true") parser.add_argument("--device", type=str, default="cuda") # Dataset parameters ------------------- parser.add_argument("-d", "--dataset", type=str, default="rcv1-2") parser.add_argument("--domains", type=str, default="all") parser.add_argument("--nrows", type=int, default=None) parser.add_argument("--min_count", type=int, default=10) parser.add_argument("--max_labels", type=int, default=50) parser.add_argument("--clf_type", type=str, default="multilabel") parser.add_argument("--save_dataset", action="store_true") # gFUN parameters ---------------------- parser.add_argument("-p", "--posteriors", action="store_true") parser.add_argument("-m", "--multilingual", action="store_true") parser.add_argument("-w", "--wce", action="store_true") parser.add_argument("-t", "--textual_transformer", action="store_true") parser.add_argument("-v", "--visual_transformer", action="store_true") parser.add_argument("--n_jobs", type=int, default=-1) parser.add_argument("--optimc", action="store_true") parser.add_argument("--features", action="store_false") parser.add_argument("--aggfunc", type=str, default="mean") # transformer parameters --------------- parser.add_argument("--textual_trf_name", type=str, default="mbert") parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--eval_batch_size", type=int, default=128) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--textual_lr", type=float, default=1e-5) parser.add_argument("--visual_lr", type=float, default=1e-5) parser.add_argument("--max_length", type=int, default=128) parser.add_argument("--patience", type=int, default=5) parser.add_argument("--evaluate_step", type=int, default=10) # Visual Transformer parameters -------------- parser.add_argument("--visual_trf_name", type=str, default="vit") args = parser.parse_args() main(args)