diff --git a/main.py b/main.py index f709bba..e49027b 100644 --- a/main.py +++ b/main.py @@ -150,7 +150,6 @@ def main(args): wandb.log(gfun_res) log_barplot_wandb(lang_metrics_gfun, title_affix="per language") - log_barplot_wandb(avg_metrics_gfun, title_affix="averages") if __name__ == "__main__": @@ -178,7 +177,7 @@ if __name__ == "__main__": parser.add_argument("--features", action="store_false") parser.add_argument("--aggfunc", type=str, default="mean") # transformer parameters --------------- - parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--epochs", type=int, default=5) parser.add_argument("--textual_trf_name", type=str, default="mbert") parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--eval_batch_size", type=int, default=128)