Compare commits

..

22 Commits

Author SHA1 Message Date
Andrea Pedrotti ab7a310b34 todo updates 2023-03-17 10:44:45 +01:00
Andrea Pedrotti 41647f974a last training swipe on eval set is now performed on batch size equal to the training set batch size 2023-03-17 10:44:23 +01:00
Andrea Pedrotti ee2a9481de sampling GLAMI1-M dataset 2023-03-16 18:10:05 +01:00
Andrea Pedrotti ee38bcda10 fixed TransformerGen init 2023-03-16 12:12:39 +01:00
Andrea Pedrotti b34da419d0 fixed import 2023-03-16 11:49:49 +01:00
Andrea Pedrotti 17d0003e48 getter for gFun and VGFs config 2023-03-16 11:41:40 +01:00
Andrea Pedrotti 9d43ebb23b implemented save/load for MT5ForSequenceClassification. Moved torch Datasets to datamanager module 2023-03-16 10:31:34 +01:00
Andrea Pedrotti 56faaf2615 changed wandb logging to a global level to keep track of all the VGFs and overall gFun 2023-03-15 16:35:49 +01:00
Andrea Pedrotti f32b9227ae TODO: better stratified sampling for GLAMI-1M 2023-03-15 11:48:03 +01:00
Andrea Pedrotti 65407f51fa update trainer to handle mT5 2023-03-15 11:47:17 +01:00
Andrea Pedrotti 26aa0b327a average pooling for MT5ForSequenceClassification and standardized return data 2023-03-15 11:46:53 +01:00
Andrea Pedrotti fece8d059e updated argparse 2023-03-14 11:54:40 +01:00
Andrea Pedrotti 5e41b4517a implemented MT5ForSequenceClassification 2023-03-14 11:53:50 +01:00
Andrea Pedrotti a3e183d7fc avoid duplicating model on gpu when earlystop is triggered 2023-03-14 11:22:00 +01:00
Andrea Pedrotti 57918ec523 save and load datasets as pkl 2023-03-10 12:40:26 +01:00
andreapdr 7d0d6ba1f6 log average metrics via wandb 2023-03-10 11:21:33 +01:00
andreapdr 5ef0904e0e logging average metrics 2023-03-09 17:59:18 +01:00
andreapdr 7e1ec46ebd improved wandb logging 2023-03-09 17:03:17 +01:00
Andrea Pedrotti 3240150542 updated todo 2023-03-07 17:36:21 +01:00
Andrea Pedrotti 84dd1f093e logging via wandb 2023-03-07 17:34:25 +01:00
Andrea Pedrotti 6b7917ca47 typos 2023-03-07 14:33:30 +01:00
andreapdr 7dead90271 logging via wandb 2023-03-07 14:20:56 +01:00
16 changed files with 640 additions and 340 deletions

1
.gitignore vendored
View File

