From 56faaf26152581790ffdc6f74b644d71ee495e3e Mon Sep 17 00:00:00 2001 From: andreapdr Date: Wed, 15 Mar 2023 16:35:49 +0100 Subject: [PATCH] changed wandb logging to a global level to keep track of all the VGFs and overall gFun --- gfun/vgfs/commons.py | 40 ++++++++++------------ main.py | 80 ++++++++++++++++++++++++++++++++++---------- 2 files changed, 80 insertions(+), 40 deletions(-) diff --git a/gfun/vgfs/commons.py b/gfun/vgfs/commons.py index 1787c0b..617b2fd 100644 --- a/gfun/vgfs/commons.py +++ b/gfun/vgfs/commons.py @@ -30,18 +30,18 @@ def verbosity_eval(epoch, print_eval): return False -def format_langkey_wandb(lang_dict): +def format_langkey_wandb(lang_dict, vgf_name): log_dict = {} for metric, l_dict in lang_dict.items(): for lang, value in l_dict.items(): - log_dict[f"language metric/{metric}/{lang}"] = value + log_dict[f"{vgf_name}/language metric/{metric}/{lang}"] = value return log_dict -def format_average_wandb(avg_dict): +def format_average_wandb(avg_dict, vgf_name): log_dict = {} for metric, value in avg_dict.items(): - log_dict[f"average metric/{metric}"] = value + log_dict[f"{vgf_name}/average metric/{metric}"] = value return log_dict @@ -213,14 +213,6 @@ class Trainer: for k, v in _config.items(): print(f"\t{k}: {v}") - wandb_logger = wandb.init( - project="gfun", - entity="andreapdr", - name=f"{_config['model name']} lr: {_config['learning rate']} scheduler: {_config['scheduler']}", - config=_config, - reinit=True, - ) - for epoch in range(epochs): train_loss = self.train_epoch(train_dataloader, epoch) @@ -233,11 +225,11 @@ class Trainer: n_jobs=self.n_jobs, ) - wandb_logger.log( + wandb.log( { - "loss/val": eval_loss, - **format_langkey_wandb(lang_metrics), - **format_average_wandb(avg_metrics), + f"{self.vgf_name}/loss/val": eval_loss, + **format_langkey_wandb(lang_metrics, self.vgf_name), + **format_average_wandb(avg_metrics, self.vgf_name), }, commit=False, ) @@ -260,10 +252,12 @@ class Trainer: if self.scheduler is not None: self.scheduler.step(avg_metrics[self.monitored_metric]) - wandb_logger.log( + wandb.log( { - "loss/train": train_loss, - "learning rate": self.optimizer.param_groups[0]["lr"], + f"{self.vgf_name}/loss/train": train_loss, + f"{self.vgf_name}/learning rate": self.optimizer.param_groups[0][ + "lr" + ], } ) @@ -274,7 +268,7 @@ class Trainer: def train_epoch(self, dataloader, epoch): self.model.train() - epoch_losses = [] + batch_losses = [] for b_idx, (x, y, lang) in enumerate(dataloader): self.optimizer.zero_grad() y_hat = self.model(x.to(self.device)) @@ -284,13 +278,13 @@ class Trainer: loss = self.loss_fn(y_hat, y.to(self.device)) loss.backward() self.optimizer.step() - epoch_losses.append(loss.item()) + batch_losses.append(loss.item()) # TODO: is this still on gpu? if (epoch + 1) % PRINT_ON_EPOCH == 0: if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0: print( - f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {np.mean(epoch_losses):.4f}" + f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {np.mean(batch_losses):.4f}" ) - return np.mean(epoch_losses) + return np.mean(batch_losses) def evaluate(self, dataloader, print_eval=True, n_jobs=-1): self.model.eval() diff --git a/main.py b/main.py index b968327..f272490 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import os +import wandb os.environ["CUDA_VISIBLE_DEVICES"] = "0" @@ -11,19 +12,38 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling """ TODO: - - [!] add support for mT5 - - [!] log on wandb also the other VGF results + final results - - [!] CLS dataset is loading only "books" domain data - - [!] documents should be trimmed to the same length (?) - - [!] overall gfun results logger - - add documentations sphinx - - [!] zero-shot setup - - FFNN posterior-probabilities' dependent - - re-init langs when loading VGFs? - - [!] experiment with weight init of Attention-aggregator + - Transformers VGFs: + - save/load for MT5ForSqeuenceClassification + - freeze params method + - log on step rather than epoch? + - General: + [!] zero-shot setup + - CLS dataset is loading only "books" domain data + - log on wandb also the other VGF results + final results + - documents should be trimmed to the same length (for SVMs we are using way too long tokens) + - Attention Aggregator: + - experiment with weight init of Attention-aggregator + - FFNN posterior-probabilities' dependent + - Docs: + - add documentations sphinx """ +def get_config_name(args): + config_name = "" + if args.posteriors: + config_name += "P+" + if args.wce: + config_name += "W+" + if args.multilingual: + config_name += "M+" + if args.textual_transformer: + config_name += f"TT_{args.textual_trf_name}+" + if args.visual_transformer: + config_name += f"VT_{args.visual_trf_name}+" + return config_name.rstrip("+") + + def main(args): dataset = get_dataset(args.dataset, args) lX, lY = dataset.training() @@ -86,27 +106,53 @@ def main(args): n_jobs=args.n_jobs, ) - # gfun.get_config() + wandb.init( + project="gfun", name=f"gFun-{get_config_name(args)}" + ) # TODO: Add config to log gfun.fit(lX, lY) if args.load_trained is None and not args.nosave: gfun.save(save_first_tier=True, save_meta=True) - # print("- Computing evaluation on training set") - # preds = gfun.transform(lX) - # train_eval = evaluate(lY, preds) - # log_eval(train_eval, phase="train") - timetr = time() print(f"- training completed in {timetr - tinit:.2f} seconds") gfun_preds = gfun.transform(lX_te) test_eval = evaluate(lY_te, gfun_preds, clf_type=args.clf_type, n_jobs=args.n_jobs) - log_eval(test_eval, phase="test", clf_type=args.clf_type) + avg_metrics_gfun, lang_metrics_gfun = log_eval( + test_eval, phase="test", clf_type=args.clf_type + ) timeval = time() print(f"- testing completed in {timeval - timetr:.2f} seconds") + def log_barplot_wandb(gfun_res, title_affix="per langauge"): + if title_affix == "per language": + for metric, lang_values in gfun_res.items(): + data = [[lang, v] for lang, v in lang_values.items()] + table = wandb.Table(data=data, columns=["lang", f"{metric}"]) + wandb.log( + { + f"gFun/language {metric}": wandb.plot.bar( + table, "lang", metric, title=f"{metric} {title_affix}" + ) + } + ) + else: + data = [[metric, value] for metric, value in gfun_res.items()] + table = wandb.Table(data=data, columns=["metric", "value"]) + wandb.log( + { + f"gFun/average metric": wandb.plot.bar( + table, "metric", "value", title=f"metric {title_affix}" + ) + } + ) + wandb.log(gfun_res) + + log_barplot_wandb(lang_metrics_gfun, title_affix="per language") + log_barplot_wandb(avg_metrics_gfun, title_affix="averages") + if __name__ == "__main__": parser = ArgumentParser()