logging via wandb
This commit is contained in:
parent
f274ec7615
commit
7dead90271
|
@ -182,3 +182,4 @@ scripts/
|
||||||
logger/*
|
logger/*
|
||||||
explore_data.ipynb
|
explore_data.ipynb
|
||||||
run.sh
|
run.sh
|
||||||
|
wandb
|
|
@ -47,7 +47,7 @@ class GeneralizedFunnelling:
|
||||||
self.posteriors_vgf = posterior
|
self.posteriors_vgf = posterior
|
||||||
self.wce_vgf = wce
|
self.wce_vgf = wce
|
||||||
self.multilingual_vgf = multilingual
|
self.multilingual_vgf = multilingual
|
||||||
self.trasformer_vgf = textual_transformer
|
self.textual_trasformer_vgf = textual_transformer
|
||||||
self.visual_transformer_vgf = visual_transformer
|
self.visual_transformer_vgf = visual_transformer
|
||||||
self.probabilistic = probabilistic
|
self.probabilistic = probabilistic
|
||||||
self.num_labels = num_labels
|
self.num_labels = num_labels
|
||||||
|
@ -142,7 +142,7 @@ class GeneralizedFunnelling:
|
||||||
wce_vgf = WceGen(n_jobs=self.n_jobs)
|
wce_vgf = WceGen(n_jobs=self.n_jobs)
|
||||||
self.first_tier_learners.append(wce_vgf)
|
self.first_tier_learners.append(wce_vgf)
|
||||||
|
|
||||||
if self.trasformer_vgf:
|
if self.textual_trasformer_vgf:
|
||||||
transformer_vgf = TextualTransformerGen(
|
transformer_vgf = TextualTransformerGen(
|
||||||
dataset_name=self.dataset_name,
|
dataset_name=self.dataset_name,
|
||||||
model_name=self.textaul_transformer_name,
|
model_name=self.textaul_transformer_name,
|
||||||
|
@ -198,7 +198,8 @@ class GeneralizedFunnelling:
|
||||||
self.posteriors_vgf,
|
self.posteriors_vgf,
|
||||||
self.multilingual_vgf,
|
self.multilingual_vgf,
|
||||||
self.wce_vgf,
|
self.wce_vgf,
|
||||||
self.trasformer_vgf,
|
self.textual_trasformer_vgf,
|
||||||
|
self.visual_transformer_vgf,
|
||||||
self.aggfunc,
|
self.aggfunc,
|
||||||
)
|
)
|
||||||
print(f"- model id: {self._model_id}")
|
print(f"- model id: {self._model_id}")
|
||||||
|
@ -372,7 +373,7 @@ class GeneralizedFunnelling:
|
||||||
"rb",
|
"rb",
|
||||||
) as vgf:
|
) as vgf:
|
||||||
first_tier_learners.append(pickle.load(vgf))
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
if self.trasformer_vgf:
|
if self.textual_trasformer_vgf:
|
||||||
with open(
|
with open(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
|
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
|
||||||
|
@ -427,7 +428,15 @@ def get_params(optimc=False):
|
||||||
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
|
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
|
||||||
|
|
||||||
|
|
||||||
def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfunc):
|
def get_unique_id(
|
||||||
|
dataset_name,
|
||||||
|
posterior,
|
||||||
|
multilingual,
|
||||||
|
wce,
|
||||||
|
textual_transformer,
|
||||||
|
visual_transformer,
|
||||||
|
aggfunc,
|
||||||
|
):
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
now = datetime.now().strftime("%y%m%d")
|
now = datetime.now().strftime("%y%m%d")
|
||||||
|
@ -435,6 +444,7 @@ def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfu
|
||||||
model_id += "p" if posterior else ""
|
model_id += "p" if posterior else ""
|
||||||
model_id += "m" if multilingual else ""
|
model_id += "m" if multilingual else ""
|
||||||
model_id += "w" if wce else ""
|
model_id += "w" if wce else ""
|
||||||
model_id += "t" if transformer else ""
|
model_id += "t" if textual_transformer else ""
|
||||||
|
model_id += "v" if visual_transformer else ""
|
||||||
model_id += f"_{aggfunc}"
|
model_id += f"_{aggfunc}"
|
||||||
return f"{model_id}_{now}"
|
return f"{model_id}_{now}"
|
||||||
|
|
|
@ -12,6 +12,7 @@ from torch.optim import AdamW
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
from transformers.modeling_outputs import ModelOutput
|
from transformers.modeling_outputs import ModelOutput
|
||||||
|
|
||||||
|
import wandb
|
||||||
from evaluation.evaluate import evaluate, log_eval
|
from evaluation.evaluate import evaluate, log_eval
|
||||||
|
|
||||||
PRINT_ON_EPOCH = 1
|
PRINT_ON_EPOCH = 1
|
||||||
|
@ -114,6 +115,7 @@ class Trainer:
|
||||||
patience,
|
patience,
|
||||||
experiment_name,
|
experiment_name,
|
||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
|
vgf_name,
|
||||||
):
|
):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = model.to(device)
|
self.model = model.to(device)
|
||||||
|
@ -130,6 +132,7 @@ class Trainer:
|
||||||
verbose=False,
|
verbose=False,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
)
|
)
|
||||||
|
self.vgf_name = vgf_name
|
||||||
|
|
||||||
def init_optimizer(self, optimizer_name, lr):
|
def init_optimizer(self, optimizer_name, lr):
|
||||||
if optimizer_name.lower() == "adamw":
|
if optimizer_name.lower() == "adamw":
|
||||||
|
@ -138,6 +141,25 @@ class Trainer:
|
||||||
raise ValueError(f"Optimizer {optimizer_name} not supported")
|
raise ValueError(f"Optimizer {optimizer_name} not supported")
|
||||||
|
|
||||||
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
||||||
|
wandb.init(
|
||||||
|
project="gfun",
|
||||||
|
name="allhere",
|
||||||
|
# reinit=True,
|
||||||
|
config={
|
||||||
|
"vgf": self.vgf_name,
|
||||||
|
"architecture": self.model.name_or_path,
|
||||||
|
"learning_rate": self.optimizer.defaults["lr"],
|
||||||
|
"epochs": epochs,
|
||||||
|
"train batch size": train_dataloader.batch_size,
|
||||||
|
"eval batch size": eval_dataloader.batch_size,
|
||||||
|
"max len": train_dataloader.dataset.X.shape[-1],
|
||||||
|
"patience": self.earlystopping.patience,
|
||||||
|
"evaluate every": self.evaluate_steps,
|
||||||
|
"print eval every": self.print_eval,
|
||||||
|
"print train steps": self.print_steps,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"""- Training params for {self.experiment_name}:
|
f"""- Training params for {self.experiment_name}:
|
||||||
- epochs: {epochs}
|
- epochs: {epochs}
|
||||||
|
@ -150,11 +172,14 @@ class Trainer:
|
||||||
- print eval every: {self.print_eval}
|
- print eval every: {self.print_eval}
|
||||||
- print train steps: {self.print_steps}\n"""
|
- print train steps: {self.print_steps}\n"""
|
||||||
)
|
)
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
self.train_epoch(train_dataloader, epoch)
|
self.train_epoch(train_dataloader, epoch)
|
||||||
if (epoch + 1) % self.evaluate_steps == 0:
|
if (epoch + 1) % self.evaluate_steps == 0:
|
||||||
print_eval = (epoch + 1) % self.print_eval == 0
|
print_eval = (epoch + 1) % self.print_eval == 0
|
||||||
metric_watcher = self.evaluate(eval_dataloader, print_eval=print_eval)
|
metric_watcher = self.evaluate(
|
||||||
|
eval_dataloader, epoch, print_eval=print_eval
|
||||||
|
)
|
||||||
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
||||||
if stop:
|
if stop:
|
||||||
print(
|
print(
|
||||||
|
@ -183,9 +208,16 @@ class Trainer:
|
||||||
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
||||||
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
||||||
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
f"{wandb.config['vgf']}_training_loss": loss,
|
||||||
|
# "epoch": epoch,
|
||||||
|
# f"{wandb.config['vgf']}_epoch": epoch,
|
||||||
|
}
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def evaluate(self, dataloader, print_eval=True):
|
def evaluate(self, dataloader, epoch, print_eval=True):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
lY = defaultdict(list)
|
lY = defaultdict(list)
|
||||||
|
@ -210,6 +242,14 @@ class Trainer:
|
||||||
|
|
||||||
l_eval = evaluate(lY, lY_hat)
|
l_eval = evaluate(lY, lY_hat)
|
||||||
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
|
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
f"{wandb.config['vgf']}_eval_metric": average_metrics[0],
|
||||||
|
f"{wandb.config['vgf']}_eval_loss": loss,
|
||||||
|
# "epoch": epoch,
|
||||||
|
# f"{wandb.config['vgf']}_epoch": epoch,
|
||||||
|
}
|
||||||
|
)
|
||||||
return average_metrics[0] # macro-F1
|
return average_metrics[0] # macro-F1
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -130,6 +130,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
experiment_name = (
|
experiment_name = (
|
||||||
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
optimizer_name="adamW",
|
optimizer_name="adamW",
|
||||||
|
@ -141,6 +142,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
checkpoint_path="models/vgfs/transformer",
|
checkpoint_path="models/vgfs/transformer",
|
||||||
|
vgf_name="textual_trf",
|
||||||
)
|
)
|
||||||
trainer.train(
|
trainer.train(
|
||||||
train_dataloader=tra_dataloader,
|
train_dataloader=tra_dataloader,
|
||||||
|
|
|
@ -97,7 +97,10 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}"
|
experiment_name = (
|
||||||
|
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||||
|
)
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
optimizer_name="adamW",
|
optimizer_name="adamW",
|
||||||
|
@ -109,6 +112,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
checkpoint_path="models/vgfs/transformer",
|
checkpoint_path="models/vgfs/transformer",
|
||||||
|
vgf_name="visual_trf",
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.train(
|
trainer.train(
|
||||||
|
|
Loading…
Reference in New Issue