From 84dd1f093e6e685b2c43776fc3efd49666f0e1c6 Mon Sep 17 00:00:00 2001 From: andreapdr Date: Tue, 7 Mar 2023 17:34:25 +0100 Subject: [PATCH] logging via wandb --- gfun/generalizedFunnelling.py | 13 +++-- gfun/vgfs/commons.py | 93 ++++++++++++++++------------------- main.py | 2 + 3 files changed, 52 insertions(+), 56 deletions(-) diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index 5044583..3899107 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -28,6 +28,7 @@ class GeneralizedFunnelling: embed_dir, n_jobs, batch_size, + eval_batch_size, max_length, lr, epochs, @@ -59,7 +60,8 @@ class GeneralizedFunnelling: self.textual_trf_name = textual_transformer_name self.epochs = epochs self.lr_transformer = lr - self.batch_size_transformer = batch_size + self.batch_size_trf = batch_size + self.eval_batch_size_trf = eval_batch_size self.max_length = max_length self.early_stopping = True self.patience = patience @@ -148,7 +150,8 @@ class GeneralizedFunnelling: model_name=self.textual_trf_name, lr=self.lr_transformer, epochs=self.epochs, - batch_size=self.batch_size_transformer, + batch_size=self.batch_size_trf, + batch_size_eval=self.eval_batch_size_trf, max_length=self.max_length, print_steps=50, probabilistic=self.probabilistic, @@ -163,10 +166,10 @@ class GeneralizedFunnelling: visual_trasformer_vgf = VisualTransformerGen( dataset_name=self.dataset_name, model_name="vit", - lr=1e-5, # self.lr_visual_transformer, + lr=self.lr_transformer, epochs=self.epochs, - batch_size=32, # self.batch_size_visual_transformer, - # batch_size_eval=128, + batch_size=self.batch_size_trf, + batch_size_eval=self.eval_batch_size_trf, probabilistic=self.probabilistic, evaluate_step=self.evaluate_step, patience=self.patience, diff --git a/gfun/vgfs/commons.py b/gfun/vgfs/commons.py index f46eef6..effbf9d 100644 --- a/gfun/vgfs/commons.py +++ b/gfun/vgfs/commons.py @@ -140,46 +140,50 @@ class Trainer: else: raise ValueError(f"Optimizer {optimizer_name} not supported") - def train(self, train_dataloader, eval_dataloader, epochs=10): - wandb.init( - project="gfun", - name="allhere", - # reinit=True, - config={ - "vgf": self.vgf_name, - "architecture": self.model.name_or_path, - "learning_rate": self.optimizer.defaults["lr"], - "epochs": epochs, - "train batch size": train_dataloader.batch_size, - "eval batch size": eval_dataloader.batch_size, - "max len": train_dataloader.dataset.X.shape[-1], - "patience": self.earlystopping.patience, - "evaluate every": self.evaluate_steps, - "print eval every": self.print_eval, - "print train steps": self.print_steps, - }, - ) + def get_config(self, train_dataloader, eval_dataloader, epochs): + return { + "model name": self.model.name_or_path, + "epochs": epochs, + "learning rate": self.optimizer.defaults["lr"], + "train batch size": train_dataloader.batch_size, + "eval batch size": eval_dataloader.batch_size, + "max len": train_dataloader.dataset.X.shape[-1], + "patience": self.earlystopping.patience, + "evaluate every": self.evaluate_steps, + "print eval every": self.print_eval, + "print train steps": self.print_steps, + } - print( - f"""- Training params for {self.experiment_name}: - - epochs: {epochs} - - learning rate: {self.optimizer.defaults['lr']} - - train batch size: {train_dataloader.batch_size} - - eval batch size: {eval_dataloader.batch_size} - - max len: {train_dataloader.dataset.X.shape[-1]} - - patience: {self.earlystopping.patience} - - evaluate every: {self.evaluate_steps} - - print eval every: {self.print_eval} - - print train steps: {self.print_steps}\n""" + def train(self, train_dataloader, eval_dataloader, epochs=10): + _config = self.get_config(train_dataloader, eval_dataloader, epochs) + + print(f"- Training params for {self.experiment_name}:") + for k, v in _config.items(): + print(f"\t{k}: {v}") + + wandb_logger = wandb.init( + project="gfun", entity="andreapdr", config=_config, reinit=True ) for epoch in range(epochs): - self.train_epoch(train_dataloader, epoch) + train_loss = self.train_epoch(train_dataloader, epoch) + + wandb_logger.log({f"{self.vgf_name}_train_loss": train_loss}) + if (epoch + 1) % self.evaluate_steps == 0: print_eval = (epoch + 1) % self.print_eval == 0 - metric_watcher = self.evaluate( - eval_dataloader, epoch, print_eval=print_eval + with torch.no_grad(): + eval_loss, metric_watcher = self.evaluate( + eval_dataloader, epoch, print_eval=print_eval + ) + + wandb_logger.log( + { + f"{self.vgf_name}_eval_loss": eval_loss, + f"{self.vgf_name}_eval_metric": metric_watcher, + } ) + stop = self.earlystopping(metric_watcher, self.model, epoch + 1) if stop: print( @@ -189,8 +193,9 @@ class Trainer: self.device ) break + print(f"- last swipe on eval set") - self.train_epoch(eval_dataloader, epoch=0) + self.train_epoch(eval_dataloader, epoch=-1) self.earlystopping.save_model(self.model) return self.model @@ -208,14 +213,7 @@ class Trainer: if (epoch + 1) % PRINT_ON_EPOCH == 0: if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0: print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}") - wandb.log( - { - f"{wandb.config['vgf']}_training_loss": loss, - # "epoch": epoch, - # f"{wandb.config['vgf']}_epoch": epoch, - } - ) - return self + return loss.item() def evaluate(self, dataloader, epoch, print_eval=True): self.model.eval() @@ -242,15 +240,8 @@ class Trainer: l_eval = evaluate(lY, lY_hat) average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval) - wandb.log( - { - f"{wandb.config['vgf']}_eval_metric": average_metrics[0], - f"{wandb.config['vgf']}_eval_loss": loss, - # "epoch": epoch, - # f"{wandb.config['vgf']}_epoch": epoch, - } - ) - return average_metrics[0] # macro-F1 + + return loss.item(), average_metrics[0] # macro-F1 class EarlyStopping: diff --git a/main.py b/main.py index 9b4ead6..e312c3d 100644 --- a/main.py +++ b/main.py @@ -54,6 +54,7 @@ def main(args): textual_transformer=args.textual_transformer, textual_transformer_name=args.transformer_name, batch_size=args.batch_size, + eval_batch_size=args.eval_batch_size, epochs=args.epochs, lr=args.lr, max_length=args.max_length, @@ -125,6 +126,7 @@ if __name__ == "__main__": # transformer parameters --------------- parser.add_argument("--transformer_name", type=str, default="mbert") parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--eval_batch_size", type=int, default=128) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--lr", type=float, default=1e-5) parser.add_argument("--max_length", type=int, default=128)