improved wandb logging

This commit is contained in:
andreapdr 2023-03-09 17:03:17 +01:00
parent 3240150542
commit 7e1ec46ebd
6 changed files with 215 additions and 90 deletions

View File

@ -1,46 +1,96 @@
from joblib import Parallel, delayed
from collections import defaultdict
from evaluation.metrics import *
from sklearn.metrics import accuracy_score, top_k_accuracy_score, f1_score
def evaluation_metrics(y, y_):
if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label
raise NotImplementedError()
else:
def evaluation_metrics(y, y_, clf_type):
if clf_type == "singlelabel":
return (
accuracy_score(y, y_),
# TODO: we need the logits to compute this top_k_accuracy_score(y, y_, k=5),
# TODO: we need logits top_k_accuracy_score(y, y_, k=10),
f1_score(y, y_, average="macro", zero_division=1),
f1_score(y, y_, average="micro"),
)
elif clf_type == "multilabel":
return (
macroF1(y, y_),
microF1(y, y_),
macroK(y, y_),
microK(y, y_),
# macroAcc(y, y_),
)
else:
raise ValueError("clf_type must be either 'singlelabel' or 'multilabel'")
def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1):
def evaluate(
ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1, clf_type="multilabel"
):
if n_jobs == 1:
return {lang: metrics(ly_true[lang], ly_pred[lang]) for lang in ly_true.keys()}
return {
lang: metrics(ly_true[lang], ly_pred[lang], clf_type)
for lang in ly_true.keys()
}
else:
langs = list(ly_true.keys())
evals = Parallel(n_jobs=n_jobs)(
delayed(metrics)(ly_true[lang], ly_pred[lang]) for lang in langs
delayed(metrics)(ly_true[lang], ly_pred[lang], clf_type) for lang in langs
)
return {lang: evals[i] for i, lang in enumerate(langs)}
def log_eval(l_eval, phase="training", verbose=True):
def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
if verbose:
print(f"\n[Results {phase}]")
metrics = []
for lang in l_eval.keys():
macrof1, microf1, macrok, microk = l_eval[lang]
metrics.append([macrof1, microf1, macrok, microk])
if phase != "validation":
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
averages = np.mean(np.array(metrics), axis=0)
if verbose:
print(
"Averages: MF1, mF1, MK, mK",
np.round(averages, 3),
"\n",
)
return averages
if clf_type == "multilabel":
for lang in l_eval.keys():
macrof1, microf1, macrok, microk = l_eval[lang]
metrics.append([macrof1, microf1, macrok, microk])
if phase != "validation":
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
averages = np.mean(np.array(metrics), axis=0)
if verbose:
print(
"Averages: MF1, mF1, MK, mK",
np.round(averages, 3),
"\n",
)
return averages # TODO: return a dict avg and lang specific
elif clf_type == "singlelabel":
lang_metrics = defaultdict(dict)
_metrics = [
"accuracy",
# "acc5", # "accuracy-at-5",
# "acc10", # "accuracy-at-10",
"MF1", # "macro-F1",
"mF1", # "micro-F1",
]
for lang in l_eval.keys():
# acc, top5, top10, macrof1, microf1 = l_eval[lang]
acc, macrof1, microf1 = l_eval[lang]
# metrics.append([acc, top5, top10, macrof1, microf1])
metrics.append([acc, macrof1, microf1])
for m, v in zip(_metrics, l_eval[lang]):
lang_metrics[m][lang] = v
if phase != "validation":
print(
# f"Lang {lang}: acc = {acc:.3f} acc-top5 = {top5:.3f} acc-top10 = {top10:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
)
averages = np.mean(np.array(metrics), axis=0)
if verbose:
print(
# "Averages: Acc, Acc-5, Acc-10, MF1, mF1",
"Averages: Acc, MF1, mF1",
np.round(averages, 3),
"\n",
)
avg_metrics = dict(zip(_metrics, averages))
return avg_metrics, lang_metrics

View File

