logging via wandb
This commit is contained in:
parent
f274ec7615
commit
7dead90271
|
@ -182,3 +182,4 @@ scripts/
|
|||
logger/*
|
||||
explore_data.ipynb
|
||||
run.sh
|
||||
wandb
|
|
@ -47,7 +47,7 @@ class GeneralizedFunnelling:
|
|||
self.posteriors_vgf = posterior
|
||||
self.wce_vgf = wce
|
||||
self.multilingual_vgf = multilingual
|
||||
self.trasformer_vgf = textual_transformer
|
||||
self.textual_trasformer_vgf = textual_transformer
|
||||
self.visual_transformer_vgf = visual_transformer
|
||||
self.probabilistic = probabilistic
|
||||
self.num_labels = num_labels
|
||||
|
@ -142,7 +142,7 @@ class GeneralizedFunnelling:
|
|||
wce_vgf = WceGen(n_jobs=self.n_jobs)
|
||||
self.first_tier_learners.append(wce_vgf)
|
||||
|
||||
if self.trasformer_vgf:
|
||||
if self.textual_trasformer_vgf:
|
||||
transformer_vgf = TextualTransformerGen(
|
||||
dataset_name=self.dataset_name,
|
||||
model_name=self.textaul_transformer_name,
|
||||
|
@ -198,7 +198,8 @@ class GeneralizedFunnelling:
|
|||
self.posteriors_vgf,
|
||||
self.multilingual_vgf,
|
||||
self.wce_vgf,
|
||||
self.trasformer_vgf,
|
||||
self.textual_trasformer_vgf,
|
||||
self.visual_transformer_vgf,
|
||||
self.aggfunc,
|
||||
)
|
||||
print(f"- model id: {self._model_id}")
|
||||
|
@ -372,7 +373,7 @@ class GeneralizedFunnelling:
|
|||
"rb",
|
||||
) as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
if self.trasformer_vgf:
|
||||
if self.textual_trasformer_vgf:
|
||||
with open(
|
||||
os.path.join(
|
||||
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
|
||||
|
@ -427,7 +428,15 @@ def get_params(optimc=False):
|
|||
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
|
||||
|
||||
|
||||
def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfunc):
|
||||
def get_unique_id(
|
||||
dataset_name,
|
||||
posterior,
|
||||
multilingual,
|
||||
wce,
|
||||
textual_transformer,
|
||||
visual_transformer,
|
||||
aggfunc,
|
||||
):
|
||||
from datetime import datetime
|
||||
|
||||
now = datetime.now().strftime("%y%m%d")
|
||||
|
@ -435,6 +444,7 @@ def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfu
|
|||
model_id += "p" if posterior else ""
|
||||
model_id += "m" if multilingual else ""
|
||||
model_id += "w" if wce else ""
|
||||
model_id += "t" if transformer else ""
|
||||
model_id += "t" if textual_transformer else ""
|
||||
model_id += "v" if visual_transformer else ""
|
||||
model_id += f"_{aggfunc}"
|
||||
return f"{model_id}_{now}"
|
||||
|
|
|
@ -12,6 +12,7 @@ from torch.optim import AdamW
|
|||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers.modeling_outputs import ModelOutput
|
||||
|
||||
import wandb
|
||||
from evaluation.evaluate import evaluate, log_eval
|
||||
|
||||
PRINT_ON_EPOCH = 1
|
||||
|
@ -114,6 +115,7 @@ class Trainer:
|
|||
patience,
|
||||
experiment_name,
|
||||
checkpoint_path,
|
||||
vgf_name,
|
||||
):
|
||||
self.device = device
|
||||
self.model = model.to(device)
|
||||
|
@ -130,6 +132,7 @@ class Trainer:
|
|||
verbose=False,
|
||||
experiment_name=experiment_name,
|
||||
)
|
||||
self.vgf_name = vgf_name
|
||||
|
||||
def init_optimizer(self, optimizer_name, lr):
|
||||
if optimizer_name.lower() == "adamw":
|
||||
|
@ -138,6 +141,25 @@ class Trainer:
|
|||
raise ValueError(f"Optimizer {optimizer_name} not supported")
|
||||
|
||||
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
||||
wandb.init(
|
||||
project="gfun",
|
||||
name="allhere",
|
||||
# reinit=True,
|
||||
config={
|
||||
"vgf": self.vgf_name,
|
||||
"architecture": self.model.name_or_path,
|
||||
"learning_rate": self.optimizer.defaults["lr"],
|
||||
"epochs": epochs,
|
||||
"train batch size": train_dataloader.batch_size,
|
||||
"eval batch size": eval_dataloader.batch_size,
|
||||
"max len": train_dataloader.dataset.X.shape[-1],
|
||||
"patience": self.earlystopping.patience,
|
||||
"evaluate every": self.evaluate_steps,
|
||||
"print eval every": self.print_eval,
|
||||
"print train steps": self.print_steps,
|
||||
},
|
||||
)
|
||||
|
||||
print(
|
||||
f"""- Training params for {self.experiment_name}:
|
||||
- epochs: {epochs}
|
||||
|
@ -150,11 +172,14 @@ class Trainer:
|
|||
- print eval every: {self.print_eval}
|
||||
- print train steps: {self.print_steps}\n"""
|
||||
)
|
||||
|
||||
for epoch in range(epochs):
|
||||
self.train_epoch(train_dataloader, epoch)
|
||||
if (epoch + 1) % self.evaluate_steps == 0:
|
||||
print_eval = (epoch + 1) % self.print_eval == 0
|
||||
metric_watcher = self.evaluate(eval_dataloader, print_eval=print_eval)
|
||||
metric_watcher = self.evaluate(
|
||||
eval_dataloader, epoch, print_eval=print_eval
|
||||
)
|
||||
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
||||
if stop:
|
||||
print(
|
||||
|
@ -183,9 +208,16 @@ class Trainer:
|
|||
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
||||
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
||||
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
||||
wandb.log(
|
||||
{
|
||||
f"{wandb.config['vgf']}_training_loss": loss,
|
||||
# "epoch": epoch,
|
||||
# f"{wandb.config['vgf']}_epoch": epoch,
|
||||
}
|
||||
)
|
||||
return self
|
||||
|
||||
def evaluate(self, dataloader, print_eval=True):
|
||||
def evaluate(self, dataloader, epoch, print_eval=True):
|
||||
self.model.eval()
|
||||
|
||||
lY = defaultdict(list)
|
||||
|
@ -210,6 +242,14 @@ class Trainer:
|
|||
|
||||
l_eval = evaluate(lY, lY_hat)
|
||||
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
|
||||
wandb.log(
|
||||
{
|
||||
f"{wandb.config['vgf']}_eval_metric": average_metrics[0],
|
||||
f"{wandb.config['vgf']}_eval_loss": loss,
|
||||
# "epoch": epoch,
|
||||
# f"{wandb.config['vgf']}_epoch": epoch,
|
||||
}
|
||||
)
|
||||
return average_metrics[0] # macro-F1
|
||||
|
||||
|
||||
|
|
|
@ -130,6 +130,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
experiment_name = (
|
||||
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=self.model,
|
||||
optimizer_name="adamW",
|
||||
|
@ -141,6 +142,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
patience=self.patience,
|
||||
experiment_name=experiment_name,
|
||||
checkpoint_path="models/vgfs/transformer",
|
||||
vgf_name="textual_trf",
|
||||
)
|
||||
trainer.train(
|
||||
train_dataloader=tra_dataloader,
|
||||
|
|
|
@ -97,7 +97,10 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
shuffle=False,
|
||||
)
|
||||
|
||||
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}"
|
||||
experiment_name = (
|
||||
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=self.model,
|
||||
optimizer_name="adamW",
|
||||
|
@ -109,6 +112,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
patience=self.patience,
|
||||
experiment_name=experiment_name,
|
||||
checkpoint_path="models/vgfs/transformer",
|
||||
vgf_name="visual_trf",
|
||||
)
|
||||
|
||||
trainer.train(
|
||||
|
|
Loading…
Reference in New Issue