logging via wandb
This commit is contained in:
parent
6b7917ca47
commit
84dd1f093e
|
@ -28,6 +28,7 @@ class GeneralizedFunnelling:
|
||||||
embed_dir,
|
embed_dir,
|
||||||
n_jobs,
|
n_jobs,
|
||||||
batch_size,
|
batch_size,
|
||||||
|
eval_batch_size,
|
||||||
max_length,
|
max_length,
|
||||||
lr,
|
lr,
|
||||||
epochs,
|
epochs,
|
||||||
|
@ -59,7 +60,8 @@ class GeneralizedFunnelling:
|
||||||
self.textual_trf_name = textual_transformer_name
|
self.textual_trf_name = textual_transformer_name
|
||||||
self.epochs = epochs
|
self.epochs = epochs
|
||||||
self.lr_transformer = lr
|
self.lr_transformer = lr
|
||||||
self.batch_size_transformer = batch_size
|
self.batch_size_trf = batch_size
|
||||||
|
self.eval_batch_size_trf = eval_batch_size
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.early_stopping = True
|
self.early_stopping = True
|
||||||
self.patience = patience
|
self.patience = patience
|
||||||
|
@ -148,7 +150,8 @@ class GeneralizedFunnelling:
|
||||||
model_name=self.textual_trf_name,
|
model_name=self.textual_trf_name,
|
||||||
lr=self.lr_transformer,
|
lr=self.lr_transformer,
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
batch_size=self.batch_size_transformer,
|
batch_size=self.batch_size_trf,
|
||||||
|
batch_size_eval=self.eval_batch_size_trf,
|
||||||
max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
print_steps=50,
|
print_steps=50,
|
||||||
probabilistic=self.probabilistic,
|
probabilistic=self.probabilistic,
|
||||||
|
@ -163,10 +166,10 @@ class GeneralizedFunnelling:
|
||||||
visual_trasformer_vgf = VisualTransformerGen(
|
visual_trasformer_vgf = VisualTransformerGen(
|
||||||
dataset_name=self.dataset_name,
|
dataset_name=self.dataset_name,
|
||||||
model_name="vit",
|
model_name="vit",
|
||||||
lr=1e-5, # self.lr_visual_transformer,
|
lr=self.lr_transformer,
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
batch_size=32, # self.batch_size_visual_transformer,
|
batch_size=self.batch_size_trf,
|
||||||
# batch_size_eval=128,
|
batch_size_eval=self.eval_batch_size_trf,
|
||||||
probabilistic=self.probabilistic,
|
probabilistic=self.probabilistic,
|
||||||
evaluate_step=self.evaluate_step,
|
evaluate_step=self.evaluate_step,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
|
|
|
@ -140,46 +140,50 @@ class Trainer:
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Optimizer {optimizer_name} not supported")
|
raise ValueError(f"Optimizer {optimizer_name} not supported")
|
||||||
|
|
||||||
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
def get_config(self, train_dataloader, eval_dataloader, epochs):
|
||||||
wandb.init(
|
return {
|
||||||
project="gfun",
|
"model name": self.model.name_or_path,
|
||||||
name="allhere",
|
"epochs": epochs,
|
||||||
# reinit=True,
|
"learning rate": self.optimizer.defaults["lr"],
|
||||||
config={
|
"train batch size": train_dataloader.batch_size,
|
||||||
"vgf": self.vgf_name,
|
"eval batch size": eval_dataloader.batch_size,
|
||||||
"architecture": self.model.name_or_path,
|
"max len": train_dataloader.dataset.X.shape[-1],
|
||||||
"learning_rate": self.optimizer.defaults["lr"],
|
"patience": self.earlystopping.patience,
|
||||||
"epochs": epochs,
|
"evaluate every": self.evaluate_steps,
|
||||||
"train batch size": train_dataloader.batch_size,
|
"print eval every": self.print_eval,
|
||||||
"eval batch size": eval_dataloader.batch_size,
|
"print train steps": self.print_steps,
|
||||||
"max len": train_dataloader.dataset.X.shape[-1],
|
}
|
||||||
"patience": self.earlystopping.patience,
|
|
||||||
"evaluate every": self.evaluate_steps,
|
|
||||||
"print eval every": self.print_eval,
|
|
||||||
"print train steps": self.print_steps,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
print(
|
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
||||||
f"""- Training params for {self.experiment_name}:
|
_config = self.get_config(train_dataloader, eval_dataloader, epochs)
|
||||||
- epochs: {epochs}
|
|
||||||
- learning rate: {self.optimizer.defaults['lr']}
|
print(f"- Training params for {self.experiment_name}:")
|
||||||
- train batch size: {train_dataloader.batch_size}
|
for k, v in _config.items():
|
||||||
- eval batch size: {eval_dataloader.batch_size}
|
print(f"\t{k}: {v}")
|
||||||
- max len: {train_dataloader.dataset.X.shape[-1]}
|
|
||||||
- patience: {self.earlystopping.patience}
|
wandb_logger = wandb.init(
|
||||||
- evaluate every: {self.evaluate_steps}
|
project="gfun", entity="andreapdr", config=_config, reinit=True
|
||||||
- print eval every: {self.print_eval}
|
|
||||||
- print train steps: {self.print_steps}\n"""
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
self.train_epoch(train_dataloader, epoch)
|
train_loss = self.train_epoch(train_dataloader, epoch)
|
||||||
|
|
||||||
|
wandb_logger.log({f"{self.vgf_name}_train_loss": train_loss})
|
||||||
|
|
||||||
if (epoch + 1) % self.evaluate_steps == 0:
|
if (epoch + 1) % self.evaluate_steps == 0:
|
||||||
print_eval = (epoch + 1) % self.print_eval == 0
|
print_eval = (epoch + 1) % self.print_eval == 0
|
||||||
metric_watcher = self.evaluate(
|
with torch.no_grad():
|
||||||
eval_dataloader, epoch, print_eval=print_eval
|
eval_loss, metric_watcher = self.evaluate(
|
||||||
|
eval_dataloader, epoch, print_eval=print_eval
|
||||||
|
)
|
||||||
|
|
||||||
|
wandb_logger.log(
|
||||||
|
{
|
||||||
|
f"{self.vgf_name}_eval_loss": eval_loss,
|
||||||
|
f"{self.vgf_name}_eval_metric": metric_watcher,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
||||||
if stop:
|
if stop:
|
||||||
print(
|
print(
|
||||||
|
@ -189,8 +193,9 @@ class Trainer:
|
||||||
self.device
|
self.device
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
print(f"- last swipe on eval set")
|
print(f"- last swipe on eval set")
|
||||||
self.train_epoch(eval_dataloader, epoch=0)
|
self.train_epoch(eval_dataloader, epoch=-1)
|
||||||
self.earlystopping.save_model(self.model)
|
self.earlystopping.save_model(self.model)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
@ -208,14 +213,7 @@ class Trainer:
|
||||||
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
||||||
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
||||||
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
||||||
wandb.log(
|
return loss.item()
|
||||||
{
|
|
||||||
f"{wandb.config['vgf']}_training_loss": loss,
|
|
||||||
# "epoch": epoch,
|
|
||||||
# f"{wandb.config['vgf']}_epoch": epoch,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def evaluate(self, dataloader, epoch, print_eval=True):
|
def evaluate(self, dataloader, epoch, print_eval=True):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
@ -242,15 +240,8 @@ class Trainer:
|
||||||
|
|
||||||
l_eval = evaluate(lY, lY_hat)
|
l_eval = evaluate(lY, lY_hat)
|
||||||
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
|
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
|
||||||
wandb.log(
|
|
||||||
{
|
return loss.item(), average_metrics[0] # macro-F1
|
||||||
f"{wandb.config['vgf']}_eval_metric": average_metrics[0],
|
|
||||||
f"{wandb.config['vgf']}_eval_loss": loss,
|
|
||||||
# "epoch": epoch,
|
|
||||||
# f"{wandb.config['vgf']}_epoch": epoch,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return average_metrics[0] # macro-F1
|
|
||||||
|
|
||||||
|
|
||||||
class EarlyStopping:
|
class EarlyStopping:
|
||||||
|
|
2
main.py
2
main.py
|
@ -54,6 +54,7 @@ def main(args):
|
||||||
textual_transformer=args.textual_transformer,
|
textual_transformer=args.textual_transformer,
|
||||||
textual_transformer_name=args.transformer_name,
|
textual_transformer_name=args.transformer_name,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
|
eval_batch_size=args.eval_batch_size,
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
max_length=args.max_length,
|
max_length=args.max_length,
|
||||||
|
@ -125,6 +126,7 @@ if __name__ == "__main__":
|
||||||
# transformer parameters ---------------
|
# transformer parameters ---------------
|
||||||
parser.add_argument("--transformer_name", type=str, default="mbert")
|
parser.add_argument("--transformer_name", type=str, default="mbert")
|
||||||
parser.add_argument("--batch_size", type=int, default=32)
|
parser.add_argument("--batch_size", type=int, default=32)
|
||||||
|
parser.add_argument("--eval_batch_size", type=int, default=128)
|
||||||
parser.add_argument("--epochs", type=int, default=100)
|
parser.add_argument("--epochs", type=int, default=100)
|
||||||
parser.add_argument("--lr", type=float, default=1e-5)
|
parser.add_argument("--lr", type=float, default=1e-5)
|
||||||
parser.add_argument("--max_length", type=int, default=128)
|
parser.add_argument("--max_length", type=int, default=128)
|
||||||
|
|
Loading…
Reference in New Issue