@ -1,17 +1,14 @@
import os
import sys
# sys.path.append(os.path.join(os.getcwd(), "gfun"))
import pickle
import numpy as np
from gfun.vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator
from gfun.vgfs.commons import AttentionAggregator, TfidfVectorizerMultilingual, predict
from gfun.vgfs.learners.svms import MetaClassifier, get_learner
from gfun.vgfs.multilingualGen import MultilingualGen
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
from gfun.vgfs.vanillaFun import VanillaFunGen
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
from gfun.vgfs.wceGen import WceGen
@ -25,12 +22,14 @@ class GeneralizedFunnelling:
visual_transformer,
langs,
num_labels,
classification_type,
embed_dir,
n_jobs,
batch_size,
eval_batch_size,
max_length,
lr,
textual_lr,
visual_lr,
epochs,
patience,
evaluate_step,
@ -52,6 +51,7 @@ class GeneralizedFunnelling:
self.visual_trf_vgf = visual_transformer
self.probabilistic = probabilistic
self.num_labels = num_labels
self.clf_type = classification_type
# ------------------------
self.langs = langs
self.embed_dir = embed_dir
@ -59,7 +59,8 @@ class GeneralizedFunnelling:
# Textual Transformer VGF params ----------
self.textual_trf_name = textual_transformer_name
self.epochs = epochs
self.lr_transformer = lr
self.txt_trf_lr = textual_lr
self.vis_trf_lr = visual_lr
self.batch_size_trf = batch_size
self.eval_batch_size_trf = eval_batch_size
self.max_length = max_length
@ -114,7 +115,7 @@ class GeneralizedFunnelling:
self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
out_dim=self.num_labels,
lr=self.lr_transformer,
lr=self.txt_trf_lr,
patience=self.patience,
num_heads=1,
device=self.device,
@ -148,7 +149,7 @@ class GeneralizedFunnelling:
transformer_vgf = TextualTransformerGen(
dataset_name=self.dataset_name,
model_name=self.textual_trf_name,
lr=self.lr_transformer,
lr=self.txt_trf_lr,
epochs=self.epochs,
batch_size=self.batch_size_trf,
batch_size_eval=self.eval_batch_size_trf,
@ -159,6 +160,7 @@ class GeneralizedFunnelling:
verbose=True,
patience=self.patience,
device=self.device,
classification_type=self.clf_type,
)
self.first_tier_learners.append(transformer_vgf)
@ -166,7 +168,7 @@ class GeneralizedFunnelling:
visual_trasformer_vgf = VisualTransformerGen(
dataset_name=self.dataset_name,
model_name="vit",
lr=self.lr_transformer,
lr=self.vis_trf_lr,
epochs=self.epochs,
batch_size=self.batch_size_trf,
batch_size_eval=self.eval_batch_size_trf,
@ -174,6 +176,7 @@ class GeneralizedFunnelling:
evaluate_step=self.evaluate_step,
patience=self.patience,
device=self.device,
classification_type=self.clf_type,
)
self.first_tier_learners.append(visual_trasformer_vgf)
@ -182,7 +185,7 @@ class GeneralizedFunnelling:
self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
out_dim=self.num_labels,
lr=self.lr_transformer,
lr=self.txt_trf_lr,
patience=self.patience,
num_heads=1,
device=self.device,
@ -255,10 +258,9 @@ class GeneralizedFunnelling:
projections.append(l_posteriors)
agg = self.aggregate(projections)
l_out = self.metaclassifier.predict_proba(agg)
# converting to binary predictions
# if self.dataset_name in ["cls"]: # TODO: better way to do this
# for lang, preds in l_out.items():
# l_out[lang] = np.expand_dims(np.argmax(preds, axis=1), axis=1)
if self.clf_type == "singlelabel":
for lang, preds in l_out.items():
l_out[lang] = predict(preds, clf_type=self.clf_type)
return l_out
def fit_transform(self, lX, lY):

View File

