improved wandb logging
This commit is contained in:
parent
3240150542
commit
7e1ec46ebd
|
@ -1,46 +1,96 @@
|
|||
from joblib import Parallel, delayed
|
||||
from collections import defaultdict
|
||||
|
||||
from evaluation.metrics import *
|
||||
from sklearn.metrics import accuracy_score, top_k_accuracy_score, f1_score
|
||||
|
||||
|
||||
def evaluation_metrics(y, y_):
|
||||
if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
def evaluation_metrics(y, y_, clf_type):
|
||||
if clf_type == "singlelabel":
|
||||
return (
|
||||
accuracy_score(y, y_),
|
||||
# TODO: we need the logits to compute this top_k_accuracy_score(y, y_, k=5),
|
||||
# TODO: we need logits top_k_accuracy_score(y, y_, k=10),
|
||||
f1_score(y, y_, average="macro", zero_division=1),
|
||||
f1_score(y, y_, average="micro"),
|
||||
)
|
||||
elif clf_type == "multilabel":
|
||||
return (
|
||||
macroF1(y, y_),
|
||||
microF1(y, y_),
|
||||
macroK(y, y_),
|
||||
microK(y, y_),
|
||||
# macroAcc(y, y_),
|
||||
)
|
||||
else:
|
||||
raise ValueError("clf_type must be either 'singlelabel' or 'multilabel'")
|
||||
|
||||
|
||||
def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1):
|
||||
def evaluate(
|
||||
ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1, clf_type="multilabel"
|
||||
):
|
||||
if n_jobs == 1:
|
||||
return {lang: metrics(ly_true[lang], ly_pred[lang]) for lang in ly_true.keys()}
|
||||
return {
|
||||
lang: metrics(ly_true[lang], ly_pred[lang], clf_type)
|
||||
for lang in ly_true.keys()
|
||||
}
|
||||
else:
|
||||
langs = list(ly_true.keys())
|
||||
evals = Parallel(n_jobs=n_jobs)(
|
||||
delayed(metrics)(ly_true[lang], ly_pred[lang]) for lang in langs
|
||||
delayed(metrics)(ly_true[lang], ly_pred[lang], clf_type) for lang in langs
|
||||
)
|
||||
return {lang: evals[i] for i, lang in enumerate(langs)}
|
||||
|
||||
|
||||
def log_eval(l_eval, phase="training", verbose=True):
|
||||
def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
|
||||
if verbose:
|
||||
print(f"\n[Results {phase}]")
|
||||
metrics = []
|
||||
for lang in l_eval.keys():
|
||||
macrof1, microf1, macrok, microk = l_eval[lang]
|
||||
metrics.append([macrof1, microf1, macrok, microk])
|
||||
if phase != "validation":
|
||||
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
|
||||
averages = np.mean(np.array(metrics), axis=0)
|
||||
if verbose:
|
||||
print(
|
||||
"Averages: MF1, mF1, MK, mK",
|
||||
np.round(averages, 3),
|
||||
"\n",
|
||||
)
|
||||
return averages
|
||||
|
||||
if clf_type == "multilabel":
|
||||
for lang in l_eval.keys():
|
||||
macrof1, microf1, macrok, microk = l_eval[lang]
|
||||
metrics.append([macrof1, microf1, macrok, microk])
|
||||
if phase != "validation":
|
||||
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
|
||||
averages = np.mean(np.array(metrics), axis=0)
|
||||
if verbose:
|
||||
print(
|
||||
"Averages: MF1, mF1, MK, mK",
|
||||
np.round(averages, 3),
|
||||
"\n",
|
||||
)
|
||||
return averages # TODO: return a dict avg and lang specific
|
||||
|
||||
elif clf_type == "singlelabel":
|
||||
lang_metrics = defaultdict(dict)
|
||||
_metrics = [
|
||||
"accuracy",
|
||||
# "acc5", # "accuracy-at-5",
|
||||
# "acc10", # "accuracy-at-10",
|
||||
"MF1", # "macro-F1",
|
||||
"mF1", # "micro-F1",
|
||||
]
|
||||
for lang in l_eval.keys():
|
||||
# acc, top5, top10, macrof1, microf1 = l_eval[lang]
|
||||
acc, macrof1, microf1 = l_eval[lang]
|
||||
# metrics.append([acc, top5, top10, macrof1, microf1])
|
||||
metrics.append([acc, macrof1, microf1])
|
||||
|
||||
for m, v in zip(_metrics, l_eval[lang]):
|
||||
lang_metrics[m][lang] = v
|
||||
|
||||
if phase != "validation":
|
||||
print(
|
||||
# f"Lang {lang}: acc = {acc:.3f} acc-top5 = {top5:.3f} acc-top10 = {top10:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
|
||||
f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
|
||||
)
|
||||
averages = np.mean(np.array(metrics), axis=0)
|
||||
if verbose:
|
||||
print(
|
||||
# "Averages: Acc, Acc-5, Acc-10, MF1, mF1",
|
||||
"Averages: Acc, MF1, mF1",
|
||||
np.round(averages, 3),
|
||||
"\n",
|
||||
)
|
||||
avg_metrics = dict(zip(_metrics, averages))
|
||||
return avg_metrics, lang_metrics
|
||||
|
|
|
@ -1,17 +1,14 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
# sys.path.append(os.path.join(os.getcwd(), "gfun"))
|
||||
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from gfun.vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator
|
||||
|
||||
from gfun.vgfs.commons import AttentionAggregator, TfidfVectorizerMultilingual, predict
|
||||
from gfun.vgfs.learners.svms import MetaClassifier, get_learner
|
||||
from gfun.vgfs.multilingualGen import MultilingualGen
|
||||
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
|
||||
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
|
||||
from gfun.vgfs.vanillaFun import VanillaFunGen
|
||||
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
|
||||
from gfun.vgfs.wceGen import WceGen
|
||||
|
||||
|
||||
|
@ -25,12 +22,14 @@ class GeneralizedFunnelling:
|
|||
visual_transformer,
|
||||
langs,
|
||||
num_labels,
|
||||
classification_type,
|
||||
embed_dir,
|
||||
n_jobs,
|
||||
batch_size,
|
||||
eval_batch_size,
|
||||
max_length,
|
||||
lr,
|
||||
textual_lr,
|
||||
visual_lr,
|
||||
epochs,
|
||||
patience,
|
||||
evaluate_step,
|
||||
|
@ -52,6 +51,7 @@ class GeneralizedFunnelling:
|
|||
self.visual_trf_vgf = visual_transformer
|
||||
self.probabilistic = probabilistic
|
||||
self.num_labels = num_labels
|
||||
self.clf_type = classification_type
|
||||
# ------------------------
|
||||
self.langs = langs
|
||||
self.embed_dir = embed_dir
|
||||
|
@ -59,7 +59,8 @@ class GeneralizedFunnelling:
|
|||
# Textual Transformer VGF params ----------
|
||||
self.textual_trf_name = textual_transformer_name
|
||||
self.epochs = epochs
|
||||
self.lr_transformer = lr
|
||||
self.txt_trf_lr = textual_lr
|
||||
self.vis_trf_lr = visual_lr
|
||||
self.batch_size_trf = batch_size
|
||||
self.eval_batch_size_trf = eval_batch_size
|
||||
self.max_length = max_length
|
||||
|
@ -114,7 +115,7 @@ class GeneralizedFunnelling:
|
|||
self.attn_aggregator = AttentionAggregator(
|
||||
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
||||
out_dim=self.num_labels,
|
||||
lr=self.lr_transformer,
|
||||
lr=self.txt_trf_lr,
|
||||
patience=self.patience,
|
||||
num_heads=1,
|
||||
device=self.device,
|
||||
|
@ -148,7 +149,7 @@ class GeneralizedFunnelling:
|
|||
transformer_vgf = TextualTransformerGen(
|
||||
dataset_name=self.dataset_name,
|
||||
model_name=self.textual_trf_name,
|
||||
lr=self.lr_transformer,
|
||||
lr=self.txt_trf_lr,
|
||||
epochs=self.epochs,
|
||||
batch_size=self.batch_size_trf,
|
||||
batch_size_eval=self.eval_batch_size_trf,
|
||||
|
@ -159,6 +160,7 @@ class GeneralizedFunnelling:
|
|||
verbose=True,
|
||||
patience=self.patience,
|
||||
device=self.device,
|
||||
classification_type=self.clf_type,
|
||||
)
|
||||
self.first_tier_learners.append(transformer_vgf)
|
||||
|
||||
|
@ -166,7 +168,7 @@ class GeneralizedFunnelling:
|
|||
visual_trasformer_vgf = VisualTransformerGen(
|
||||
dataset_name=self.dataset_name,
|
||||
model_name="vit",
|
||||
lr=self.lr_transformer,
|
||||
lr=self.vis_trf_lr,
|
||||
epochs=self.epochs,
|
||||
batch_size=self.batch_size_trf,
|
||||
batch_size_eval=self.eval_batch_size_trf,
|
||||
|
@ -174,6 +176,7 @@ class GeneralizedFunnelling:
|
|||
evaluate_step=self.evaluate_step,
|
||||
patience=self.patience,
|
||||
device=self.device,
|
||||
classification_type=self.clf_type,
|
||||
)
|
||||
self.first_tier_learners.append(visual_trasformer_vgf)
|
||||
|
||||
|
@ -182,7 +185,7 @@ class GeneralizedFunnelling:
|
|||
self.attn_aggregator = AttentionAggregator(
|
||||
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
||||
out_dim=self.num_labels,
|
||||
lr=self.lr_transformer,
|
||||
lr=self.txt_trf_lr,
|
||||
patience=self.patience,
|
||||
num_heads=1,
|
||||
device=self.device,
|
||||
|
@ -255,10 +258,9 @@ class GeneralizedFunnelling:
|
|||
projections.append(l_posteriors)
|
||||
agg = self.aggregate(projections)
|
||||
l_out = self.metaclassifier.predict_proba(agg)
|
||||
# converting to binary predictions
|
||||
# if self.dataset_name in ["cls"]: # TODO: better way to do this
|
||||
# for lang, preds in l_out.items():
|
||||
# l_out[lang] = np.expand_dims(np.argmax(preds, axis=1), axis=1)
|
||||
if self.clf_type == "singlelabel":
|
||||
for lang, preds in l_out.items():
|
||||
l_out[lang] = predict(preds, clf_type=self.clf_type)
|
||||
return l_out
|
||||
|
||||
def fit_transform(self, lX, lY):
|
||||
|
|
|
@ -9,6 +9,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer
|
|||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import normalize
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers.modeling_outputs import ModelOutput
|
||||
|
||||
|
@ -22,6 +23,21 @@ def _normalize(lX, l2=True):
|
|||
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
|
||||
|
||||
|
||||
def verbosity_eval(epoch, print_eval):
|
||||
if (epoch + 1) % print_eval == 0 and epoch != 0:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def format_langkey_wandb(lang_dict):
|
||||
log_dict = {}
|
||||
for metric, l_dict in lang_dict.items():
|
||||
for lang, value in l_dict.items():
|
||||
log_dict[f"language metric/{metric}/{lang}"] = value
|
||||
return log_dict
|
||||
|
||||
|
||||
def XdotM(X, M, sif):
|
||||
E = X.dot(M)
|
||||
if sif:
|
||||
|
@ -58,18 +74,23 @@ def compute_pc(X, npc=1):
|
|||
return svd.components_
|
||||
|
||||
|
||||
def predict(logits, classification_type="multilabel"):
|
||||
def predict(logits, clf_type="multilabel"):
|
||||
"""
|
||||
Converts soft precictions to hard predictions [0,1]
|
||||
"""
|
||||
if classification_type == "multilabel":
|
||||
if clf_type == "multilabel":
|
||||
prediction = torch.sigmoid(logits) > 0.5
|
||||
elif classification_type == "singlelabel":
|
||||
prediction = torch.argmax(logits, dim=1).view(-1, 1)
|
||||
return prediction.detach().cpu().numpy()
|
||||
elif clf_type == "singlelabel":
|
||||
if type(logits) != torch.Tensor:
|
||||
logits = torch.tensor(logits)
|
||||
prediction = torch.softmax(logits, dim=1)
|
||||
prediction = prediction.detach().cpu().numpy()
|
||||
_argmaxs = prediction.argmax(axis=1)
|
||||
prediction = np.eye(prediction.shape[1])[_argmaxs]
|
||||
return prediction
|
||||
else:
|
||||
print("unknown classification type")
|
||||
|
||||
return prediction.detach().cpu().numpy()
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TfidfVectorizerMultilingual:
|
||||
|
@ -115,36 +136,54 @@ class Trainer:
|
|||
patience,
|
||||
experiment_name,
|
||||
checkpoint_path,
|
||||
classification_type,
|
||||
vgf_name,
|
||||
n_jobs,
|
||||
scheduler_name=None,
|
||||
):
|
||||
self.device = device
|
||||
self.model = model.to(device)
|
||||
self.optimizer = self.init_optimizer(optimizer_name, lr)
|
||||
self.optimizer, self.scheduler = self.init_optimizer(
|
||||
optimizer_name, lr, scheduler_name
|
||||
)
|
||||
self.evaluate_steps = evaluate_step
|
||||
self.loss_fn = loss_fn.to(device)
|
||||
self.print_steps = print_steps
|
||||
self.experiment_name = experiment_name
|
||||
self.patience = patience
|
||||
self.print_eval = evaluate_step
|
||||
self.print_eval = 10
|
||||
self.earlystopping = EarlyStopping(
|
||||
patience=patience,
|
||||
checkpoint_path=checkpoint_path,
|
||||
verbose=False,
|
||||
experiment_name=experiment_name,
|
||||
)
|
||||
self.clf_type = classification_type
|
||||
self.vgf_name = vgf_name
|
||||
self.n_jobs = n_jobs
|
||||
self.monitored_metric = (
|
||||
"macro-F1" if self.clf_type == "multilabel" else "accuracy"
|
||||
) # TODO: make this configurable
|
||||
|
||||
def init_optimizer(self, optimizer_name, lr):
|
||||
def init_optimizer(self, optimizer_name, lr, scheduler_name):
|
||||
if optimizer_name.lower() == "adamw":
|
||||
return AdamW(self.model.parameters(), lr=lr)
|
||||
optim = AdamW(self.model.parameters(), lr=lr)
|
||||
else:
|
||||
raise ValueError(f"Optimizer {optimizer_name} not supported")
|
||||
if scheduler_name is None:
|
||||
scheduler = None
|
||||
elif scheduler_name == "ReduceLROnPlateau":
|
||||
scheduler = ReduceLROnPlateau(optim, "max", factor=0.5, min_lr=1e-5)
|
||||
else:
|
||||
raise ValueError(f"Scheduler {scheduler_name} not supported")
|
||||
return optim, scheduler
|
||||
|
||||
def get_config(self, train_dataloader, eval_dataloader, epochs):
|
||||
return {
|
||||
"model name": self.model.name_or_path,
|
||||
"epochs": epochs,
|
||||
"learning rate": self.optimizer.defaults["lr"],
|
||||
"scheduler": "TODO", # TODO: add scheduler name
|
||||
"train batch size": train_dataloader.batch_size,
|
||||
"eval batch size": eval_dataloader.batch_size,
|
||||
"max len": train_dataloader.dataset.X.shape[-1],
|
||||
|
@ -152,6 +191,7 @@ class Trainer:
|
|||
"evaluate every": self.evaluate_steps,
|
||||
"print eval every": self.print_eval,
|
||||
"print train steps": self.print_steps,
|
||||
"classification type": self.clf_type,
|
||||
}
|
||||
|
||||
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
||||
|
@ -168,23 +208,23 @@ class Trainer:
|
|||
for epoch in range(epochs):
|
||||
train_loss = self.train_epoch(train_dataloader, epoch)
|
||||
|
||||
wandb_logger.log({f"{self.vgf_name}_train_loss": train_loss})
|
||||
|
||||
if (epoch + 1) % self.evaluate_steps == 0:
|
||||
print_eval = (epoch + 1) % self.print_eval == 0
|
||||
if (epoch + 1) % self.evaluate_steps == 0 or (epoch + 1) == 1:
|
||||
print_eval = verbosity_eval(epoch, self.print_eval)
|
||||
with torch.no_grad():
|
||||
eval_loss, metric_watcher = self.evaluate(
|
||||
eval_dataloader, epoch, print_eval=print_eval
|
||||
eval_loss, avg_metrics, lang_metrics = self.evaluate(
|
||||
eval_dataloader,
|
||||
print_eval=print_eval,
|
||||
n_jobs=self.n_jobs,
|
||||
)
|
||||
|
||||
wandb_logger.log(
|
||||
{
|
||||
f"{self.vgf_name}_eval_loss": eval_loss,
|
||||
f"{self.vgf_name}_eval_metric": metric_watcher,
|
||||
}
|
||||
{"loss/val": eval_loss, **format_langkey_wandb(lang_metrics)},
|
||||
commit=False,
|
||||
)
|
||||
|
||||
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
||||
stop = self.earlystopping(
|
||||
avg_metrics[self.monitored_metric], self.model, epoch + 1
|
||||
)
|
||||
if stop:
|
||||
print(
|
||||
f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}"
|
||||
|
@ -194,6 +234,16 @@ class Trainer:
|
|||
)
|
||||
break
|
||||
|
||||
if self.scheduler is not None:
|
||||
self.scheduler.step(avg_metrics[self.monitored_metric])
|
||||
|
||||
wandb_logger.log(
|
||||
{
|
||||
"loss/train": train_loss,
|
||||
"learning rate": self.optimizer.param_groups[0]["lr"],
|
||||
}
|
||||
)
|
||||
|
||||
print(f"- last swipe on eval set")
|
||||
self.train_epoch(eval_dataloader, epoch=-1)
|
||||
self.earlystopping.save_model(self.model)
|
||||
|
@ -201,6 +251,7 @@ class Trainer:
|
|||
|
||||
def train_epoch(self, dataloader, epoch):
|
||||
self.model.train()
|
||||
epoch_losses = []
|
||||
for b_idx, (x, y, lang) in enumerate(dataloader):
|
||||
self.optimizer.zero_grad()
|
||||
y_hat = self.model(x.to(self.device))
|
||||
|
@ -210,38 +261,47 @@ class Trainer:
|
|||
loss = self.loss_fn(y_hat, y.to(self.device))
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
epoch_losses.append(loss.item())
|
||||
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
||||
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
||||
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
||||
return loss.item()
|
||||
print(
|
||||
f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {np.mean(epoch_losses):.4f}"
|
||||
)
|
||||
return np.mean(epoch_losses)
|
||||
|
||||
def evaluate(self, dataloader, epoch, print_eval=True):
|
||||
def evaluate(self, dataloader, print_eval=True, n_jobs=-1):
|
||||
self.model.eval()
|
||||
eval_losses = []
|
||||
|
||||
lY = defaultdict(list)
|
||||
lY_hat = defaultdict(list)
|
||||
lY_true = defaultdict(list)
|
||||
lY_pred = defaultdict(list)
|
||||
|
||||
for b_idx, (x, y, lang) in enumerate(dataloader):
|
||||
y_hat = self.model(x.to(self.device))
|
||||
if isinstance(y_hat, ModelOutput):
|
||||
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
||||
predictions = predict(y_hat.logits, classification_type="multilabel")
|
||||
y_pred = self.model(x.to(self.device))
|
||||
if isinstance(y_pred, ModelOutput):
|
||||
loss = self.loss_fn(y_pred.logits, y.to(self.device))
|
||||
predictions = predict(y_pred.logits, clf_type=self.clf_type)
|
||||
else:
|
||||
loss = self.loss_fn(y_hat, y.to(self.device))
|
||||
predictions = predict(y_hat, classification_type="multilabel")
|
||||
loss = self.loss_fn(y_pred, y.to(self.device))
|
||||
predictions = predict(y_pred, clf_type=self.clf_type)
|
||||
|
||||
eval_losses.append(loss.item())
|
||||
|
||||
for l, _true, _pred in zip(lang, y, predictions):
|
||||
lY[l].append(_true.detach().cpu().numpy())
|
||||
lY_hat[l].append(_pred)
|
||||
lY_true[l].append(_true.detach().cpu().numpy())
|
||||
lY_pred[l].append(_pred)
|
||||
|
||||
for lang in lY:
|
||||
lY[lang] = np.vstack(lY[lang])
|
||||
lY_hat[lang] = np.vstack(lY_hat[lang])
|
||||
for lang in lY_true:
|
||||
lY_true[lang] = np.vstack(lY_true[lang])
|
||||
lY_pred[lang] = np.vstack(lY_pred[lang])
|
||||
|
||||
l_eval = evaluate(lY, lY_hat)
|
||||
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
|
||||
l_eval = evaluate(lY_true, lY_pred, clf_type=self.clf_type, n_jobs=n_jobs)
|
||||
|
||||
return loss.item(), average_metrics[0] # macro-F1
|
||||
avg_metrics, lang_metrics = log_eval(
|
||||
l_eval, phase="validation", clf_type=self.clf_type, verbose=print_eval
|
||||
)
|
||||
|
||||
return np.mean(eval_losses), avg_metrics, lang_metrics
|
||||
|
||||
|
||||
class EarlyStopping:
|
||||
|
@ -279,7 +339,7 @@ class EarlyStopping:
|
|||
print(
|
||||
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
|
||||
)
|
||||
if self.counter >= self.patience:
|
||||
if self.counter >= self.patience and self.patience != -1:
|
||||
print(f"- earlystopping: Early stopping at epoch {epoch}")
|
||||
return True
|
||||
|
||||
|
|
|
@ -7,11 +7,9 @@ from collections import defaultdict
|
|||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
# from sklearn.model_selection import train_test_split
|
||||
# from torch.optim import AdamW
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from gfun.vgfs.commons import Trainer
|
||||
from gfun.vgfs.transformerGen import TransformerGen
|
||||
from gfun.vgfs.viewGen import ViewGen
|
||||
|
@ -19,9 +17,6 @@ from gfun.vgfs.viewGen import ViewGen
|
|||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
||||
# TODO: add support to loggers
|
||||
|
||||
|
||||
class TextualTransformerGen(ViewGen, TransformerGen):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -39,6 +34,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
evaluate_step=10,
|
||||
verbose=False,
|
||||
patience=5,
|
||||
classification_type="multilabel",
|
||||
):
|
||||
super().__init__(
|
||||
self._validate_model_name(model_name),
|
||||
|
@ -56,6 +52,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
verbose,
|
||||
patience,
|
||||
)
|
||||
self.clf_type = classification_type
|
||||
self.fitted = False
|
||||
print(
|
||||
f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
||||
|
@ -143,6 +140,9 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
experiment_name=experiment_name,
|
||||
checkpoint_path="models/vgfs/transformer",
|
||||
vgf_name="textual_trf",
|
||||
classification_type=self.clf_type,
|
||||
n_jobs=self.n_jobs,
|
||||
scheduler_name="ReduceLROnPlateau",
|
||||
)
|
||||
trainer.train(
|
||||
train_dataloader=tra_dataloader,
|
||||
|
|
|
@ -27,6 +27,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
device="cpu",
|
||||
probabilistic=False,
|
||||
patience=5,
|
||||
classification_type="multilabel",
|
||||
):
|
||||
super().__init__(
|
||||
model_name,
|
||||
|
@ -40,6 +41,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
patience=patience,
|
||||
probabilistic=probabilistic,
|
||||
)
|
||||
self.clf_type = classification_type
|
||||
self.fitted = False
|
||||
print(
|
||||
f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
||||
|
@ -113,6 +115,8 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
experiment_name=experiment_name,
|
||||
checkpoint_path="models/vgfs/transformer",
|
||||
vgf_name="visual_trf",
|
||||
classification_type=self.clf_type,
|
||||
n_jobs=self.n_jobs,
|
||||
)
|
||||
|
||||
trainer.train(
|
||||
|
|
17
main.py
17
main.py
|
@ -1,3 +1,7 @@
|
|||
import os
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
from argparse import ArgumentParser
|
||||
from time import time
|
||||
|
||||
|
@ -7,6 +11,7 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
|||
|
||||
"""
|
||||
TODO:
|
||||
- [!] LR scheduler
|
||||
- [!] CLS dataset is loading only "books" domain data
|
||||
- [!] documents should be trimmed to the same length (?)
|
||||
- [!] overall gfun results logger
|
||||
|
@ -42,6 +47,7 @@ def main(args):
|
|||
dataset_name=args.dataset,
|
||||
langs=dataset.langs(),
|
||||
num_labels=dataset.num_labels(),
|
||||
classification_type=args.clf_type,
|
||||
# Posterior VGF params ----------------
|
||||
posterior=args.posteriors,
|
||||
# Multilingual VGF params -------------
|
||||
|
@ -55,7 +61,8 @@ def main(args):
|
|||
batch_size=args.batch_size,
|
||||
eval_batch_size=args.eval_batch_size,
|
||||
epochs=args.epochs,
|
||||
lr=args.lr,
|
||||
textual_lr=args.textual_lr,
|
||||
visual_lr=args.visual_lr,
|
||||
max_length=args.max_length,
|
||||
patience=args.patience,
|
||||
evaluate_step=args.evaluate_step,
|
||||
|
@ -93,8 +100,8 @@ def main(args):
|
|||
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
||||
|
||||
gfun_preds = gfun.transform(lX_te)
|
||||
test_eval = evaluate(lY_te, gfun_preds)
|
||||
log_eval(test_eval, phase="test")
|
||||
test_eval = evaluate(lY_te, gfun_preds, clf_type=args.clf_type, n_jobs=args.n_jobs)
|
||||
log_eval(test_eval, phase="test", clf_type=args.clf_type)
|
||||
|
||||
timeval = time()
|
||||
print(f"- testing completed in {timeval - timetr:.2f} seconds")
|
||||
|
@ -112,6 +119,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--nrows", type=int, default=None)
|
||||
parser.add_argument("--min_count", type=int, default=10)
|
||||
parser.add_argument("--max_labels", type=int, default=50)
|
||||
parser.add_argument("--clf_type", type=str, default="multilabel")
|
||||
# gFUN parameters ----------------------
|
||||
parser.add_argument("-p", "--posteriors", action="store_true")
|
||||
parser.add_argument("-m", "--multilingual", action="store_true")
|
||||
|
@ -127,7 +135,8 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
parser.add_argument("--eval_batch_size", type=int, default=128)
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--lr", type=float, default=1e-5)
|
||||
parser.add_argument("--textual_lr", type=float, default=1e-5)
|
||||
parser.add_argument("--visual_lr", type=float, default=1e-5)
|
||||
parser.add_argument("--max_length", type=int, default=128)
|
||||
parser.add_argument("--patience", type=int, default=5)
|
||||
parser.add_argument("--evaluate_step", type=int, default=10)
|
||||
|
|
Loading…
Reference in New Issue