logging via wandb

This commit is contained in:
Andrea Pedrotti 2023-03-07 17:34:25 +01:00
parent 6b7917ca47
commit 84dd1f093e
3 changed files with 52 additions and 56 deletions

View File

@ -28,6 +28,7 @@ class GeneralizedFunnelling:
embed_dir, embed_dir,
n_jobs, n_jobs,
batch_size, batch_size,
eval_batch_size,
max_length, max_length,
lr, lr,
epochs, epochs,
@ -59,7 +60,8 @@ class GeneralizedFunnelling:
self.textual_trf_name = textual_transformer_name self.textual_trf_name = textual_transformer_name
self.epochs = epochs self.epochs = epochs
self.lr_transformer = lr self.lr_transformer = lr
self.batch_size_transformer = batch_size self.batch_size_trf = batch_size
self.eval_batch_size_trf = eval_batch_size
self.max_length = max_length self.max_length = max_length
self.early_stopping = True self.early_stopping = True
self.patience = patience self.patience = patience
@ -148,7 +150,8 @@ class GeneralizedFunnelling:
model_name=self.textual_trf_name, model_name=self.textual_trf_name,
lr=self.lr_transformer, lr=self.lr_transformer,
epochs=self.epochs, epochs=self.epochs,
batch_size=self.batch_size_transformer, batch_size=self.batch_size_trf,
batch_size_eval=self.eval_batch_size_trf,
max_length=self.max_length, max_length=self.max_length,
print_steps=50, print_steps=50,
probabilistic=self.probabilistic, probabilistic=self.probabilistic,
@ -163,10 +166,10 @@ class GeneralizedFunnelling:
visual_trasformer_vgf = VisualTransformerGen( visual_trasformer_vgf = VisualTransformerGen(
dataset_name=self.dataset_name, dataset_name=self.dataset_name,
model_name="vit", model_name="vit",
lr=1e-5, # self.lr_visual_transformer, lr=self.lr_transformer,
epochs=self.epochs, epochs=self.epochs,
batch_size=32, # self.batch_size_visual_transformer, batch_size=self.batch_size_trf,
# batch_size_eval=128, batch_size_eval=self.eval_batch_size_trf,
probabilistic=self.probabilistic, probabilistic=self.probabilistic,
evaluate_step=self.evaluate_step, evaluate_step=self.evaluate_step,
patience=self.patience, patience=self.patience,

View File

@ -140,46 +140,50 @@ class Trainer:
else: else:
raise ValueError(f"Optimizer {optimizer_name} not supported") raise ValueError(f"Optimizer {optimizer_name} not supported")
def train(self, train_dataloader, eval_dataloader, epochs=10): def get_config(self, train_dataloader, eval_dataloader, epochs):
wandb.init( return {
project="gfun", "model name": self.model.name_or_path,
name="allhere", "epochs": epochs,
# reinit=True, "learning rate": self.optimizer.defaults["lr"],
config={ "train batch size": train_dataloader.batch_size,
"vgf": self.vgf_name, "eval batch size": eval_dataloader.batch_size,
"architecture": self.model.name_or_path, "max len": train_dataloader.dataset.X.shape[-1],
"learning_rate": self.optimizer.defaults["lr"], "patience": self.earlystopping.patience,
"epochs": epochs, "evaluate every": self.evaluate_steps,
"train batch size": train_dataloader.batch_size, "print eval every": self.print_eval,
"eval batch size": eval_dataloader.batch_size, "print train steps": self.print_steps,
"max len": train_dataloader.dataset.X.shape[-1], }
"patience": self.earlystopping.patience,
"evaluate every": self.evaluate_steps,
"print eval every": self.print_eval,
"print train steps": self.print_steps,
},
)
print( def train(self, train_dataloader, eval_dataloader, epochs=10):
f"""- Training params for {self.experiment_name}: _config = self.get_config(train_dataloader, eval_dataloader, epochs)
- epochs: {epochs}
- learning rate: {self.optimizer.defaults['lr']} print(f"- Training params for {self.experiment_name}:")
- train batch size: {train_dataloader.batch_size} for k, v in _config.items():
- eval batch size: {eval_dataloader.batch_size} print(f"\t{k}: {v}")
- max len: {train_dataloader.dataset.X.shape[-1]}
- patience: {self.earlystopping.patience} wandb_logger = wandb.init(
- evaluate every: {self.evaluate_steps} project="gfun", entity="andreapdr", config=_config, reinit=True
- print eval every: {self.print_eval}
- print train steps: {self.print_steps}\n"""
) )
for epoch in range(epochs): for epoch in range(epochs):
self.train_epoch(train_dataloader, epoch) train_loss = self.train_epoch(train_dataloader, epoch)
wandb_logger.log({f"{self.vgf_name}_train_loss": train_loss})
if (epoch + 1) % self.evaluate_steps == 0: if (epoch + 1) % self.evaluate_steps == 0:
print_eval = (epoch + 1) % self.print_eval == 0 print_eval = (epoch + 1) % self.print_eval == 0
metric_watcher = self.evaluate( with torch.no_grad():
eval_dataloader, epoch, print_eval=print_eval eval_loss, metric_watcher = self.evaluate(
eval_dataloader, epoch, print_eval=print_eval
)
wandb_logger.log(
{
f"{self.vgf_name}_eval_loss": eval_loss,
f"{self.vgf_name}_eval_metric": metric_watcher,
}
) )
stop = self.earlystopping(metric_watcher, self.model, epoch + 1) stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
if stop: if stop:
print( print(
@ -189,8 +193,9 @@ class Trainer:
self.device self.device
) )
break break
print(f"- last swipe on eval set") print(f"- last swipe on eval set")
self.train_epoch(eval_dataloader, epoch=0) self.train_epoch(eval_dataloader, epoch=-1)
self.earlystopping.save_model(self.model) self.earlystopping.save_model(self.model)
return self.model return self.model
@ -208,14 +213,7 @@ class Trainer:
if (epoch + 1) % PRINT_ON_EPOCH == 0: if (epoch + 1) % PRINT_ON_EPOCH == 0:
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0: if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}") print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
wandb.log( return loss.item()
{
f"{wandb.config['vgf']}_training_loss": loss,
# "epoch": epoch,
# f"{wandb.config['vgf']}_epoch": epoch,
}
)
return self
def evaluate(self, dataloader, epoch, print_eval=True): def evaluate(self, dataloader, epoch, print_eval=True):
self.model.eval() self.model.eval()
@ -242,15 +240,8 @@ class Trainer:
l_eval = evaluate(lY, lY_hat) l_eval = evaluate(lY, lY_hat)
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval) average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
wandb.log(
{ return loss.item(), average_metrics[0] # macro-F1
f"{wandb.config['vgf']}_eval_metric": average_metrics[0],
f"{wandb.config['vgf']}_eval_loss": loss,
# "epoch": epoch,
# f"{wandb.config['vgf']}_epoch": epoch,
}
)
return average_metrics[0] # macro-F1
class EarlyStopping: class EarlyStopping:

View File

@ -54,6 +54,7 @@ def main(args):
textual_transformer=args.textual_transformer, textual_transformer=args.textual_transformer,
textual_transformer_name=args.transformer_name, textual_transformer_name=args.transformer_name,
batch_size=args.batch_size, batch_size=args.batch_size,
eval_batch_size=args.eval_batch_size,
epochs=args.epochs, epochs=args.epochs,
lr=args.lr, lr=args.lr,
max_length=args.max_length, max_length=args.max_length,
@ -125,6 +126,7 @@ if __name__ == "__main__":
# transformer parameters --------------- # transformer parameters ---------------
parser.add_argument("--transformer_name", type=str, default="mbert") parser.add_argument("--transformer_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--eval_batch_size", type=int, default=128)
parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--lr", type=float, default=1e-5) parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--max_length", type=int, default=128) parser.add_argument("--max_length", type=int, default=128)