@ -182,3 +182,4 @@ scripts/
logger/*
explore_data.ipynb
run.sh
wandb

View File

@ -1,3 +1,5 @@
import os
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from dataManager.glamiDataset import get_dataframe
from dataManager.multilingualDataset import MultilingualDataset
@ -22,7 +24,7 @@ class gFunDataset:
self.labels = labels
self.nrows = nrows
self.dataset = {}
self.load_dataset()
self._load_dataset()
def get_label_binarizer(self, labels):
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
@ -35,7 +37,7 @@ class gFunDataset:
mlb.fit(labels)
return mlb
def load_dataset(self):
def _load_dataset(self):
if "glami" in self.dataset_dir.lower():
print(f"- Loading GLAMI dataset from {self.dataset_dir}")
self.dataset_name = "glami"
@ -106,44 +108,19 @@ class gFunDataset:
return dataset, labels, data_langs
def _load_glami(self, dataset_dir, nrows):
def _balanced_sample(data, n, remainder=0):
import pandas as pd
langs = sorted(data.geo.unique().tolist())
dict_n = {lang: n for lang in langs}
dict_n[langs[0]] += remainder
sampled = []
for lang in langs:
sampled.append(data[data.geo == lang].sample(n=dict_n[lang]))
return pd.concat(sampled, axis=0)
# TODO: set this sampling as determinsitic/dependeing on the seed
lang_nrows = (
nrows // 13 if self.data_langs is None else nrows // len(self.data_langs)
) # GLAMI 1-M has 13 languages
remainder = (
nrows % 13 if self.data_langs is None else nrows % len(self.data_langs)
train_split = get_dataframe("train", dataset_dir=dataset_dir).sample(n=nrows)
test_split = get_dataframe("test", dataset_dir=dataset_dir).sample(
n=int(nrows / 10)
)
train_split = get_dataframe("train", dataset_dir=dataset_dir)
train_split = _balanced_sample(train_split, lang_nrows, remainder=remainder)
if self.data_langs is None:
data_langs = sorted(train_split.geo.unique().tolist())
# TODO: if data langs is NOT none then we have a problem where we filter df by langs
if self.labels is None:
labels = train_split.category_name.unique().tolist()
# TODO: atm test data should contain same languages as train data
test_split = get_dataframe("test", dataset_dir=dataset_dir)
# TODO: atm we're using 1:1 train-test
test_split = _balanced_sample(test_split, lang_nrows, remainder=remainder)
gb_train = train_split.groupby("geo")
gb_test = test_split.groupby("geo")
if self.data_langs is None:
data_langs = sorted(train_split.geo.unique().tolist())
if self.labels is None:
labels = train_split.category_name.unique().tolist()
def _format_glami(data_df):
text = (data_df.name + " " + data_df.description).tolist()
image = data_df.image_file.tolist()
@ -205,6 +182,14 @@ class gFunDataset:
else:
return self.labels
def save_as_pickle(self, path):
import pickle
filepath = os.path.join(path, f"{self.dataset_name}_{self.nrows}.pkl")
with open(filepath, "wb") as f:
print(f"- saving dataset in {filepath}")
pickle.dump(self, f)
if __name__ == "__main__":
import os

View File

@ -1,2 +1,66 @@
class TorchMultiNewsDataset:
pass
import torch
from torch.utils.data import Dataset
class MultilingualDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
}.values()
],
[],
)
return self
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]
class MultimodalDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([imgs for imgs in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data) for lang, data in self.lX.items()
}.values()
],
[],
)
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]

View File

@ -1,9 +1,21 @@
from os.path import expanduser
from os.path import expanduser, join
from dataManager.gFunDataset import gFunDataset
from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.amazonDataset import AmazonDataset
def load_from_pickle(path, dataset_name, nrows):
import pickle
filepath = join(path, f"{dataset_name}_{nrows}.pkl")
with open(filepath, "rb") as f:
loaded = pickle.load(f)
print(f"- Loaded dataset from {filepath}")
loaded.show_dimension()
return loaded
def get_dataset(dataset_name, args):
assert dataset_name in [
"multinews",
@ -58,6 +70,9 @@ def get_dataset(dataset_name, args):
nrows=args.nrows,
)
elif dataset_name == "glami":
if args.save_dataset is False:
dataset = load_from_pickle(GLAMI_DATAPATH, dataset_name, args.nrows)
else:
dataset = gFunDataset(
dataset_dir=GLAMI_DATAPATH,
is_textual=True,
@ -65,6 +80,9 @@ def get_dataset(dataset_name, args):
is_multilabel=False,
nrows=args.nrows,
)
dataset.save_as_pickle(GLAMI_DATAPATH)
elif dataset_name == "cls":
dataset = gFunDataset(
dataset_dir=CLS_DATAPATH,

View File

@ -1,46 +1,57 @@
from joblib import Parallel, delayed
from collections import defaultdict
from evaluation.metrics import *
from sklearn.metrics import accuracy_score, top_k_accuracy_score, f1_score
def evaluation_metrics(y, y_):
if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label
raise NotImplementedError()
else:
def evaluation_metrics(y, y_, clf_type):
if clf_type == "singlelabel":
return (
accuracy_score(y, y_),
# TODO: we need the logits to compute this top_k_accuracy_score(y, y_, k=5),
# TODO: we need logits top_k_accuracy_score(y, y_, k=10),
f1_score(y, y_, average="macro", zero_division=1),
f1_score(y, y_, average="micro"),
)
elif clf_type == "multilabel":
return (
macroF1(y, y_),
microF1(y, y_),
macroK(y, y_),
microK(y, y_),
# macroAcc(y, y_),
microAcc(
y, y_
), # TODO: we're using micro-averaging for accuracy, it is == to accuracy_score on binary classification
)
else:
raise ValueError("clf_type must be either 'singlelabel' or 'multilabel'")
def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1):
def evaluate(
ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1, clf_type="multilabel"
):
if n_jobs == 1:
return {lang: metrics(ly_true[lang], ly_pred[lang]) for lang in ly_true.keys()}
return {
lang: metrics(ly_true[lang], ly_pred[lang], clf_type)
for lang in ly_true.keys()
}
else:
langs = list(ly_true.keys())
evals = Parallel(n_jobs=n_jobs)(
delayed(metrics)(ly_true[lang], ly_pred[lang]) for lang in langs
delayed(metrics)(ly_true[lang], ly_pred[lang], clf_type) for lang in langs
)
return {lang: evals[i] for i, lang in enumerate(langs)}
def log_eval(l_eval, phase="training", verbose=True):
def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
if verbose:
print(f"\n[Results {phase}]")
metrics = []
if clf_type == "multilabel":
for lang in l_eval.keys():
macrof1, microf1, macrok, microk, microAcc = l_eval[lang]
metrics.append([macrof1, microf1, macrok, microk, microAcc])
macrof1, microf1, macrok, microk = l_eval[lang]
metrics.append([macrof1, microf1, macrok, microk])
if phase != "validation":
print(
f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f} acc = {microAcc:.3f}"
)
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
averages = np.mean(np.array(metrics), axis=0)
if verbose:
print(
@ -48,4 +59,38 @@ def log_eval(l_eval, phase="training", verbose=True):
np.round(averages, 3),
"\n",
)
return averages
return averages # TODO: return a dict avg and lang specific
elif clf_type == "singlelabel":
lang_metrics = defaultdict(dict)
_metrics = [
"accuracy",
# "acc5", # "accuracy-at-5",
# "acc10", # "accuracy-at-10",
"MF1", # "macro-F1",
"mF1", # "micro-F1",
]
for lang in l_eval.keys():
# acc, top5, top10, macrof1, microf1 = l_eval[lang]
acc, macrof1, microf1 = l_eval[lang]
# metrics.append([acc, top5, top10, macrof1, microf1])
metrics.append([acc, macrof1, microf1])
for m, v in zip(_metrics, l_eval[lang]):
lang_metrics[m][lang] = v
if phase != "validation":
print(
# f"Lang {lang}: acc = {acc:.3f} acc-top5 = {top5:.3f} acc-top10 = {top10:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
)
averages = np.mean(np.array(metrics), axis=0)
if verbose:
print(
# "Averages: Acc, Acc-5, Acc-10, MF1, mF1",
"Averages: Acc, MF1, mF1",
np.round(averages, 3),
"\n",
)
avg_metrics = dict(zip(_metrics, averages))
return avg_metrics, lang_metrics

View File

@ -239,7 +239,3 @@ def microK(true_labels, predicted_labels):
def macroAcc(true_labels, predicted_labels):
return macro_average(true_labels, predicted_labels, accuracy)
def microAcc(true_labels, predicted_labels):
return micro_average(true_labels, predicted_labels, accuracy)

View File

@ -1,17 +1,14 @@
import os
import sys
# sys.path.append(os.path.join(os.getcwd(), "gfun"))
import pickle
import numpy as np
from gfun.vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator
from gfun.vgfs.commons import AttentionAggregator, TfidfVectorizerMultilingual, predict
from gfun.vgfs.learners.svms import MetaClassifier, get_learner
from gfun.vgfs.multilingualGen import MultilingualGen
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
from gfun.vgfs.vanillaFun import VanillaFunGen
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
from gfun.vgfs.wceGen import WceGen
@ -25,11 +22,14 @@ class GeneralizedFunnelling:
visual_transformer,
langs,
num_labels,
classification_type,
embed_dir,
n_jobs,
batch_size,
eval_batch_size,
max_length,
lr,
textual_lr,
visual_lr,
epochs,
patience,
evaluate_step,
@ -47,26 +47,31 @@ class GeneralizedFunnelling:
self.posteriors_vgf = posterior
self.wce_vgf = wce
self.multilingual_vgf = multilingual
self.trasformer_vgf = textual_transformer
self.visual_transformer_vgf = visual_transformer
self.textual_trf_vgf = textual_transformer
self.visual_trf_vgf = visual_transformer
self.probabilistic = probabilistic
self.num_labels = num_labels
self.clf_type = classification_type
# ------------------------
self.langs = langs
self.embed_dir = embed_dir
self.cached = True
# Textual Transformer VGF params ----------
self.textaul_transformer_name = textual_transformer_name
self.textual_trf_name = textual_transformer_name
self.epochs = epochs
self.lr_transformer = lr
self.batch_size_transformer = batch_size
self.textual_trf_lr = textual_lr
self.textual_scheduler = "ReduceLROnPlateau"
self.batch_size_trf = batch_size
self.eval_batch_size_trf = eval_batch_size
self.max_length = max_length
self.early_stopping = True
self.patience = patience
self.evaluate_step = evaluate_step
self.device = device
# Visual Transformer VGF params ----------
self.visual_transformer_name = visual_transformer_name
self.visual_trf_name = visual_transformer_name
self.visual_trf_lr = visual_lr
self.visual_scheduler = "ReduceLROnPlateau"
# Metaclassifier params ------------
self.optimc = optimc
# -------------------
@ -77,7 +82,7 @@ class GeneralizedFunnelling:
self.aggfunc = aggfunc
self.load_trained = load_trained
self.load_first_tier = (
True # TODO: i guess we're always going to load at least the fitst tier
True # TODO: i guess we're always going to load at least the first tier
)
self.load_meta = load_meta
self.dataset_name = dataset_name
@ -112,7 +117,7 @@ class GeneralizedFunnelling:
self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
out_dim=self.num_labels,
lr=self.lr_transformer,
lr=self.textual_trf_lr,
patience=self.patience,
num_heads=1,
device=self.device,
@ -142,13 +147,15 @@ class GeneralizedFunnelling:
wce_vgf = WceGen(n_jobs=self.n_jobs)
self.first_tier_learners.append(wce_vgf)
if self.trasformer_vgf:
if self.textual_trf_vgf:
transformer_vgf = TextualTransformerGen(
dataset_name=self.dataset_name,
model_name=self.textaul_transformer_name,
lr=self.lr_transformer,
model_name=self.textual_trf_name,
lr=self.textual_trf_lr,
scheduler=self.textual_scheduler,
epochs=self.epochs,
batch_size=self.batch_size_transformer,
batch_size=self.batch_size_trf,
batch_size_eval=self.eval_batch_size_trf,
max_length=self.max_length,
print_steps=50,
probabilistic=self.probabilistic,
@ -156,21 +163,24 @@ class GeneralizedFunnelling:
verbose=True,
patience=self.patience,
device=self.device,
classification_type=self.clf_type,
)
self.first_tier_learners.append(transformer_vgf)
if self.visual_transformer_vgf:
if self.visual_trf_vgf:
visual_trasformer_vgf = VisualTransformerGen(
dataset_name=self.dataset_name,
model_name="vit",
lr=1e-5, # self.lr_visual_transformer,
lr=self.visual_trf_lr,
scheduler=self.visual_scheduler,
epochs=self.epochs,
batch_size=32, # self.batch_size_visual_transformer,
# batch_size_eval=128,
batch_size=self.batch_size_trf,
batch_size_eval=self.eval_batch_size_trf,
probabilistic=self.probabilistic,
evaluate_step=self.evaluate_step,
patience=self.patience,
device=self.device,
classification_type=self.clf_type,
)
self.first_tier_learners.append(visual_trasformer_vgf)
@ -179,7 +189,7 @@ class GeneralizedFunnelling:
self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
out_dim=self.num_labels,
lr=self.lr_transformer,
lr=self.textual_trf_lr,
patience=self.patience,
num_heads=1,
device=self.device,
@ -198,7 +208,8 @@ class GeneralizedFunnelling:
self.posteriors_vgf,
self.multilingual_vgf,
self.wce_vgf,
self.trasformer_vgf,
self.textual_trf_vgf,
self.visual_trf_vgf,
self.aggfunc,
)
print(f"- model id: {self._model_id}")
@ -251,10 +262,9 @@ class GeneralizedFunnelling:
projections.append(l_posteriors)
agg = self.aggregate(projections)
l_out = self.metaclassifier.predict_proba(agg)
# converting to binary predictions
# if self.dataset_name in ["cls"]: # TODO: better way to do this
# for lang, preds in l_out.items():
# l_out[lang] = np.expand_dims(np.argmax(preds, axis=1), axis=1)
if self.clf_type == "singlelabel":
for lang, preds in l_out.items():
l_out[lang] = predict(preds, clf_type=self.clf_type)
return l_out
def fit_transform(self, lX, lY):
@ -303,15 +313,21 @@ class GeneralizedFunnelling:
return aggregated
def get_config(self):
print("\n")
print("-" * 50)
print("[GeneralizedFunnelling config]")
print(f"- model trained on langs: {self.langs}")
print("-- View Generating Functions configurations:\n")
c = {}
for vgf in self.first_tier_learners:
print(vgf)
print("-" * 50)
vgf_config = vgf.get_config()
c.update(vgf_config)
gfun_config = {
"id": self._model_id,
"aggfunc": self.aggfunc,
"optimc": self.optimc,
"dataset": self.dataset_name,
}
c["gFun"] = gfun_config
return c
def save(self, save_first_tier=True, save_meta=True):
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")
@ -334,7 +350,7 @@ class GeneralizedFunnelling:
pickle.dump(self.metaclassifier, f)
return
def save_first_tier_learners(self):
def save_first_tier_learners(self, model_id):
for vgf in self.first_tier_learners:
vgf.save_vgf(model_id=self._model_id)
return self
@ -372,7 +388,7 @@ class GeneralizedFunnelling:
"rb",
) as vgf:
first_tier_learners.append(pickle.load(vgf))
if self.trasformer_vgf:
if self.textual_trf_vgf:
with open(
os.path.join(
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
@ -427,7 +443,15 @@ def get_params(optimc=False):
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfunc):
def get_unique_id(
dataset_name,
posterior,
multilingual,
wce,
textual_transformer,
visual_transformer,
aggfunc,
):
from datetime import datetime
now = datetime.now().strftime("%y%m%d")
@ -435,6 +459,7 @@ def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfu
model_id += "p" if posterior else ""
model_id += "m" if multilingual else ""
model_id += "w" if wce else ""
model_id += "t" if transformer else ""
model_id += "t" if textual_transformer else ""
model_id += "v" if visual_transformer else ""
model_id += f"_{aggfunc}"
return f"{model_id}_{now}"

View File

@ -9,9 +9,11 @@ from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from transformers.modeling_outputs import ModelOutput
import wandb
from evaluation.evaluate import evaluate, log_eval
PRINT_ON_EPOCH = 1
@ -21,6 +23,28 @@ def _normalize(lX, l2=True):
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
def verbosity_eval(epoch, print_eval):
if (epoch + 1) % print_eval == 0 and epoch != 0:
return True
else:
return False
def format_langkey_wandb(lang_dict, vgf_name):
log_dict = {}
for metric, l_dict in lang_dict.items():
for lang, value in l_dict.items():
log_dict[f"{vgf_name}/language metric/{metric}/{lang}"] = value
return log_dict
def format_average_wandb(avg_dict, vgf_name):
log_dict = {}
for metric, value in avg_dict.items():
log_dict[f"{vgf_name}/average metric/{metric}"] = value
return log_dict
def XdotM(X, M, sif):
E = X.dot(M)
if sif:
@ -57,18 +81,23 @@ def compute_pc(X, npc=1):
return svd.components_
def predict(logits, classification_type="multilabel"):
def predict(logits, clf_type="multilabel"):
"""
Converts soft precictions to hard predictions [0,1]
"""
if classification_type == "multilabel":
if clf_type == "multilabel":
prediction = torch.sigmoid(logits) > 0.5
elif classification_type == "singlelabel":
prediction = torch.argmax(logits, dim=1).view(-1, 1)
else:
print("unknown classification type")
return prediction.detach().cpu().numpy()
elif clf_type == "singlelabel":
if type(logits) != torch.Tensor:
logits = torch.tensor(logits)
prediction = torch.softmax(logits, dim=1)
prediction = prediction.detach().cpu().numpy()
_argmaxs = prediction.argmax(axis=1)
prediction = np.eye(prediction.shape[1])[_argmaxs]
return prediction
else:
raise NotImplementedError()
class TfidfVectorizerMultilingual:
@ -114,63 +143,138 @@ class Trainer:
patience,
experiment_name,
checkpoint_path,
classification_type,
vgf_name,
n_jobs,
scheduler_name=None,
):
self.device = device
self.model = model.to(device)
self.optimizer = self.init_optimizer(optimizer_name, lr)
self.optimizer, self.scheduler = self.init_optimizer(
optimizer_name, lr, scheduler_name
)
self.evaluate_steps = evaluate_step
self.loss_fn = loss_fn.to(device)
self.print_steps = print_steps
self.experiment_name = experiment_name
self.patience = patience
self.print_eval = evaluate_step
self.print_eval = 10
self.earlystopping = EarlyStopping(
patience=patience,
checkpoint_path=checkpoint_path,
verbose=False,
experiment_name=experiment_name,
)
self.clf_type = classification_type
self.vgf_name = vgf_name
self.scheduler_name = scheduler_name
self.n_jobs = n_jobs
self.monitored_metric = (
"macro-F1" if self.clf_type == "multilabel" else "accuracy"
) # TODO: make this configurable
def init_optimizer(self, optimizer_name, lr):
def init_optimizer(self, optimizer_name, lr, scheduler_name):
if optimizer_name.lower() == "adamw":
return AdamW(self.model.parameters(), lr=lr)
optim = AdamW(self.model.parameters(), lr=lr)
else:
raise ValueError(f"Optimizer {optimizer_name} not supported")
if scheduler_name is None:
scheduler = None
elif scheduler_name == "ReduceLROnPlateau":
scheduler = ReduceLROnPlateau(optim, "max", factor=0.5, min_lr=1e-5)
else:
raise ValueError(f"Scheduler {scheduler_name} not supported")
return optim, scheduler
def get_config(self, train_dataloader, eval_dataloader, epochs):
return {
"model name": self.model.name_or_path
if not hasattr(self.model, "mt5encoder")
else self.model.mt5encoder.name_or_path,
"epochs": epochs,
"learning rate": self.optimizer.defaults["lr"],
"scheduler": self.scheduler_name, # TODO: add scheduler params
"train size": len(train_dataloader.dataset),
"eval size": len(eval_dataloader.dataset),
"train batch size": train_dataloader.batch_size,
"eval batch size": eval_dataloader.batch_size,
"max len": train_dataloader.dataset.X.shape[-1],
"patience": self.earlystopping.patience,
"evaluate every": self.evaluate_steps,
"print eval every": self.print_eval,
"print train steps": self.print_steps,
"classification type": self.clf_type,
}
def train(self, train_dataloader, eval_dataloader, epochs=10):
print(
f"""- Training params for {self.experiment_name}:
- epochs: {epochs}
- learning rate: {self.optimizer.defaults['lr']}
- train batch size: {train_dataloader.batch_size}
- eval batch size: {eval_dataloader.batch_size}
- max len: {train_dataloader.dataset.X.shape[-1]}
- patience: {self.earlystopping.patience}
- evaluate every: {self.evaluate_steps}
- print eval every: {self.print_eval}
- print train steps: {self.print_steps}\n"""
)
_config = self.get_config(train_dataloader, eval_dataloader, epochs)
print(f"- Training params for {self.experiment_name}:")
for k, v in _config.items():
print(f"\t{k}: {v}")
for epoch in range(epochs):
self.train_epoch(train_dataloader, epoch)
if (epoch + 1) % self.evaluate_steps == 0:
print_eval = (epoch + 1) % self.print_eval == 0
metric_watcher = self.evaluate(eval_dataloader, print_eval=print_eval)
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
train_loss = self.train_epoch(train_dataloader, epoch)
if (epoch + 1) % self.evaluate_steps == 0 or (epoch + 1) == 1:
print_eval = verbosity_eval(epoch, self.print_eval)
with torch.no_grad():
eval_loss, avg_metrics, lang_metrics = self.evaluate(
eval_dataloader,
print_eval=print_eval,
n_jobs=self.n_jobs,
)
wandb.log(
{
f"{self.vgf_name}/loss/val": eval_loss,
**format_langkey_wandb(lang_metrics, self.vgf_name),
**format_average_wandb(avg_metrics, self.vgf_name),
},
commit=False,
)
stop = self.earlystopping(
avg_metrics[self.monitored_metric], self.model, epoch + 1
)
if stop:
print(
f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}"
)
self.model = self.earlystopping.load_model(self.model).to(
self.device
)
restored_model = self.earlystopping.load_model(self.model)
# swapping model on gpu
del self.model
self.model = restored_model.to(self.device)
break
if self.scheduler is not None:
self.scheduler.step(avg_metrics[self.monitored_metric])
wandb.log(
{
f"{self.vgf_name}/loss/train": train_loss,
f"{self.vgf_name}/learning rate": self.optimizer.param_groups[0][
"lr"
],
}
)
print(f"- last swipe on eval set")
self.train_epoch(eval_dataloader, epoch=0)
self.train_epoch(
DataLoader(
eval_dataloader.dataset,
batch_size=train_dataloader.batch_size,
shuffle=True,
),
epoch=-1,
)
self.earlystopping.save_model(self.model)
return self.model
def train_epoch(self, dataloader, epoch):
self.model.train()
batch_losses = []
for b_idx, (x, y, lang) in enumerate(dataloader):
self.optimizer.zero_grad()
y_hat = self.model(x.to(self.device))
@ -180,37 +284,47 @@ class Trainer:
loss = self.loss_fn(y_hat, y.to(self.device))
loss.backward()
self.optimizer.step()
batch_losses.append(loss.item())
if (epoch + 1) % PRINT_ON_EPOCH == 0:
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
return self
print(
f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {np.mean(batch_losses):.4f}"
)
return np.mean(batch_losses)
def evaluate(self, dataloader, print_eval=True):
def evaluate(self, dataloader, print_eval=True, n_jobs=-1):
self.model.eval()
eval_losses = []
lY = defaultdict(list)
lY_hat = defaultdict(list)
lY_true = defaultdict(list)
lY_pred = defaultdict(list)
for b_idx, (x, y, lang) in enumerate(dataloader):
y_hat = self.model(x.to(self.device))
if isinstance(y_hat, ModelOutput):
loss = self.loss_fn(y_hat.logits, y.to(self.device))
predictions = predict(y_hat.logits, classification_type="multilabel")
y_pred = self.model(x.to(self.device))
if isinstance(y_pred, ModelOutput):
loss = self.loss_fn(y_pred.logits, y.to(self.device))
predictions = predict(y_pred.logits, clf_type=self.clf_type)
else:
loss = self.loss_fn(y_hat, y.to(self.device))
predictions = predict(y_hat, classification_type="multilabel")
loss = self.loss_fn(y_pred, y.to(self.device))
predictions = predict(y_pred, clf_type=self.clf_type)
eval_losses.append(loss.item())
for l, _true, _pred in zip(lang, y, predictions):
lY[l].append(_true.detach().cpu().numpy())
lY_hat[l].append(_pred)
lY_true[l].append(_true.detach().cpu().numpy())
lY_pred[l].append(_pred)
for lang in lY:
lY[lang] = np.vstack(lY[lang])
lY_hat[lang] = np.vstack(lY_hat[lang])
for lang in lY_true:
lY_true[lang] = np.vstack(lY_true[lang])
lY_pred[lang] = np.vstack(lY_pred[lang])
l_eval = evaluate(lY, lY_hat)
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
return average_metrics[0] # macro-F1
l_eval = evaluate(lY_true, lY_pred, clf_type=self.clf_type, n_jobs=n_jobs)
avg_metrics, lang_metrics = log_eval(
l_eval, phase="validation", clf_type=self.clf_type, verbose=print_eval
)
return np.mean(eval_losses), avg_metrics, lang_metrics
class EarlyStopping:
@ -232,7 +346,8 @@ class EarlyStopping:
self.experiment_name = experiment_name
def __call__(self, validation, model, epoch):
if validation > self.best_score:
if validation >= self.best_score:
wandb.log({"patience": self.patience - self.counter})
if self.verbose:
print(
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
@ -244,11 +359,12 @@ class EarlyStopping:
self.save_model(model)
elif validation < (self.best_score + self.min_delta):
self.counter += 1
wandb.log({"patience": self.patience - self.counter})
if self.verbose:
print(
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
)
if self.counter >= self.patience:
if self.counter >= self.patience and self.patience != -1:
print(f"- earlystopping: Early stopping at epoch {epoch}")
return True

View File

@ -104,10 +104,6 @@ class MultilingualGen(ViewGen):
pickle.dump(self, f)
return self
def __str__(self):
_str = f"[Multilingual VGF (m)]\n- embed_dir: {self.embed_dir}\n- langs: {self.langs}\n- n_jobs: {self.n_jobs}\n- cached: {self.cached}\n- sif: {self.sif}\n- probabilistic: {self.probabilistic}\n"
return _str
def load_MUSEs(langs, l_vocab, dir_path, cached=False):
dir_path = expanduser(dir_path)

View File

@ -6,20 +6,50 @@ from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import transformers
# from sklearn.model_selection import train_test_split
# from torch.optim import AdamW
from torch.utils.data import Dataset
from transformers import MT5EncoderModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers.modeling_outputs import ModelOutput
from gfun.vgfs.commons import Trainer
from gfun.vgfs.transformerGen import TransformerGen
from gfun.vgfs.viewGen import ViewGen
from dataManager.torchDataset import MultilingualDatasetTorch
transformers.logging.set_verbosity_error()
# TODO: add support to loggers
class MT5ForSequenceClassification(nn.Module):
def __init__(self, model_name, num_labels, output_hidden_states):
super().__init__()
self.output_hidden_states = output_hidden_states
self.mt5encoder = MT5EncoderModel.from_pretrained(
model_name, output_hidden_states=True
)
self.dropout = nn.Dropout(0.1)
self.linear = nn.Linear(512, num_labels)
def forward(self, input_ids):
embed = self.mt5encoder(input_ids=input_ids)
pooled = torch.mean(embed.last_hidden_state, dim=1)
outputs = self.dropout(pooled)
logits = self.linear(outputs)
if self.output_hidden_states:
return ModelOutput(
logits=logits,
pooled=pooled,
)
return ModelOutput(logits=logits)
def save_pretrained(self, checkpoint_dir):
torch.save(self.state_dict(), checkpoint_dir + ".pt")
return
def from_pretrained(self, checkpoint_dir):
checkpoint_dir += ".pt"
return self.load_state_dict(torch.load(checkpoint_dir))
class TextualTransformerGen(ViewGen, TransformerGen):
@ -39,23 +69,27 @@ class TextualTransformerGen(ViewGen, TransformerGen):
evaluate_step=10,
verbose=False,
patience=5,
classification_type="multilabel",
scheduler="ReduceLROnPlateau",
):
super().__init__(
self._validate_model_name(model_name),
dataset_name,
epochs,
lr,
batch_size,
batch_size_eval,
max_length,
print_steps,
device,
probabilistic,
n_jobs,
evaluate_step,
verbose,
patience,
epochs=epochs,
lr=lr,
scheduler=scheduler,
batch_size=batch_size,
batch_size_eval=batch_size_eval,
device=device,
evaluate_step=evaluate_step,
patience=patience,
probabilistic=probabilistic,
max_length=max_length,
print_steps=print_steps,
n_jobs=n_jobs,
verbose=verbose,
)
self.clf_type = classification_type
self.fitted = False
print(
f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]"
@ -66,12 +100,19 @@ class TextualTransformerGen(ViewGen, TransformerGen):
return "bert-base-uncased"
elif "mbert" == model_name:
return "bert-base-multilingual-uncased"
elif "xlm" == model_name:
elif "xlm-roberta" == model_name:
return "xlm-roberta-base"
elif "mt5" == model_name:
return "google/mt5-small"
else:
raise NotImplementedError
def load_pretrained_model(self, model_name, num_labels):
if model_name == "google/mt5-small":
return MT5ForSequenceClassification(
model_name, num_labels=num_labels, output_hidden_states=True
)
else:
return AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, output_hidden_states=True
)
@ -127,9 +168,8 @@ class TextualTransformerGen(ViewGen, TransformerGen):
shuffle=False,
)
experiment_name = (
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
)
experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
trainer = Trainer(
model=self.model,
optimizer_name="adamW",
@ -140,7 +180,16 @@ class TextualTransformerGen(ViewGen, TransformerGen):
evaluate_step=self.evaluate_step,
patience=self.patience,
experiment_name=experiment_name,
checkpoint_path="models/vgfs/transformer",
checkpoint_path=os.path.join(
"models",
"vgfs",
"transformer",
self._format_model_name(self.model_name),
),
vgf_name="textual_trf",
classification_type=self.clf_type,
n_jobs=self.n_jobs,
scheduler_name=self.scheduler,
)
trainer.train(
train_dataloader=tra_dataloader,
@ -175,6 +224,10 @@ class TextualTransformerGen(ViewGen, TransformerGen):
with torch.no_grad():
for input_ids, lang in dataloader:
input_ids = input_ids.to(self.device)
# TODO: check this
if isinstance(self.model, MT5ForSequenceClassification):
batch_embeddings = self.model(input_ids).pooled.cpu().numpy()
else:
out = self.model(input_ids).hidden_states[-1]
batch_embeddings = out[:, 0, :].cpu().numpy()
_embeds.append((batch_embeddings, lang))
@ -206,39 +259,22 @@ class TextualTransformerGen(ViewGen, TransformerGen):
pickle.dump(self, f)
return self
def __str__(self):
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
return str
def freeze_model(self):
# TODO: up to n-layers? or all? avoid freezing head ovb...
for param in self.model.parameters():
param.requires_grad = False
def _format_model_name(self, model_name):
if "mt5" in model_name:
return "google-mt5"
elif "bert" in model_name:
if "multilingual" in model_name:
return "mbert"
elif "xlm-roberta" in model_name:
return "xlm-roberta"
else:
return model_name
class MultilingualDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
}.values()
],
[],
)
return self
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]
def get_config(self):
c = super().get_config()
return {"textual_trf": c}

View File

@ -26,6 +26,7 @@ class TransformerGen:
evaluate_step=10,
verbose=False,
patience=5,
scheduler=None,
):
self.model_name = model_name
self.dataset_name = dataset_name
@ -46,6 +47,7 @@ class TransformerGen:
self.verbose = verbose
self.patience = patience
self.datasets = {}
self.scheduler = scheduler
self.feature2posterior_projector = (
self.make_probabilistic() if probabilistic else None
)
@ -94,3 +96,22 @@ class TransformerGen:
val_lY[lang] = val_Y
return tr_lX, tr_lY, val_lX, val_lY
def get_config(self):
return {
"model_name": self.model_name,
"dataset_name": self.dataset_name,
"epochs": self.epochs,
"lr": self.lr,
"scheduler": self.scheduler,
"batch_size": self.batch_size,
"batch_size_eval": self.batch_size_eval,
"max_length": self.max_length,
"print_steps": self.print_steps,
"device": self.device,
"probabilistic": self.probabilistic,
"n_jobs": self.n_jobs,
"evaluate_step": self.evaluate_step,
"verbose": self.verbose,
"patience": self.patience,
}

View File

@ -65,8 +65,3 @@ class VanillaFunGen(ViewGen):
with open(_path, "wb") as f:
pickle.dump(self, f)
return self
def __str__(self):
_str = f"[VanillaFunGen (-p)]\n- base learner: {self.learners}\n- n_jobs: {self.n_jobs}\n"
# - parameters: {self.first_tier_parameters}
return _str

View File

@ -4,12 +4,12 @@ import numpy as np
import torch
import transformers
from PIL import Image
from torch.utils.data import Dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification
from gfun.vgfs.commons import Trainer
from gfun.vgfs.transformerGen import TransformerGen
from gfun.vgfs.viewGen import ViewGen
from dataManager.torchDataset import MultimodalDatasetTorch
transformers.logging.set_verbosity_error()
@ -20,6 +20,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
model_name,
dataset_name,
lr=1e-5,
scheduler="ReduceLROnPlateau",
epochs=10,
batch_size=32,
batch_size_eval=128,
@ -27,12 +28,14 @@ class VisualTransformerGen(ViewGen, TransformerGen):
device="cpu",
probabilistic=False,
patience=5,
classification_type="multilabel",
):
super().__init__(
model_name,
dataset_name,
lr=lr,
epochs=epochs,
lr=lr,
scheduler=scheduler,
batch_size=batch_size,
batch_size_eval=batch_size_eval,
device=device,
@ -40,6 +43,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
patience=patience,
probabilistic=probabilistic,
)
self.clf_type = classification_type
self.fitted = False
print(
f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]"
@ -97,7 +101,10 @@ class VisualTransformerGen(ViewGen, TransformerGen):
shuffle=False,
)
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}"
experiment_name = (
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
)
trainer = Trainer(
model=self.model,
optimizer_name="adamW",
@ -109,6 +116,9 @@ class VisualTransformerGen(ViewGen, TransformerGen):
patience=self.patience,
experiment_name=experiment_name,
checkpoint_path="models/vgfs/transformer",
vgf_name="visual_trf",
classification_type=self.clf_type,
n_jobs=self.n_jobs,
)
trainer.train(
@ -175,66 +185,5 @@ class VisualTransformerGen(ViewGen, TransformerGen):
pickle.dump(self, f)
return self
def __str__(self):
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
return str
class MultimodalDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([imgs for imgs in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data) for lang, data in self.lX.items()
}.values()
],
[],
)
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]
if __name__ == "__main__":
from os.path import expanduser
from dataManager.gFunDataset import gFunDataset
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
dataset = gFunDataset(
dataset_dir=GLAMI_DATAPATH,
is_textual=True,
is_visual=True,
is_multilabel=False,
nrows=50,
)
vg = VisualTransformerGen(
dataset_name=dataset.dataset_name,
model_name="vit",
device="cuda",
epochs=5,
evaluate_step=10,
patience=10,
probabilistic=True,
)
lX, lY = dataset.training()
vg.fit(lX, lY)
out = vg.transform(lX)
exit(0)
def get_config(self):
return {"visual_trf": super().get_config()}

View File

@ -40,10 +40,6 @@ class WceGen(ViewGen):
"sif": self.sif,
}
def __str__(self):
_str = f"[WordClass VGF (w)]\n- sif: {self.sif}\n- n_jobs: {self.n_jobs}\n"
return _str
def save_vgf(self, model_id):
import pickle
from os.path import join

105
main.py
View File

@ -1,3 +1,8 @@
import os
import wandb
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from argparse import ArgumentParser
from time import time
@ -7,18 +12,36 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
"""
TODO:
- [!] add support for Binary Datasets (e.g. cls) - NB: CLS dataset is loading only "books" domain data
- [!] documents should be trimmed to the same length (?)
- [!] logging
- add documentations sphinx
- [!] zero-shot setup
- Transformers VGFs:
- scheduler with warmup and cosine
- freeze params method
- General:
[!] zero-shot setup
- CLS dataset is loading only "books" domain data
- documents should be trimmed to the same length (for SVMs we are using way too long tokens)
- Attention Aggregator:
- experiment with weight init of Attention-aggregator
- FFNN posterior-probabilities' dependent
- re-init langs when loading VGFs?
- [!] loss of Attention-aggregator seems to be uncorrelated with Macro-F1 on the validation set!
- [!] experiment with weight init of Attention-aggregator
- Docs:
- add documentations sphinx
"""
def get_config_name(args):
config_name = ""
if args.posteriors:
config_name += "P+"
if args.wce:
config_name += "W+"
if args.multilingual:
config_name += "M+"
if args.textual_transformer:
config_name += f"TT_{args.textual_trf_name}+"
if args.visual_transformer:
config_name += f"VT_{args.visual_trf_name}+"
return config_name.rstrip("+")
def main(args):
dataset = get_dataset(args.dataset, args)
lX, lY = dataset.training()
@ -43,6 +66,7 @@ def main(args):
dataset_name=args.dataset,
langs=dataset.langs(),
num_labels=dataset.num_labels(),
classification_type=args.clf_type,
# Posterior VGF params ----------------
posterior=args.posteriors,
# Multilingual VGF params -------------
@ -52,24 +76,26 @@ def main(args):
wce=args.wce,
# Transformer VGF params --------------
textual_transformer=args.textual_transformer,
textual_transformer_name=args.transformer_name,
textual_transformer_name=args.textual_trf_name,
batch_size=args.batch_size,
eval_batch_size=args.eval_batch_size,
epochs=args.epochs,
lr=args.lr,
textual_lr=args.textual_lr,
visual_lr=args.visual_lr,
max_length=args.max_length,
patience=args.patience,
evaluate_step=args.evaluate_step,
device=args.device,
# Visual Transformer VGF params --------------
visual_transformer=args.visual_transformer,
visual_transformer_name=args.visual_transformer_name,
visual_transformer_name=args.visual_trf_name,
# batch_size=args.batch_size,
# epochs=args.epochs,
# lr=args.lr,
# patience=args.patience,
# evaluate_step=args.evaluate_step,
# device="cuda",
# General params ----------------------
# General params ---------------------
probabilistic=args.features,
aggfunc=args.aggfunc,
optimc=args.optimc,
@ -78,27 +104,54 @@ def main(args):
n_jobs=args.n_jobs,
)
# gfun.get_config()
config = gfun.get_config()
wandb.init(project="gfun", name=f"gFun-{get_config_name(args)}", config=config)
gfun.fit(lX, lY)
if args.load_trained is None and not args.nosave:
gfun.save(save_first_tier=True, save_meta=True)
# print("- Computing evaluation on training set")
# preds = gfun.transform(lX)
# train_eval = evaluate(lY, preds)
# log_eval(train_eval, phase="train")
timetr = time()
print(f"- training completed in {timetr - tinit:.2f} seconds")
gfun_preds = gfun.transform(lX_te)
test_eval = evaluate(lY_te, gfun_preds)
log_eval(test_eval, phase="test")
test_eval = evaluate(lY_te, gfun_preds, clf_type=args.clf_type, n_jobs=args.n_jobs)
avg_metrics_gfun, lang_metrics_gfun = log_eval(
test_eval, phase="test", clf_type=args.clf_type
)
timeval = time()
print(f"- testing completed in {timeval - timetr:.2f} seconds")
def log_barplot_wandb(gfun_res, title_affix="per langauge"):
if title_affix == "per language":
for metric, lang_values in gfun_res.items():
data = [[lang, v] for lang, v in lang_values.items()]
table = wandb.Table(data=data, columns=["lang", f"{metric}"])
wandb.log(
{
f"gFun/language {metric}": wandb.plot.bar(
table, "lang", metric, title=f"{metric} {title_affix}"
)
}
)
else:
data = [[metric, value] for metric, value in gfun_res.items()]
table = wandb.Table(data=data, columns=["metric", "value"])
wandb.log(
{
f"gFun/average metric": wandb.plot.bar(
table, "metric", "value", title=f"metric {title_affix}"
)
}
)
wandb.log(gfun_res)
log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
log_barplot_wandb(avg_metrics_gfun, title_affix="averages")
if __name__ == "__main__":
parser = ArgumentParser()
@ -112,6 +165,8 @@ if __name__ == "__main__":
parser.add_argument("--nrows", type=int, default=None)
parser.add_argument("--min_count", type=int, default=10)
parser.add_argument("--max_labels", type=int, default=50)
parser.add_argument("--clf_type", type=str, default="multilabel")
parser.add_argument("--save_dataset", action="store_true")
# gFUN parameters ----------------------
parser.add_argument("-p", "--posteriors", action="store_true")
parser.add_argument("-m", "--multilingual", action="store_true")
@ -123,15 +178,17 @@ if __name__ == "__main__":
parser.add_argument("--features", action="store_false")
parser.add_argument("--aggfunc", type=str, default="mean")
# transformer parameters ---------------
parser.add_argument("--transformer_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--textual_trf_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--eval_batch_size", type=int, default=128)
parser.add_argument("--textual_lr", type=float, default=1e-4)
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--patience", type=int, default=5)
parser.add_argument("--evaluate_step", type=int, default=10)
# Visual Transformer parameters --------------
parser.add_argument("--visual_transformer_name", type=str, default="vit")
parser.add_argument("--visual_trf_name", type=str, default="vit")
parser.add_argument("--visual_lr", type=float, default=1e-4)
args = parser.parse_args()

View File

@ -1,13 +1,13 @@
beautifulsoup4==4.11.2
joblib==1.2.0
matplotlib==3.7.1
numpy==1.24.2
matplotlib==3.6.3
numpy==1.24.1
pandas==1.5.3
Pillow==9.4.0
requests==2.28.2
scikit_learn==1.2.1
scikit_learn==1.2.2
scipy==1.10.1
torch==1.13.1
torchtext==0.14.1
tqdm==4.65.0
transformers==4.26.1
tqdm==4.64.1
transformers==4.26.0