@ -9,6 +9,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from transformers.modeling_outputs import ModelOutput
@ -22,6 +23,21 @@ def _normalize(lX, l2=True):
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
def verbosity_eval(epoch, print_eval):
if (epoch + 1) % print_eval == 0 and epoch != 0:
return True
else:
return False
def format_langkey_wandb(lang_dict):
log_dict = {}
for metric, l_dict in lang_dict.items():
for lang, value in l_dict.items():
log_dict[f"language metric/{metric}/{lang}"] = value
return log_dict
def XdotM(X, M, sif):
E = X.dot(M)
if sif:
@ -58,18 +74,23 @@ def compute_pc(X, npc=1):
return svd.components_
def predict(logits, classification_type="multilabel"):
def predict(logits, clf_type="multilabel"):
"""
Converts soft precictions to hard predictions [0,1]
"""
if classification_type == "multilabel":
if clf_type == "multilabel":
prediction = torch.sigmoid(logits) > 0.5
elif classification_type == "singlelabel":
prediction = torch.argmax(logits, dim=1).view(-1, 1)
return prediction.detach().cpu().numpy()
elif clf_type == "singlelabel":
if type(logits) != torch.Tensor:
logits = torch.tensor(logits)
prediction = torch.softmax(logits, dim=1)
prediction = prediction.detach().cpu().numpy()
_argmaxs = prediction.argmax(axis=1)
prediction = np.eye(prediction.shape[1])[_argmaxs]
return prediction
else:
print("unknown classification type")
return prediction.detach().cpu().numpy()
raise NotImplementedError()
class TfidfVectorizerMultilingual:
@ -115,36 +136,54 @@ class Trainer:
patience,
experiment_name,
checkpoint_path,
classification_type,
vgf_name,
n_jobs,
scheduler_name=None,
):
self.device = device
self.model = model.to(device)
self.optimizer = self.init_optimizer(optimizer_name, lr)
self.optimizer, self.scheduler = self.init_optimizer(
optimizer_name, lr, scheduler_name
)
self.evaluate_steps = evaluate_step
self.loss_fn = loss_fn.to(device)
self.print_steps = print_steps
self.experiment_name = experiment_name
self.patience = patience
self.print_eval = evaluate_step
self.print_eval = 10
self.earlystopping = EarlyStopping(
patience=patience,
checkpoint_path=checkpoint_path,
verbose=False,
experiment_name=experiment_name,
)
self.clf_type = classification_type
self.vgf_name = vgf_name
self.n_jobs = n_jobs
self.monitored_metric = (
"macro-F1" if self.clf_type == "multilabel" else "accuracy"
) # TODO: make this configurable
def init_optimizer(self, optimizer_name, lr):
def init_optimizer(self, optimizer_name, lr, scheduler_name):
if optimizer_name.lower() == "adamw":
return AdamW(self.model.parameters(), lr=lr)
optim = AdamW(self.model.parameters(), lr=lr)
else:
raise ValueError(f"Optimizer {optimizer_name} not supported")
if scheduler_name is None:
scheduler = None
elif scheduler_name == "ReduceLROnPlateau":
scheduler = ReduceLROnPlateau(optim, "max", factor=0.5, min_lr=1e-5)
else:
raise ValueError(f"Scheduler {scheduler_name} not supported")
return optim, scheduler
def get_config(self, train_dataloader, eval_dataloader, epochs):
return {
"model name": self.model.name_or_path,
"epochs": epochs,
"learning rate": self.optimizer.defaults["lr"],
"scheduler": "TODO", # TODO: add scheduler name
"train batch size": train_dataloader.batch_size,
"eval batch size": eval_dataloader.batch_size,
"max len": train_dataloader.dataset.X.shape[-1],
@ -152,6 +191,7 @@ class Trainer:
"evaluate every": self.evaluate_steps,
"print eval every": self.print_eval,
"print train steps": self.print_steps,
"classification type": self.clf_type,
}
def train(self, train_dataloader, eval_dataloader, epochs=10):
@ -168,23 +208,23 @@ class Trainer:
for epoch in range(epochs):
train_loss = self.train_epoch(train_dataloader, epoch)
wandb_logger.log({f"{self.vgf_name}_train_loss": train_loss})
if (epoch + 1) % self.evaluate_steps == 0:
print_eval = (epoch + 1) % self.print_eval == 0
if (epoch + 1) % self.evaluate_steps == 0 or (epoch + 1) == 1:
print_eval = verbosity_eval(epoch, self.print_eval)
with torch.no_grad():
eval_loss, metric_watcher = self.evaluate(
eval_dataloader, epoch, print_eval=print_eval
eval_loss, avg_metrics, lang_metrics = self.evaluate(
eval_dataloader,
print_eval=print_eval,
n_jobs=self.n_jobs,
)
wandb_logger.log(
{
f"{self.vgf_name}_eval_loss": eval_loss,
f"{self.vgf_name}_eval_metric": metric_watcher,
}
{"loss/val": eval_loss, **format_langkey_wandb(lang_metrics)},
commit=False,
)
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
stop = self.earlystopping(
avg_metrics[self.monitored_metric], self.model, epoch + 1
)
if stop:
print(
f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}"
@ -194,6 +234,16 @@ class Trainer:
)
break
if self.scheduler is not None:
self.scheduler.step(avg_metrics[self.monitored_metric])
wandb_logger.log(
{
"loss/train": train_loss,
"learning rate": self.optimizer.param_groups[0]["lr"],
}
)
print(f"- last swipe on eval set")
self.train_epoch(eval_dataloader, epoch=-1)
self.earlystopping.save_model(self.model)
@ -201,6 +251,7 @@ class Trainer:
def train_epoch(self, dataloader, epoch):
self.model.train()
epoch_losses = []
for b_idx, (x, y, lang) in enumerate(dataloader):
self.optimizer.zero_grad()
y_hat = self.model(x.to(self.device))
@ -210,38 +261,47 @@ class Trainer:
loss = self.loss_fn(y_hat, y.to(self.device))
loss.backward()
self.optimizer.step()
epoch_losses.append(loss.item())
if (epoch + 1) % PRINT_ON_EPOCH == 0:
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
return loss.item()
print(
f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {np.mean(epoch_losses):.4f}"
)
return np.mean(epoch_losses)
def evaluate(self, dataloader, epoch, print_eval=True):
def evaluate(self, dataloader, print_eval=True, n_jobs=-1):
self.model.eval()
eval_losses = []
lY = defaultdict(list)
lY_hat = defaultdict(list)
lY_true = defaultdict(list)
lY_pred = defaultdict(list)
for b_idx, (x, y, lang) in enumerate(dataloader):
y_hat = self.model(x.to(self.device))
if isinstance(y_hat, ModelOutput):
loss = self.loss_fn(y_hat.logits, y.to(self.device))
predictions = predict(y_hat.logits, classification_type="multilabel")
y_pred = self.model(x.to(self.device))
if isinstance(y_pred, ModelOutput):
loss = self.loss_fn(y_pred.logits, y.to(self.device))
predictions = predict(y_pred.logits, clf_type=self.clf_type)
else:
loss = self.loss_fn(y_hat, y.to(self.device))
predictions = predict(y_hat, classification_type="multilabel")
loss = self.loss_fn(y_pred, y.to(self.device))
predictions = predict(y_pred, clf_type=self.clf_type)
eval_losses.append(loss.item())
for l, _true, _pred in zip(lang, y, predictions):
lY[l].append(_true.detach().cpu().numpy())
lY_hat[l].append(_pred)
lY_true[l].append(_true.detach().cpu().numpy())
lY_pred[l].append(_pred)
for lang in lY:
lY[lang] = np.vstack(lY[lang])
lY_hat[lang] = np.vstack(lY_hat[lang])
for lang in lY_true:
lY_true[lang] = np.vstack(lY_true[lang])
lY_pred[lang] = np.vstack(lY_pred[lang])
l_eval = evaluate(lY, lY_hat)
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
l_eval = evaluate(lY_true, lY_pred, clf_type=self.clf_type, n_jobs=n_jobs)
return loss.item(), average_metrics[0] # macro-F1
avg_metrics, lang_metrics = log_eval(
l_eval, phase="validation", clf_type=self.clf_type, verbose=print_eval
)
return np.mean(eval_losses), avg_metrics, lang_metrics
class EarlyStopping:
@ -279,7 +339,7 @@ class EarlyStopping:
print(
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
)
if self.counter >= self.patience:
if self.counter >= self.patience and self.patience != -1:
print(f"- earlystopping: Early stopping at epoch {epoch}")
return True

