From 7dead90271d3c90c21311a7f25eb90f20900345e Mon Sep 17 00:00:00 2001 From: andreapdr Date: Tue, 7 Mar 2023 14:20:56 +0100 Subject: [PATCH] logging via wandb --- .gitignore | 3 +- gfun/generalizedFunnelling.py | 22 +++++++++++---- gfun/vgfs/commons.py | 44 ++++++++++++++++++++++++++++-- gfun/vgfs/textualTransformerGen.py | 2 ++ gfun/vgfs/visualTransformerGen.py | 6 +++- 5 files changed, 67 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 59aee03..6651515 100644 --- a/.gitignore +++ b/.gitignore @@ -181,4 +181,5 @@ models/* scripts/ logger/* explore_data.ipynb -run.sh \ No newline at end of file +run.sh +wandb \ No newline at end of file diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index 4d4f25d..88009b4 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -47,7 +47,7 @@ class GeneralizedFunnelling: self.posteriors_vgf = posterior self.wce_vgf = wce self.multilingual_vgf = multilingual - self.trasformer_vgf = textual_transformer + self.textual_trasformer_vgf = textual_transformer self.visual_transformer_vgf = visual_transformer self.probabilistic = probabilistic self.num_labels = num_labels @@ -142,7 +142,7 @@ class GeneralizedFunnelling: wce_vgf = WceGen(n_jobs=self.n_jobs) self.first_tier_learners.append(wce_vgf) - if self.trasformer_vgf: + if self.textual_trasformer_vgf: transformer_vgf = TextualTransformerGen( dataset_name=self.dataset_name, model_name=self.textaul_transformer_name, @@ -198,7 +198,8 @@ class GeneralizedFunnelling: self.posteriors_vgf, self.multilingual_vgf, self.wce_vgf, - self.trasformer_vgf, + self.textual_trasformer_vgf, + self.visual_transformer_vgf, self.aggfunc, ) print(f"- model id: {self._model_id}") @@ -372,7 +373,7 @@ class GeneralizedFunnelling: "rb", ) as vgf: first_tier_learners.append(pickle.load(vgf)) - if self.trasformer_vgf: + if self.textual_trasformer_vgf: with open( os.path.join( "models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl" @@ -427,7 +428,15 @@ def get_params(optimc=False): return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}] -def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfunc): +def get_unique_id( + dataset_name, + posterior, + multilingual, + wce, + textual_transformer, + visual_transformer, + aggfunc, +): from datetime import datetime now = datetime.now().strftime("%y%m%d") @@ -435,6 +444,7 @@ def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfu model_id += "p" if posterior else "" model_id += "m" if multilingual else "" model_id += "w" if wce else "" - model_id += "t" if transformer else "" + model_id += "t" if textual_transformer else "" + model_id += "v" if visual_transformer else "" model_id += f"_{aggfunc}" return f"{model_id}_{now}" diff --git a/gfun/vgfs/commons.py b/gfun/vgfs/commons.py index 92a481a..f46eef6 100644 --- a/gfun/vgfs/commons.py +++ b/gfun/vgfs/commons.py @@ -12,6 +12,7 @@ from torch.optim import AdamW from torch.utils.data import DataLoader, Dataset from transformers.modeling_outputs import ModelOutput +import wandb from evaluation.evaluate import evaluate, log_eval PRINT_ON_EPOCH = 1 @@ -114,6 +115,7 @@ class Trainer: patience, experiment_name, checkpoint_path, + vgf_name, ): self.device = device self.model = model.to(device) @@ -130,6 +132,7 @@ class Trainer: verbose=False, experiment_name=experiment_name, ) + self.vgf_name = vgf_name def init_optimizer(self, optimizer_name, lr): if optimizer_name.lower() == "adamw": @@ -138,6 +141,25 @@ class Trainer: raise ValueError(f"Optimizer {optimizer_name} not supported") def train(self, train_dataloader, eval_dataloader, epochs=10): + wandb.init( + project="gfun", + name="allhere", + # reinit=True, + config={ + "vgf": self.vgf_name, + "architecture": self.model.name_or_path, + "learning_rate": self.optimizer.defaults["lr"], + "epochs": epochs, + "train batch size": train_dataloader.batch_size, + "eval batch size": eval_dataloader.batch_size, + "max len": train_dataloader.dataset.X.shape[-1], + "patience": self.earlystopping.patience, + "evaluate every": self.evaluate_steps, + "print eval every": self.print_eval, + "print train steps": self.print_steps, + }, + ) + print( f"""- Training params for {self.experiment_name}: - epochs: {epochs} @@ -150,11 +172,14 @@ class Trainer: - print eval every: {self.print_eval} - print train steps: {self.print_steps}\n""" ) + for epoch in range(epochs): self.train_epoch(train_dataloader, epoch) if (epoch + 1) % self.evaluate_steps == 0: print_eval = (epoch + 1) % self.print_eval == 0 - metric_watcher = self.evaluate(eval_dataloader, print_eval=print_eval) + metric_watcher = self.evaluate( + eval_dataloader, epoch, print_eval=print_eval + ) stop = self.earlystopping(metric_watcher, self.model, epoch + 1) if stop: print( @@ -183,9 +208,16 @@ class Trainer: if (epoch + 1) % PRINT_ON_EPOCH == 0: if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0: print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}") + wandb.log( + { + f"{wandb.config['vgf']}_training_loss": loss, + # "epoch": epoch, + # f"{wandb.config['vgf']}_epoch": epoch, + } + ) return self - def evaluate(self, dataloader, print_eval=True): + def evaluate(self, dataloader, epoch, print_eval=True): self.model.eval() lY = defaultdict(list) @@ -210,6 +242,14 @@ class Trainer: l_eval = evaluate(lY, lY_hat) average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval) + wandb.log( + { + f"{wandb.config['vgf']}_eval_metric": average_metrics[0], + f"{wandb.config['vgf']}_eval_loss": loss, + # "epoch": epoch, + # f"{wandb.config['vgf']}_epoch": epoch, + } + ) return average_metrics[0] # macro-F1 diff --git a/gfun/vgfs/textualTransformerGen.py b/gfun/vgfs/textualTransformerGen.py index 8a525c6..1c91b93 100644 --- a/gfun/vgfs/textualTransformerGen.py +++ b/gfun/vgfs/textualTransformerGen.py @@ -130,6 +130,7 @@ class TextualTransformerGen(ViewGen, TransformerGen): experiment_name = ( f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}" ) + trainer = Trainer( model=self.model, optimizer_name="adamW", @@ -141,6 +142,7 @@ class TextualTransformerGen(ViewGen, TransformerGen): patience=self.patience, experiment_name=experiment_name, checkpoint_path="models/vgfs/transformer", + vgf_name="textual_trf", ) trainer.train( train_dataloader=tra_dataloader, diff --git a/gfun/vgfs/visualTransformerGen.py b/gfun/vgfs/visualTransformerGen.py index 7692c97..f8b6e6e 100644 --- a/gfun/vgfs/visualTransformerGen.py +++ b/gfun/vgfs/visualTransformerGen.py @@ -97,7 +97,10 @@ class VisualTransformerGen(ViewGen, TransformerGen): shuffle=False, ) - experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}" + experiment_name = ( + f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}" + ) + trainer = Trainer( model=self.model, optimizer_name="adamW", @@ -109,6 +112,7 @@ class VisualTransformerGen(ViewGen, TransformerGen): patience=self.patience, experiment_name=experiment_name, checkpoint_path="models/vgfs/transformer", + vgf_name="visual_trf", ) trainer.train(