From 65407f51fa9d5c37d29a7c1cbfc80df137fd3846 Mon Sep 17 00:00:00 2001 From: andreapdr Date: Wed, 15 Mar 2023 11:47:17 +0100 Subject: [PATCH] update trainer to handle mT5 --- gfun/vgfs/commons.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/gfun/vgfs/commons.py b/gfun/vgfs/commons.py index e6527cd..1787c0b 100644 --- a/gfun/vgfs/commons.py +++ b/gfun/vgfs/commons.py @@ -188,7 +188,9 @@ class Trainer: def get_config(self, train_dataloader, eval_dataloader, epochs): return { - "model name": self.model.name_or_path, + "model name": self.model.name_or_path + if not hasattr(self.model, "mt5encoder") + else self.model.mt5encoder.name_or_path, "epochs": epochs, "learning rate": self.optimizer.defaults["lr"], "scheduler": self.scheduler_name, # TODO: add scheduler params @@ -212,7 +214,11 @@ class Trainer: print(f"\t{k}: {v}") wandb_logger = wandb.init( - project="gfun", entity="andreapdr", config=_config, reinit=True + project="gfun", + entity="andreapdr", + name=f"{_config['model name']} lr: {_config['learning rate']} scheduler: {_config['scheduler']}", + config=_config, + reinit=True, ) for epoch in range(epochs):