View File

@ -7,11 +7,9 @@ from collections import defaultdict
import numpy as np
import torch
import transformers
# from sklearn.model_selection import train_test_split
# from torch.optim import AdamW
from torch.utils.data import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from gfun.vgfs.commons import Trainer
from gfun.vgfs.transformerGen import TransformerGen
from gfun.vgfs.viewGen import ViewGen
@ -19,9 +17,6 @@ from gfun.vgfs.viewGen import ViewGen
transformers.logging.set_verbosity_error()
# TODO: add support to loggers
class TextualTransformerGen(ViewGen, TransformerGen):
def __init__(
self,
@ -39,6 +34,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
evaluate_step=10,
verbose=False,
patience=5,
classification_type="multilabel",
):
super().__init__(
self._validate_model_name(model_name),
@ -56,6 +52,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
verbose,
patience,
)
self.clf_type = classification_type
self.fitted = False
print(
f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]"
@ -143,6 +140,9 @@ class TextualTransformerGen(ViewGen, TransformerGen):
experiment_name=experiment_name,
checkpoint_path="models/vgfs/transformer",
vgf_name="textual_trf",
classification_type=self.clf_type,
n_jobs=self.n_jobs,
scheduler_name="ReduceLROnPlateau",
)
trainer.train(
train_dataloader=tra_dataloader,

View File

@ -27,6 +27,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
device="cpu",
probabilistic=False,
patience=5,
classification_type="multilabel",
):
super().__init__(
model_name,
@ -40,6 +41,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
patience=patience,
probabilistic=probabilistic,
)
self.clf_type = classification_type
self.fitted = False
print(
f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]"
@ -113,6 +115,8 @@ class VisualTransformerGen(ViewGen, TransformerGen):
experiment_name=experiment_name,
checkpoint_path="models/vgfs/transformer",
vgf_name="visual_trf",
classification_type=self.clf_type,
n_jobs=self.n_jobs,
)
trainer.train(

17
main.py
View File

@ -1,3 +1,7 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from argparse import ArgumentParser
from time import time
@ -7,6 +11,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
"""
TODO:
- [!] LR scheduler
- [!] CLS dataset is loading only "books" domain data
- [!] documents should be trimmed to the same length (?)
- [!] overall gfun results logger
@ -42,6 +47,7 @@ def main(args):
dataset_name=args.dataset,
langs=dataset.langs(),
num_labels=dataset.num_labels(),
classification_type=args.clf_type,
# Posterior VGF params ----------------
posterior=args.posteriors,
# Multilingual VGF params -------------
@ -55,7 +61,8 @@ def main(args):
batch_size=args.batch_size,
eval_batch_size=args.eval_batch_size,
epochs=args.epochs,
lr=args.lr,
textual_lr=args.textual_lr,
visual_lr=args.visual_lr,
max_length=args.max_length,
patience=args.patience,
evaluate_step=args.evaluate_step,
@ -93,8 +100,8 @@ def main(args):
print(f"- training completed in {timetr - tinit:.2f} seconds")
gfun_preds = gfun.transform(lX_te)
test_eval = evaluate(lY_te, gfun_preds)
log_eval(test_eval, phase="test")
test_eval = evaluate(lY_te, gfun_preds, clf_type=args.clf_type, n_jobs=args.n_jobs)
log_eval(test_eval, phase="test", clf_type=args.clf_type)
timeval = time()
print(f"- testing completed in {timeval - timetr:.2f} seconds")
@ -112,6 +119,7 @@ if __name__ == "__main__":
parser.add_argument("--nrows", type=int, default=None)
parser.add_argument("--min_count", type=int, default=10)
parser.add_argument("--max_labels", type=int, default=50)
parser.add_argument("--clf_type", type=str, default="multilabel")
# gFUN parameters ----------------------
parser.add_argument("-p", "--posteriors", action="store_true")
parser.add_argument("-m", "--multilingual", action="store_true")
@ -127,7 +135,8 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--eval_batch_size", type=int, default=128)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--textual_lr", type=float, default=1e-5)
parser.add_argument("--visual_lr", type=float, default=1e-5)
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--patience", type=int, default=5)
parser.add_argument("--evaluate_step", type=int, default=10)