Compare commits

..

22 Commits

Author SHA1 Message Date
Andrea Pedrotti ab7a310b34 todo updates 2023-03-17 10:44:45 +01:00
Andrea Pedrotti 41647f974a last training swipe on eval set is now performed on batch size equal to the training set batch size 2023-03-17 10:44:23 +01:00
Andrea Pedrotti ee2a9481de sampling GLAMI1-M dataset 2023-03-16 18:10:05 +01:00
Andrea Pedrotti ee38bcda10 fixed TransformerGen init 2023-03-16 12:12:39 +01:00
Andrea Pedrotti b34da419d0 fixed import 2023-03-16 11:49:49 +01:00
Andrea Pedrotti 17d0003e48 getter for gFun and VGFs config 2023-03-16 11:41:40 +01:00
Andrea Pedrotti 9d43ebb23b implemented save/load for MT5ForSequenceClassification. Moved torch Datasets to datamanager module 2023-03-16 10:31:34 +01:00
Andrea Pedrotti 56faaf2615 changed wandb logging to a global level to keep track of all the VGFs and overall gFun 2023-03-15 16:35:49 +01:00
Andrea Pedrotti f32b9227ae TODO: better stratified sampling for GLAMI-1M 2023-03-15 11:48:03 +01:00
Andrea Pedrotti 65407f51fa update trainer to handle mT5 2023-03-15 11:47:17 +01:00
Andrea Pedrotti 26aa0b327a average pooling for MT5ForSequenceClassification and standardized return data 2023-03-15 11:46:53 +01:00
Andrea Pedrotti fece8d059e updated argparse 2023-03-14 11:54:40 +01:00
Andrea Pedrotti 5e41b4517a implemented MT5ForSequenceClassification 2023-03-14 11:53:50 +01:00
Andrea Pedrotti a3e183d7fc avoid duplicating model on gpu when earlystop is triggered 2023-03-14 11:22:00 +01:00
Andrea Pedrotti 57918ec523 save and load datasets as pkl 2023-03-10 12:40:26 +01:00
andreapdr 7d0d6ba1f6 log average metrics via wandb 2023-03-10 11:21:33 +01:00
andreapdr 5ef0904e0e logging average metrics 2023-03-09 17:59:18 +01:00
andreapdr 7e1ec46ebd improved wandb logging 2023-03-09 17:03:17 +01:00
Andrea Pedrotti 3240150542 updated todo 2023-03-07 17:36:21 +01:00
Andrea Pedrotti 84dd1f093e logging via wandb 2023-03-07 17:34:25 +01:00
Andrea Pedrotti 6b7917ca47 typos 2023-03-07 14:33:30 +01:00
andreapdr 7dead90271 logging via wandb 2023-03-07 14:20:56 +01:00
16 changed files with 640 additions and 340 deletions

3
.gitignore vendored
View File

@ -181,4 +181,5 @@ models/*
scripts/ scripts/
logger/* logger/*
explore_data.ipynb explore_data.ipynb
run.sh run.sh
wandb

View File

@ -1,3 +1,5 @@
import os
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from dataManager.glamiDataset import get_dataframe from dataManager.glamiDataset import get_dataframe
from dataManager.multilingualDataset import MultilingualDataset from dataManager.multilingualDataset import MultilingualDataset
@ -22,7 +24,7 @@ class gFunDataset:
self.labels = labels self.labels = labels
self.nrows = nrows self.nrows = nrows
self.dataset = {} self.dataset = {}
self.load_dataset() self._load_dataset()
def get_label_binarizer(self, labels): def get_label_binarizer(self, labels):
if self.dataset_name in ["rcv1-2", "jrc", "cls"]: if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
@ -35,7 +37,7 @@ class gFunDataset:
mlb.fit(labels) mlb.fit(labels)
return mlb return mlb
def load_dataset(self): def _load_dataset(self):
if "glami" in self.dataset_dir.lower(): if "glami" in self.dataset_dir.lower():
print(f"- Loading GLAMI dataset from {self.dataset_dir}") print(f"- Loading GLAMI dataset from {self.dataset_dir}")
self.dataset_name = "glami" self.dataset_name = "glami"
@ -106,44 +108,19 @@ class gFunDataset:
return dataset, labels, data_langs return dataset, labels, data_langs
def _load_glami(self, dataset_dir, nrows): def _load_glami(self, dataset_dir, nrows):
def _balanced_sample(data, n, remainder=0): train_split = get_dataframe("train", dataset_dir=dataset_dir).sample(n=nrows)
import pandas as pd test_split = get_dataframe("test", dataset_dir=dataset_dir).sample(
n=int(nrows / 10)
langs = sorted(data.geo.unique().tolist())
dict_n = {lang: n for lang in langs}
dict_n[langs[0]] += remainder
sampled = []
for lang in langs:
sampled.append(data[data.geo == lang].sample(n=dict_n[lang]))
return pd.concat(sampled, axis=0)
# TODO: set this sampling as determinsitic/dependeing on the seed
lang_nrows = (
nrows // 13 if self.data_langs is None else nrows // len(self.data_langs)
) # GLAMI 1-M has 13 languages
remainder = (
nrows % 13 if self.data_langs is None else nrows % len(self.data_langs)
) )
train_split = get_dataframe("train", dataset_dir=dataset_dir)
train_split = _balanced_sample(train_split, lang_nrows, remainder=remainder)
if self.data_langs is None:
data_langs = sorted(train_split.geo.unique().tolist())
# TODO: if data langs is NOT none then we have a problem where we filter df by langs
if self.labels is None:
labels = train_split.category_name.unique().tolist()
# TODO: atm test data should contain same languages as train data
test_split = get_dataframe("test", dataset_dir=dataset_dir)
# TODO: atm we're using 1:1 train-test
test_split = _balanced_sample(test_split, lang_nrows, remainder=remainder)
gb_train = train_split.groupby("geo") gb_train = train_split.groupby("geo")
gb_test = test_split.groupby("geo") gb_test = test_split.groupby("geo")
if self.data_langs is None:
data_langs = sorted(train_split.geo.unique().tolist())
if self.labels is None:
labels = train_split.category_name.unique().tolist()
def _format_glami(data_df): def _format_glami(data_df):
text = (data_df.name + " " + data_df.description).tolist() text = (data_df.name + " " + data_df.description).tolist()
image = data_df.image_file.tolist() image = data_df.image_file.tolist()
@ -205,6 +182,14 @@ class gFunDataset:
else: else:
return self.labels return self.labels
def save_as_pickle(self, path):
import pickle
filepath = os.path.join(path, f"{self.dataset_name}_{self.nrows}.pkl")
with open(filepath, "wb") as f:
print(f"- saving dataset in {filepath}")
pickle.dump(self, f)
if __name__ == "__main__": if __name__ == "__main__":
import os import os

View File

@ -1,2 +1,66 @@
class TorchMultiNewsDataset: import torch
pass from torch.utils.data import Dataset
class MultilingualDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
}.values()
],
[],
)
return self
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]
class MultimodalDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([imgs for imgs in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data) for lang, data in self.lX.items()
}.values()
],
[],
)
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]

View File

@ -1,9 +1,21 @@
from os.path import expanduser from os.path import expanduser, join
from dataManager.gFunDataset import gFunDataset from dataManager.gFunDataset import gFunDataset
from dataManager.multiNewsDataset import MultiNewsDataset from dataManager.multiNewsDataset import MultiNewsDataset
from dataManager.amazonDataset import AmazonDataset from dataManager.amazonDataset import AmazonDataset
def load_from_pickle(path, dataset_name, nrows):
import pickle
filepath = join(path, f"{dataset_name}_{nrows}.pkl")
with open(filepath, "rb") as f:
loaded = pickle.load(f)
print(f"- Loaded dataset from {filepath}")
loaded.show_dimension()
return loaded
def get_dataset(dataset_name, args): def get_dataset(dataset_name, args):
assert dataset_name in [ assert dataset_name in [
"multinews", "multinews",
@ -58,13 +70,19 @@ def get_dataset(dataset_name, args):
nrows=args.nrows, nrows=args.nrows,
) )
elif dataset_name == "glami": elif dataset_name == "glami":
dataset = gFunDataset( if args.save_dataset is False:
dataset_dir=GLAMI_DATAPATH, dataset = load_from_pickle(GLAMI_DATAPATH, dataset_name, args.nrows)
is_textual=True, else:
is_visual=True, dataset = gFunDataset(
is_multilabel=False, dataset_dir=GLAMI_DATAPATH,
nrows=args.nrows, is_textual=True,
) is_visual=True,
is_multilabel=False,
nrows=args.nrows,
)
dataset.save_as_pickle(GLAMI_DATAPATH)
elif dataset_name == "cls": elif dataset_name == "cls":
dataset = gFunDataset( dataset = gFunDataset(
dataset_dir=CLS_DATAPATH, dataset_dir=CLS_DATAPATH,

View File

@ -1,51 +1,96 @@
from joblib import Parallel, delayed from joblib import Parallel, delayed
from collections import defaultdict
from evaluation.metrics import * from evaluation.metrics import *
from sklearn.metrics import accuracy_score, top_k_accuracy_score, f1_score
def evaluation_metrics(y, y_): def evaluation_metrics(y, y_, clf_type):
if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label if clf_type == "singlelabel":
raise NotImplementedError() return (
else: accuracy_score(y, y_),
# TODO: we need the logits to compute this top_k_accuracy_score(y, y_, k=5),
# TODO: we need logits top_k_accuracy_score(y, y_, k=10),
f1_score(y, y_, average="macro", zero_division=1),
f1_score(y, y_, average="micro"),
)
elif clf_type == "multilabel":
return ( return (
macroF1(y, y_), macroF1(y, y_),
microF1(y, y_), microF1(y, y_),
macroK(y, y_), macroK(y, y_),
microK(y, y_), microK(y, y_),
# macroAcc(y, y_),
microAcc(
y, y_
), # TODO: we're using micro-averaging for accuracy, it is == to accuracy_score on binary classification
) )
else:
raise ValueError("clf_type must be either 'singlelabel' or 'multilabel'")
def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1): def evaluate(
ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1, clf_type="multilabel"
):
if n_jobs == 1: if n_jobs == 1:
return {lang: metrics(ly_true[lang], ly_pred[lang]) for lang in ly_true.keys()} return {
lang: metrics(ly_true[lang], ly_pred[lang], clf_type)
for lang in ly_true.keys()
}
else: else:
langs = list(ly_true.keys()) langs = list(ly_true.keys())
evals = Parallel(n_jobs=n_jobs)( evals = Parallel(n_jobs=n_jobs)(
delayed(metrics)(ly_true[lang], ly_pred[lang]) for lang in langs delayed(metrics)(ly_true[lang], ly_pred[lang], clf_type) for lang in langs
) )
return {lang: evals[i] for i, lang in enumerate(langs)} return {lang: evals[i] for i, lang in enumerate(langs)}
def log_eval(l_eval, phase="training", verbose=True): def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
if verbose: if verbose:
print(f"\n[Results {phase}]") print(f"\n[Results {phase}]")
metrics = [] metrics = []
for lang in l_eval.keys():
macrof1, microf1, macrok, microk, microAcc = l_eval[lang] if clf_type == "multilabel":
metrics.append([macrof1, microf1, macrok, microk, microAcc]) for lang in l_eval.keys():
if phase != "validation": macrof1, microf1, macrok, microk = l_eval[lang]
metrics.append([macrof1, microf1, macrok, microk])
if phase != "validation":
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
averages = np.mean(np.array(metrics), axis=0)
if verbose:
print( print(
f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f} acc = {microAcc:.3f}" "Averages: MF1, mF1, MK, mK",
np.round(averages, 3),
"\n",
) )
averages = np.mean(np.array(metrics), axis=0) return averages # TODO: return a dict avg and lang specific
if verbose:
print( elif clf_type == "singlelabel":
"Averages: MF1, mF1, MK, mK", lang_metrics = defaultdict(dict)
np.round(averages, 3), _metrics = [
"\n", "accuracy",
) # "acc5", # "accuracy-at-5",
return averages # "acc10", # "accuracy-at-10",
"MF1", # "macro-F1",
"mF1", # "micro-F1",
]
for lang in l_eval.keys():
# acc, top5, top10, macrof1, microf1 = l_eval[lang]
acc, macrof1, microf1 = l_eval[lang]
# metrics.append([acc, top5, top10, macrof1, microf1])
metrics.append([acc, macrof1, microf1])
for m, v in zip(_metrics, l_eval[lang]):
lang_metrics[m][lang] = v
if phase != "validation":
print(
# f"Lang {lang}: acc = {acc:.3f} acc-top5 = {top5:.3f} acc-top10 = {top10:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
)
averages = np.mean(np.array(metrics), axis=0)
if verbose:
print(
# "Averages: Acc, Acc-5, Acc-10, MF1, mF1",
"Averages: Acc, MF1, mF1",
np.round(averages, 3),
"\n",
)
avg_metrics = dict(zip(_metrics, averages))
return avg_metrics, lang_metrics

View File

@ -239,7 +239,3 @@ def microK(true_labels, predicted_labels):
def macroAcc(true_labels, predicted_labels): def macroAcc(true_labels, predicted_labels):
return macro_average(true_labels, predicted_labels, accuracy) return macro_average(true_labels, predicted_labels, accuracy)
def microAcc(true_labels, predicted_labels):
return micro_average(true_labels, predicted_labels, accuracy)

View File

@ -1,17 +1,14 @@
import os import os
import sys
# sys.path.append(os.path.join(os.getcwd(), "gfun"))
import pickle import pickle
import numpy as np import numpy as np
from gfun.vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator
from gfun.vgfs.commons import AttentionAggregator, TfidfVectorizerMultilingual, predict
from gfun.vgfs.learners.svms import MetaClassifier, get_learner from gfun.vgfs.learners.svms import MetaClassifier, get_learner
from gfun.vgfs.multilingualGen import MultilingualGen from gfun.vgfs.multilingualGen import MultilingualGen
from gfun.vgfs.textualTransformerGen import TextualTransformerGen from gfun.vgfs.textualTransformerGen import TextualTransformerGen
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
from gfun.vgfs.vanillaFun import VanillaFunGen from gfun.vgfs.vanillaFun import VanillaFunGen
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
from gfun.vgfs.wceGen import WceGen from gfun.vgfs.wceGen import WceGen
@ -25,11 +22,14 @@ class GeneralizedFunnelling:
visual_transformer, visual_transformer,
langs, langs,
num_labels, num_labels,
classification_type,
embed_dir, embed_dir,
n_jobs, n_jobs,
batch_size, batch_size,
eval_batch_size,
max_length, max_length,
lr, textual_lr,
visual_lr,
epochs, epochs,
patience, patience,
evaluate_step, evaluate_step,
@ -47,26 +47,31 @@ class GeneralizedFunnelling:
self.posteriors_vgf = posterior self.posteriors_vgf = posterior
self.wce_vgf = wce self.wce_vgf = wce
self.multilingual_vgf = multilingual self.multilingual_vgf = multilingual
self.trasformer_vgf = textual_transformer self.textual_trf_vgf = textual_transformer
self.visual_transformer_vgf = visual_transformer self.visual_trf_vgf = visual_transformer
self.probabilistic = probabilistic self.probabilistic = probabilistic
self.num_labels = num_labels self.num_labels = num_labels
self.clf_type = classification_type
# ------------------------ # ------------------------
self.langs = langs self.langs = langs
self.embed_dir = embed_dir self.embed_dir = embed_dir
self.cached = True self.cached = True
# Textual Transformer VGF params ---------- # Textual Transformer VGF params ----------
self.textaul_transformer_name = textual_transformer_name self.textual_trf_name = textual_transformer_name
self.epochs = epochs self.epochs = epochs
self.lr_transformer = lr self.textual_trf_lr = textual_lr
self.batch_size_transformer = batch_size self.textual_scheduler = "ReduceLROnPlateau"
self.batch_size_trf = batch_size
self.eval_batch_size_trf = eval_batch_size
self.max_length = max_length self.max_length = max_length
self.early_stopping = True self.early_stopping = True
self.patience = patience self.patience = patience
self.evaluate_step = evaluate_step self.evaluate_step = evaluate_step
self.device = device self.device = device
# Visual Transformer VGF params ---------- # Visual Transformer VGF params ----------
self.visual_transformer_name = visual_transformer_name self.visual_trf_name = visual_transformer_name
self.visual_trf_lr = visual_lr
self.visual_scheduler = "ReduceLROnPlateau"
# Metaclassifier params ------------ # Metaclassifier params ------------
self.optimc = optimc self.optimc = optimc
# ------------------- # -------------------
@ -77,7 +82,7 @@ class GeneralizedFunnelling:
self.aggfunc = aggfunc self.aggfunc = aggfunc
self.load_trained = load_trained self.load_trained = load_trained
self.load_first_tier = ( self.load_first_tier = (
True # TODO: i guess we're always going to load at least the fitst tier True # TODO: i guess we're always going to load at least the first tier
) )
self.load_meta = load_meta self.load_meta = load_meta
self.dataset_name = dataset_name self.dataset_name = dataset_name
@ -112,7 +117,7 @@ class GeneralizedFunnelling:
self.attn_aggregator = AttentionAggregator( self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking), embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
out_dim=self.num_labels, out_dim=self.num_labels,
lr=self.lr_transformer, lr=self.textual_trf_lr,
patience=self.patience, patience=self.patience,
num_heads=1, num_heads=1,
device=self.device, device=self.device,
@ -142,13 +147,15 @@ class GeneralizedFunnelling:
wce_vgf = WceGen(n_jobs=self.n_jobs) wce_vgf = WceGen(n_jobs=self.n_jobs)
self.first_tier_learners.append(wce_vgf) self.first_tier_learners.append(wce_vgf)
if self.trasformer_vgf: if self.textual_trf_vgf:
transformer_vgf = TextualTransformerGen( transformer_vgf = TextualTransformerGen(
dataset_name=self.dataset_name, dataset_name=self.dataset_name,
model_name=self.textaul_transformer_name, model_name=self.textual_trf_name,
lr=self.lr_transformer, lr=self.textual_trf_lr,
scheduler=self.textual_scheduler,
epochs=self.epochs, epochs=self.epochs,
batch_size=self.batch_size_transformer, batch_size=self.batch_size_trf,
batch_size_eval=self.eval_batch_size_trf,
max_length=self.max_length, max_length=self.max_length,
print_steps=50, print_steps=50,
probabilistic=self.probabilistic, probabilistic=self.probabilistic,
@ -156,21 +163,24 @@ class GeneralizedFunnelling:
verbose=True, verbose=True,
patience=self.patience, patience=self.patience,
device=self.device, device=self.device,
classification_type=self.clf_type,
) )
self.first_tier_learners.append(transformer_vgf) self.first_tier_learners.append(transformer_vgf)
if self.visual_transformer_vgf: if self.visual_trf_vgf:
visual_trasformer_vgf = VisualTransformerGen( visual_trasformer_vgf = VisualTransformerGen(
dataset_name=self.dataset_name, dataset_name=self.dataset_name,
model_name="vit", model_name="vit",
lr=1e-5, # self.lr_visual_transformer, lr=self.visual_trf_lr,
scheduler=self.visual_scheduler,
epochs=self.epochs, epochs=self.epochs,
batch_size=32, # self.batch_size_visual_transformer, batch_size=self.batch_size_trf,
# batch_size_eval=128, batch_size_eval=self.eval_batch_size_trf,
probabilistic=self.probabilistic, probabilistic=self.probabilistic,
evaluate_step=self.evaluate_step, evaluate_step=self.evaluate_step,
patience=self.patience, patience=self.patience,
device=self.device, device=self.device,
classification_type=self.clf_type,
) )
self.first_tier_learners.append(visual_trasformer_vgf) self.first_tier_learners.append(visual_trasformer_vgf)
@ -179,7 +189,7 @@ class GeneralizedFunnelling:
self.attn_aggregator = AttentionAggregator( self.attn_aggregator = AttentionAggregator(
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking), embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
out_dim=self.num_labels, out_dim=self.num_labels,
lr=self.lr_transformer, lr=self.textual_trf_lr,
patience=self.patience, patience=self.patience,
num_heads=1, num_heads=1,
device=self.device, device=self.device,
@ -198,7 +208,8 @@ class GeneralizedFunnelling:
self.posteriors_vgf, self.posteriors_vgf,
self.multilingual_vgf, self.multilingual_vgf,
self.wce_vgf, self.wce_vgf,
self.trasformer_vgf, self.textual_trf_vgf,
self.visual_trf_vgf,
self.aggfunc, self.aggfunc,
) )
print(f"- model id: {self._model_id}") print(f"- model id: {self._model_id}")
@ -251,10 +262,9 @@ class GeneralizedFunnelling:
projections.append(l_posteriors) projections.append(l_posteriors)
agg = self.aggregate(projections) agg = self.aggregate(projections)
l_out = self.metaclassifier.predict_proba(agg) l_out = self.metaclassifier.predict_proba(agg)
# converting to binary predictions if self.clf_type == "singlelabel":
# if self.dataset_name in ["cls"]: # TODO: better way to do this for lang, preds in l_out.items():
# for lang, preds in l_out.items(): l_out[lang] = predict(preds, clf_type=self.clf_type)
# l_out[lang] = np.expand_dims(np.argmax(preds, axis=1), axis=1)
return l_out return l_out
def fit_transform(self, lX, lY): def fit_transform(self, lX, lY):
@ -303,15 +313,21 @@ class GeneralizedFunnelling:
return aggregated return aggregated
def get_config(self): def get_config(self):
print("\n") c = {}
print("-" * 50)
print("[GeneralizedFunnelling config]")
print(f"- model trained on langs: {self.langs}")
print("-- View Generating Functions configurations:\n")
for vgf in self.first_tier_learners: for vgf in self.first_tier_learners:
print(vgf) vgf_config = vgf.get_config()
print("-" * 50) c.update(vgf_config)
gfun_config = {
"id": self._model_id,
"aggfunc": self.aggfunc,
"optimc": self.optimc,
"dataset": self.dataset_name,
}
c["gFun"] = gfun_config
return c
def save(self, save_first_tier=True, save_meta=True): def save(self, save_first_tier=True, save_meta=True):
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}") print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")
@ -334,7 +350,7 @@ class GeneralizedFunnelling:
pickle.dump(self.metaclassifier, f) pickle.dump(self.metaclassifier, f)
return return
def save_first_tier_learners(self): def save_first_tier_learners(self, model_id):
for vgf in self.first_tier_learners: for vgf in self.first_tier_learners:
vgf.save_vgf(model_id=self._model_id) vgf.save_vgf(model_id=self._model_id)
return self return self
@ -372,7 +388,7 @@ class GeneralizedFunnelling:
"rb", "rb",
) as vgf: ) as vgf:
first_tier_learners.append(pickle.load(vgf)) first_tier_learners.append(pickle.load(vgf))
if self.trasformer_vgf: if self.textual_trf_vgf:
with open( with open(
os.path.join( os.path.join(
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl" "models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
@ -427,7 +443,15 @@ def get_params(optimc=False):
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}] return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfunc): def get_unique_id(
dataset_name,
posterior,
multilingual,
wce,
textual_transformer,
visual_transformer,
aggfunc,
):
from datetime import datetime from datetime import datetime
now = datetime.now().strftime("%y%m%d") now = datetime.now().strftime("%y%m%d")
@ -435,6 +459,7 @@ def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfu
model_id += "p" if posterior else "" model_id += "p" if posterior else ""
model_id += "m" if multilingual else "" model_id += "m" if multilingual else ""
model_id += "w" if wce else "" model_id += "w" if wce else ""
model_id += "t" if transformer else "" model_id += "t" if textual_transformer else ""
model_id += "v" if visual_transformer else ""
model_id += f"_{aggfunc}" model_id += f"_{aggfunc}"
return f"{model_id}_{now}" return f"{model_id}_{now}"

View File

@ -9,9 +9,11 @@ from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize from sklearn.preprocessing import normalize
from torch.optim import AdamW from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from transformers.modeling_outputs import ModelOutput from transformers.modeling_outputs import ModelOutput
import wandb
from evaluation.evaluate import evaluate, log_eval from evaluation.evaluate import evaluate, log_eval
PRINT_ON_EPOCH = 1 PRINT_ON_EPOCH = 1
@ -21,6 +23,28 @@ def _normalize(lX, l2=True):
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
def verbosity_eval(epoch, print_eval):
if (epoch + 1) % print_eval == 0 and epoch != 0:
return True
else:
return False
def format_langkey_wandb(lang_dict, vgf_name):
log_dict = {}
for metric, l_dict in lang_dict.items():
for lang, value in l_dict.items():
log_dict[f"{vgf_name}/language metric/{metric}/{lang}"] = value
return log_dict
def format_average_wandb(avg_dict, vgf_name):
log_dict = {}
for metric, value in avg_dict.items():
log_dict[f"{vgf_name}/average metric/{metric}"] = value
return log_dict
def XdotM(X, M, sif): def XdotM(X, M, sif):
E = X.dot(M) E = X.dot(M)
if sif: if sif:
@ -57,18 +81,23 @@ def compute_pc(X, npc=1):
return svd.components_ return svd.components_
def predict(logits, classification_type="multilabel"): def predict(logits, clf_type="multilabel"):
""" """
Converts soft precictions to hard predictions [0,1] Converts soft precictions to hard predictions [0,1]
""" """
if classification_type == "multilabel": if clf_type == "multilabel":
prediction = torch.sigmoid(logits) > 0.5 prediction = torch.sigmoid(logits) > 0.5
elif classification_type == "singlelabel": return prediction.detach().cpu().numpy()
prediction = torch.argmax(logits, dim=1).view(-1, 1) elif clf_type == "singlelabel":
if type(logits) != torch.Tensor:
logits = torch.tensor(logits)
prediction = torch.softmax(logits, dim=1)
prediction = prediction.detach().cpu().numpy()
_argmaxs = prediction.argmax(axis=1)
prediction = np.eye(prediction.shape[1])[_argmaxs]
return prediction
else: else:
print("unknown classification type") raise NotImplementedError()
return prediction.detach().cpu().numpy()
class TfidfVectorizerMultilingual: class TfidfVectorizerMultilingual:
@ -114,63 +143,138 @@ class Trainer:
patience, patience,
experiment_name, experiment_name,
checkpoint_path, checkpoint_path,
classification_type,
vgf_name,
n_jobs,
scheduler_name=None,
): ):
self.device = device self.device = device
self.model = model.to(device) self.model = model.to(device)
self.optimizer = self.init_optimizer(optimizer_name, lr) self.optimizer, self.scheduler = self.init_optimizer(
optimizer_name, lr, scheduler_name
)
self.evaluate_steps = evaluate_step self.evaluate_steps = evaluate_step
self.loss_fn = loss_fn.to(device) self.loss_fn = loss_fn.to(device)
self.print_steps = print_steps self.print_steps = print_steps
self.experiment_name = experiment_name self.experiment_name = experiment_name
self.patience = patience self.patience = patience
self.print_eval = evaluate_step self.print_eval = 10
self.earlystopping = EarlyStopping( self.earlystopping = EarlyStopping(
patience=patience, patience=patience,
checkpoint_path=checkpoint_path, checkpoint_path=checkpoint_path,
verbose=False, verbose=False,
experiment_name=experiment_name, experiment_name=experiment_name,
) )
self.clf_type = classification_type
self.vgf_name = vgf_name
self.scheduler_name = scheduler_name
self.n_jobs = n_jobs
self.monitored_metric = (
"macro-F1" if self.clf_type == "multilabel" else "accuracy"
) # TODO: make this configurable
def init_optimizer(self, optimizer_name, lr): def init_optimizer(self, optimizer_name, lr, scheduler_name):
if optimizer_name.lower() == "adamw": if optimizer_name.lower() == "adamw":
return AdamW(self.model.parameters(), lr=lr) optim = AdamW(self.model.parameters(), lr=lr)
else: else:
raise ValueError(f"Optimizer {optimizer_name} not supported") raise ValueError(f"Optimizer {optimizer_name} not supported")
if scheduler_name is None:
scheduler = None
elif scheduler_name == "ReduceLROnPlateau":
scheduler = ReduceLROnPlateau(optim, "max", factor=0.5, min_lr=1e-5)
else:
raise ValueError(f"Scheduler {scheduler_name} not supported")
return optim, scheduler
def get_config(self, train_dataloader, eval_dataloader, epochs):
return {
"model name": self.model.name_or_path
if not hasattr(self.model, "mt5encoder")
else self.model.mt5encoder.name_or_path,
"epochs": epochs,
"learning rate": self.optimizer.defaults["lr"],
"scheduler": self.scheduler_name, # TODO: add scheduler params
"train size": len(train_dataloader.dataset),
"eval size": len(eval_dataloader.dataset),
"train batch size": train_dataloader.batch_size,
"eval batch size": eval_dataloader.batch_size,
"max len": train_dataloader.dataset.X.shape[-1],
"patience": self.earlystopping.patience,
"evaluate every": self.evaluate_steps,
"print eval every": self.print_eval,
"print train steps": self.print_steps,
"classification type": self.clf_type,
}
def train(self, train_dataloader, eval_dataloader, epochs=10): def train(self, train_dataloader, eval_dataloader, epochs=10):
print( _config = self.get_config(train_dataloader, eval_dataloader, epochs)
f"""- Training params for {self.experiment_name}:
- epochs: {epochs} print(f"- Training params for {self.experiment_name}:")
- learning rate: {self.optimizer.defaults['lr']} for k, v in _config.items():
- train batch size: {train_dataloader.batch_size} print(f"\t{k}: {v}")
- eval batch size: {eval_dataloader.batch_size}
- max len: {train_dataloader.dataset.X.shape[-1]}
- patience: {self.earlystopping.patience}
- evaluate every: {self.evaluate_steps}
- print eval every: {self.print_eval}
- print train steps: {self.print_steps}\n"""
)
for epoch in range(epochs): for epoch in range(epochs):
self.train_epoch(train_dataloader, epoch) train_loss = self.train_epoch(train_dataloader, epoch)
if (epoch + 1) % self.evaluate_steps == 0:
print_eval = (epoch + 1) % self.print_eval == 0 if (epoch + 1) % self.evaluate_steps == 0 or (epoch + 1) == 1:
metric_watcher = self.evaluate(eval_dataloader, print_eval=print_eval) print_eval = verbosity_eval(epoch, self.print_eval)
stop = self.earlystopping(metric_watcher, self.model, epoch + 1) with torch.no_grad():
eval_loss, avg_metrics, lang_metrics = self.evaluate(
eval_dataloader,
print_eval=print_eval,
n_jobs=self.n_jobs,
)
wandb.log(
{
f"{self.vgf_name}/loss/val": eval_loss,
**format_langkey_wandb(lang_metrics, self.vgf_name),
**format_average_wandb(avg_metrics, self.vgf_name),
},
commit=False,
)
stop = self.earlystopping(
avg_metrics[self.monitored_metric], self.model, epoch + 1
)
if stop: if stop:
print( print(
f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}" f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}"
) )
self.model = self.earlystopping.load_model(self.model).to( restored_model = self.earlystopping.load_model(self.model)
self.device
) # swapping model on gpu
del self.model
self.model = restored_model.to(self.device)
break break
if self.scheduler is not None:
self.scheduler.step(avg_metrics[self.monitored_metric])
wandb.log(
{
f"{self.vgf_name}/loss/train": train_loss,
f"{self.vgf_name}/learning rate": self.optimizer.param_groups[0][
"lr"
],
}
)
print(f"- last swipe on eval set") print(f"- last swipe on eval set")
self.train_epoch(eval_dataloader, epoch=0) self.train_epoch(
DataLoader(
eval_dataloader.dataset,
batch_size=train_dataloader.batch_size,
shuffle=True,
),
epoch=-1,
)
self.earlystopping.save_model(self.model) self.earlystopping.save_model(self.model)
return self.model return self.model
def train_epoch(self, dataloader, epoch): def train_epoch(self, dataloader, epoch):
self.model.train() self.model.train()
batch_losses = []
for b_idx, (x, y, lang) in enumerate(dataloader): for b_idx, (x, y, lang) in enumerate(dataloader):
self.optimizer.zero_grad() self.optimizer.zero_grad()
y_hat = self.model(x.to(self.device)) y_hat = self.model(x.to(self.device))
@ -180,37 +284,47 @@ class Trainer:
loss = self.loss_fn(y_hat, y.to(self.device)) loss = self.loss_fn(y_hat, y.to(self.device))
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
batch_losses.append(loss.item())
if (epoch + 1) % PRINT_ON_EPOCH == 0: if (epoch + 1) % PRINT_ON_EPOCH == 0:
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0: if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}") print(
return self f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {np.mean(batch_losses):.4f}"
)
return np.mean(batch_losses)
def evaluate(self, dataloader, print_eval=True): def evaluate(self, dataloader, print_eval=True, n_jobs=-1):
self.model.eval() self.model.eval()
eval_losses = []
lY = defaultdict(list) lY_true = defaultdict(list)
lY_hat = defaultdict(list) lY_pred = defaultdict(list)
for b_idx, (x, y, lang) in enumerate(dataloader): for b_idx, (x, y, lang) in enumerate(dataloader):
y_hat = self.model(x.to(self.device)) y_pred = self.model(x.to(self.device))
if isinstance(y_hat, ModelOutput): if isinstance(y_pred, ModelOutput):
loss = self.loss_fn(y_hat.logits, y.to(self.device)) loss = self.loss_fn(y_pred.logits, y.to(self.device))
predictions = predict(y_hat.logits, classification_type="multilabel") predictions = predict(y_pred.logits, clf_type=self.clf_type)
else: else:
loss = self.loss_fn(y_hat, y.to(self.device)) loss = self.loss_fn(y_pred, y.to(self.device))
predictions = predict(y_hat, classification_type="multilabel") predictions = predict(y_pred, clf_type=self.clf_type)
eval_losses.append(loss.item())
for l, _true, _pred in zip(lang, y, predictions): for l, _true, _pred in zip(lang, y, predictions):
lY[l].append(_true.detach().cpu().numpy()) lY_true[l].append(_true.detach().cpu().numpy())
lY_hat[l].append(_pred) lY_pred[l].append(_pred)
for lang in lY: for lang in lY_true:
lY[lang] = np.vstack(lY[lang]) lY_true[lang] = np.vstack(lY_true[lang])
lY_hat[lang] = np.vstack(lY_hat[lang]) lY_pred[lang] = np.vstack(lY_pred[lang])
l_eval = evaluate(lY, lY_hat) l_eval = evaluate(lY_true, lY_pred, clf_type=self.clf_type, n_jobs=n_jobs)
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
return average_metrics[0] # macro-F1 avg_metrics, lang_metrics = log_eval(
l_eval, phase="validation", clf_type=self.clf_type, verbose=print_eval
)
return np.mean(eval_losses), avg_metrics, lang_metrics
class EarlyStopping: class EarlyStopping:
@ -232,7 +346,8 @@ class EarlyStopping:
self.experiment_name = experiment_name self.experiment_name = experiment_name
def __call__(self, validation, model, epoch): def __call__(self, validation, model, epoch):
if validation > self.best_score: if validation >= self.best_score:
wandb.log({"patience": self.patience - self.counter})
if self.verbose: if self.verbose:
print( print(
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}" f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
@ -244,11 +359,12 @@ class EarlyStopping:
self.save_model(model) self.save_model(model)
elif validation < (self.best_score + self.min_delta): elif validation < (self.best_score + self.min_delta):
self.counter += 1 self.counter += 1
wandb.log({"patience": self.patience - self.counter})
if self.verbose: if self.verbose:
print( print(
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}" f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
) )
if self.counter >= self.patience: if self.counter >= self.patience and self.patience != -1:
print(f"- earlystopping: Early stopping at epoch {epoch}") print(f"- earlystopping: Early stopping at epoch {epoch}")
return True return True

View File

@ -104,10 +104,6 @@ class MultilingualGen(ViewGen):
pickle.dump(self, f) pickle.dump(self, f)
return self return self
def __str__(self):
_str = f"[Multilingual VGF (m)]\n- embed_dir: {self.embed_dir}\n- langs: {self.langs}\n- n_jobs: {self.n_jobs}\n- cached: {self.cached}\n- sif: {self.sif}\n- probabilistic: {self.probabilistic}\n"
return _str
def load_MUSEs(langs, l_vocab, dir_path, cached=False): def load_MUSEs(langs, l_vocab, dir_path, cached=False):
dir_path = expanduser(dir_path) dir_path = expanduser(dir_path)

View File

@ -6,20 +6,50 @@ from collections import defaultdict
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
import transformers import transformers
from transformers import MT5EncoderModel
# from sklearn.model_selection import train_test_split
# from torch.optim import AdamW
from torch.utils.data import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers.modeling_outputs import ModelOutput
from gfun.vgfs.commons import Trainer from gfun.vgfs.commons import Trainer
from gfun.vgfs.transformerGen import TransformerGen from gfun.vgfs.transformerGen import TransformerGen
from gfun.vgfs.viewGen import ViewGen from gfun.vgfs.viewGen import ViewGen
from dataManager.torchDataset import MultilingualDatasetTorch
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
# TODO: add support to loggers class MT5ForSequenceClassification(nn.Module):
def __init__(self, model_name, num_labels, output_hidden_states):
super().__init__()
self.output_hidden_states = output_hidden_states
self.mt5encoder = MT5EncoderModel.from_pretrained(
model_name, output_hidden_states=True
)
self.dropout = nn.Dropout(0.1)
self.linear = nn.Linear(512, num_labels)
def forward(self, input_ids):
embed = self.mt5encoder(input_ids=input_ids)
pooled = torch.mean(embed.last_hidden_state, dim=1)
outputs = self.dropout(pooled)
logits = self.linear(outputs)
if self.output_hidden_states:
return ModelOutput(
logits=logits,
pooled=pooled,
)
return ModelOutput(logits=logits)
def save_pretrained(self, checkpoint_dir):
torch.save(self.state_dict(), checkpoint_dir + ".pt")
return
def from_pretrained(self, checkpoint_dir):
checkpoint_dir += ".pt"
return self.load_state_dict(torch.load(checkpoint_dir))
class TextualTransformerGen(ViewGen, TransformerGen): class TextualTransformerGen(ViewGen, TransformerGen):
@ -39,23 +69,27 @@ class TextualTransformerGen(ViewGen, TransformerGen):
evaluate_step=10, evaluate_step=10,
verbose=False, verbose=False,
patience=5, patience=5,
classification_type="multilabel",
scheduler="ReduceLROnPlateau",
): ):
super().__init__( super().__init__(
self._validate_model_name(model_name), self._validate_model_name(model_name),
dataset_name, dataset_name,
epochs, epochs=epochs,
lr, lr=lr,
batch_size, scheduler=scheduler,
batch_size_eval, batch_size=batch_size,
max_length, batch_size_eval=batch_size_eval,
print_steps, device=device,
device, evaluate_step=evaluate_step,
probabilistic, patience=patience,
n_jobs, probabilistic=probabilistic,
evaluate_step, max_length=max_length,
verbose, print_steps=print_steps,
patience, n_jobs=n_jobs,
verbose=verbose,
) )
self.clf_type = classification_type
self.fitted = False self.fitted = False
print( print(
f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]" f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]"
@ -66,15 +100,22 @@ class TextualTransformerGen(ViewGen, TransformerGen):
return "bert-base-uncased" return "bert-base-uncased"
elif "mbert" == model_name: elif "mbert" == model_name:
return "bert-base-multilingual-uncased" return "bert-base-multilingual-uncased"
elif "xlm" == model_name: elif "xlm-roberta" == model_name:
return "xlm-roberta-base" return "xlm-roberta-base"
elif "mt5" == model_name:
return "google/mt5-small"
else: else:
raise NotImplementedError raise NotImplementedError
def load_pretrained_model(self, model_name, num_labels): def load_pretrained_model(self, model_name, num_labels):
return AutoModelForSequenceClassification.from_pretrained( if model_name == "google/mt5-small":
model_name, num_labels=num_labels, output_hidden_states=True return MT5ForSequenceClassification(
) model_name, num_labels=num_labels, output_hidden_states=True
)
else:
return AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, output_hidden_states=True
)
def load_tokenizer(self, model_name): def load_tokenizer(self, model_name):
return AutoTokenizer.from_pretrained(model_name) return AutoTokenizer.from_pretrained(model_name)
@ -127,9 +168,8 @@ class TextualTransformerGen(ViewGen, TransformerGen):
shuffle=False, shuffle=False,
) )
experiment_name = ( experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
)
trainer = Trainer( trainer = Trainer(
model=self.model, model=self.model,
optimizer_name="adamW", optimizer_name="adamW",
@ -140,7 +180,16 @@ class TextualTransformerGen(ViewGen, TransformerGen):
evaluate_step=self.evaluate_step, evaluate_step=self.evaluate_step,
patience=self.patience, patience=self.patience,
experiment_name=experiment_name, experiment_name=experiment_name,
checkpoint_path="models/vgfs/transformer", checkpoint_path=os.path.join(
"models",
"vgfs",
"transformer",
self._format_model_name(self.model_name),
),
vgf_name="textual_trf",
classification_type=self.clf_type,
n_jobs=self.n_jobs,
scheduler_name=self.scheduler,
) )
trainer.train( trainer.train(
train_dataloader=tra_dataloader, train_dataloader=tra_dataloader,
@ -175,8 +224,12 @@ class TextualTransformerGen(ViewGen, TransformerGen):
with torch.no_grad(): with torch.no_grad():
for input_ids, lang in dataloader: for input_ids, lang in dataloader:
input_ids = input_ids.to(self.device) input_ids = input_ids.to(self.device)
out = self.model(input_ids).hidden_states[-1] # TODO: check this
batch_embeddings = out[:, 0, :].cpu().numpy() if isinstance(self.model, MT5ForSequenceClassification):
batch_embeddings = self.model(input_ids).pooled.cpu().numpy()
else:
out = self.model(input_ids).hidden_states[-1]
batch_embeddings = out[:, 0, :].cpu().numpy()
_embeds.append((batch_embeddings, lang)) _embeds.append((batch_embeddings, lang))
for embed, lang in _embeds: for embed, lang in _embeds:
@ -206,39 +259,22 @@ class TextualTransformerGen(ViewGen, TransformerGen):
pickle.dump(self, f) pickle.dump(self, f)
return self return self
def __str__(self): def freeze_model(self):
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n" # TODO: up to n-layers? or all? avoid freezing head ovb...
return str for param in self.model.parameters():
param.requires_grad = False
def _format_model_name(self, model_name):
if "mt5" in model_name:
return "google-mt5"
elif "bert" in model_name:
if "multilingual" in model_name:
return "mbert"
elif "xlm-roberta" in model_name:
return "xlm-roberta"
else:
return model_name
class MultilingualDatasetTorch(Dataset): def get_config(self):
def __init__(self, lX, lY, split="train"): c = super().get_config()
self.lX = lX return {"textual_trf": c}
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
}.values()
],
[],
)
return self
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]

View File

@ -26,6 +26,7 @@ class TransformerGen:
evaluate_step=10, evaluate_step=10,
verbose=False, verbose=False,
patience=5, patience=5,
scheduler=None,
): ):
self.model_name = model_name self.model_name = model_name
self.dataset_name = dataset_name self.dataset_name = dataset_name
@ -46,6 +47,7 @@ class TransformerGen:
self.verbose = verbose self.verbose = verbose
self.patience = patience self.patience = patience
self.datasets = {} self.datasets = {}
self.scheduler = scheduler
self.feature2posterior_projector = ( self.feature2posterior_projector = (
self.make_probabilistic() if probabilistic else None self.make_probabilistic() if probabilistic else None
) )
@ -94,3 +96,22 @@ class TransformerGen:
val_lY[lang] = val_Y val_lY[lang] = val_Y
return tr_lX, tr_lY, val_lX, val_lY return tr_lX, tr_lY, val_lX, val_lY
def get_config(self):
return {
"model_name": self.model_name,
"dataset_name": self.dataset_name,
"epochs": self.epochs,
"lr": self.lr,
"scheduler": self.scheduler,
"batch_size": self.batch_size,
"batch_size_eval": self.batch_size_eval,
"max_length": self.max_length,
"print_steps": self.print_steps,
"device": self.device,
"probabilistic": self.probabilistic,
"n_jobs": self.n_jobs,
"evaluate_step": self.evaluate_step,
"verbose": self.verbose,
"patience": self.patience,
}

View File

@ -65,8 +65,3 @@ class VanillaFunGen(ViewGen):
with open(_path, "wb") as f: with open(_path, "wb") as f:
pickle.dump(self, f) pickle.dump(self, f)
return self return self
def __str__(self):
_str = f"[VanillaFunGen (-p)]\n- base learner: {self.learners}\n- n_jobs: {self.n_jobs}\n"
# - parameters: {self.first_tier_parameters}
return _str

View File

@ -4,12 +4,12 @@ import numpy as np
import torch import torch
import transformers import transformers
from PIL import Image from PIL import Image
from torch.utils.data import Dataset
from transformers import AutoImageProcessor, AutoModelForImageClassification from transformers import AutoImageProcessor, AutoModelForImageClassification
from gfun.vgfs.commons import Trainer from gfun.vgfs.commons import Trainer
from gfun.vgfs.transformerGen import TransformerGen from gfun.vgfs.transformerGen import TransformerGen
from gfun.vgfs.viewGen import ViewGen from gfun.vgfs.viewGen import ViewGen
from dataManager.torchDataset import MultimodalDatasetTorch
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
@ -20,6 +20,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
model_name, model_name,
dataset_name, dataset_name,
lr=1e-5, lr=1e-5,
scheduler="ReduceLROnPlateau",
epochs=10, epochs=10,
batch_size=32, batch_size=32,
batch_size_eval=128, batch_size_eval=128,
@ -27,12 +28,14 @@ class VisualTransformerGen(ViewGen, TransformerGen):
device="cpu", device="cpu",
probabilistic=False, probabilistic=False,
patience=5, patience=5,
classification_type="multilabel",
): ):
super().__init__( super().__init__(
model_name, model_name,
dataset_name, dataset_name,
lr=lr,
epochs=epochs, epochs=epochs,
lr=lr,
scheduler=scheduler,
batch_size=batch_size, batch_size=batch_size,
batch_size_eval=batch_size_eval, batch_size_eval=batch_size_eval,
device=device, device=device,
@ -40,6 +43,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
patience=patience, patience=patience,
probabilistic=probabilistic, probabilistic=probabilistic,
) )
self.clf_type = classification_type
self.fitted = False self.fitted = False
print( print(
f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]" f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]"
@ -97,7 +101,10 @@ class VisualTransformerGen(ViewGen, TransformerGen):
shuffle=False, shuffle=False,
) )
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}" experiment_name = (
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
)
trainer = Trainer( trainer = Trainer(
model=self.model, model=self.model,
optimizer_name="adamW", optimizer_name="adamW",
@ -109,6 +116,9 @@ class VisualTransformerGen(ViewGen, TransformerGen):
patience=self.patience, patience=self.patience,
experiment_name=experiment_name, experiment_name=experiment_name,
checkpoint_path="models/vgfs/transformer", checkpoint_path="models/vgfs/transformer",
vgf_name="visual_trf",
classification_type=self.clf_type,
n_jobs=self.n_jobs,
) )
trainer.train( trainer.train(
@ -175,66 +185,5 @@ class VisualTransformerGen(ViewGen, TransformerGen):
pickle.dump(self, f) pickle.dump(self, f)
return self return self
def __str__(self): def get_config(self):
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n" return {"visual_trf": super().get_config()}
return str
class MultimodalDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([imgs for imgs in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data) for lang, data in self.lX.items()
}.values()
],
[],
)
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]
if __name__ == "__main__":
from os.path import expanduser
from dataManager.gFunDataset import gFunDataset
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
dataset = gFunDataset(
dataset_dir=GLAMI_DATAPATH,
is_textual=True,
is_visual=True,
is_multilabel=False,
nrows=50,
)
vg = VisualTransformerGen(
dataset_name=dataset.dataset_name,
model_name="vit",
device="cuda",
epochs=5,
evaluate_step=10,
patience=10,
probabilistic=True,
)
lX, lY = dataset.training()
vg.fit(lX, lY)
out = vg.transform(lX)
exit(0)

View File

@ -40,10 +40,6 @@ class WceGen(ViewGen):
"sif": self.sif, "sif": self.sif,
} }
def __str__(self):
_str = f"[WordClass VGF (w)]\n- sif: {self.sif}\n- n_jobs: {self.n_jobs}\n"
return _str
def save_vgf(self, model_id): def save_vgf(self, model_id):
import pickle import pickle
from os.path import join from os.path import join

111
main.py
View File

@ -1,3 +1,8 @@
import os
import wandb
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from argparse import ArgumentParser from argparse import ArgumentParser
from time import time from time import time
@ -5,18 +10,36 @@ from dataManager.utils import get_dataset
from evaluation.evaluate import evaluate, log_eval from evaluation.evaluate import evaluate, log_eval
from gfun.generalizedFunnelling import GeneralizedFunnelling from gfun.generalizedFunnelling import GeneralizedFunnelling
"""
TODO:
- [!] add support for Binary Datasets (e.g. cls) - NB: CLS dataset is loading only "books" domain data
- [!] documents should be trimmed to the same length (?)
- [!] logging
- add documentations sphinx
- [!] zero-shot setup
- FFNN posterior-probabilities' dependent
- re-init langs when loading VGFs?
- [!] loss of Attention-aggregator seems to be uncorrelated with Macro-F1 on the validation set!
- [!] experiment with weight init of Attention-aggregator
""" """
TODO:
- Transformers VGFs:
- scheduler with warmup and cosine
- freeze params method
- General:
[!] zero-shot setup
- CLS dataset is loading only "books" domain data
- documents should be trimmed to the same length (for SVMs we are using way too long tokens)
- Attention Aggregator:
- experiment with weight init of Attention-aggregator
- FFNN posterior-probabilities' dependent
- Docs:
- add documentations sphinx
"""
def get_config_name(args):
config_name = ""
if args.posteriors:
config_name += "P+"
if args.wce:
config_name += "W+"
if args.multilingual:
config_name += "M+"
if args.textual_transformer:
config_name += f"TT_{args.textual_trf_name}+"
if args.visual_transformer:
config_name += f"VT_{args.visual_trf_name}+"
return config_name.rstrip("+")
def main(args): def main(args):
@ -43,6 +66,7 @@ def main(args):
dataset_name=args.dataset, dataset_name=args.dataset,
langs=dataset.langs(), langs=dataset.langs(),
num_labels=dataset.num_labels(), num_labels=dataset.num_labels(),
classification_type=args.clf_type,
# Posterior VGF params ---------------- # Posterior VGF params ----------------
posterior=args.posteriors, posterior=args.posteriors,
# Multilingual VGF params ------------- # Multilingual VGF params -------------
@ -52,24 +76,26 @@ def main(args):
wce=args.wce, wce=args.wce,
# Transformer VGF params -------------- # Transformer VGF params --------------
textual_transformer=args.textual_transformer, textual_transformer=args.textual_transformer,
textual_transformer_name=args.transformer_name, textual_transformer_name=args.textual_trf_name,
batch_size=args.batch_size, batch_size=args.batch_size,
eval_batch_size=args.eval_batch_size,
epochs=args.epochs, epochs=args.epochs,
lr=args.lr, textual_lr=args.textual_lr,
visual_lr=args.visual_lr,
max_length=args.max_length, max_length=args.max_length,
patience=args.patience, patience=args.patience,
evaluate_step=args.evaluate_step, evaluate_step=args.evaluate_step,
device=args.device, device=args.device,
# Visual Transformer VGF params -------------- # Visual Transformer VGF params --------------
visual_transformer=args.visual_transformer, visual_transformer=args.visual_transformer,
visual_transformer_name=args.visual_transformer_name, visual_transformer_name=args.visual_trf_name,
# batch_size=args.batch_size, # batch_size=args.batch_size,
# epochs=args.epochs, # epochs=args.epochs,
# lr=args.lr, # lr=args.lr,
# patience=args.patience, # patience=args.patience,
# evaluate_step=args.evaluate_step, # evaluate_step=args.evaluate_step,
# device="cuda", # device="cuda",
# General params ---------------------- # General params ---------------------
probabilistic=args.features, probabilistic=args.features,
aggfunc=args.aggfunc, aggfunc=args.aggfunc,
optimc=args.optimc, optimc=args.optimc,
@ -78,27 +104,54 @@ def main(args):
n_jobs=args.n_jobs, n_jobs=args.n_jobs,
) )
# gfun.get_config() config = gfun.get_config()
wandb.init(project="gfun", name=f"gFun-{get_config_name(args)}", config=config)
gfun.fit(lX, lY) gfun.fit(lX, lY)
if args.load_trained is None and not args.nosave: if args.load_trained is None and not args.nosave:
gfun.save(save_first_tier=True, save_meta=True) gfun.save(save_first_tier=True, save_meta=True)
# print("- Computing evaluation on training set")
# preds = gfun.transform(lX)
# train_eval = evaluate(lY, preds)
# log_eval(train_eval, phase="train")
timetr = time() timetr = time()
print(f"- training completed in {timetr - tinit:.2f} seconds") print(f"- training completed in {timetr - tinit:.2f} seconds")
gfun_preds = gfun.transform(lX_te) gfun_preds = gfun.transform(lX_te)
test_eval = evaluate(lY_te, gfun_preds) test_eval = evaluate(lY_te, gfun_preds, clf_type=args.clf_type, n_jobs=args.n_jobs)
log_eval(test_eval, phase="test") avg_metrics_gfun, lang_metrics_gfun = log_eval(
test_eval, phase="test", clf_type=args.clf_type
)
timeval = time() timeval = time()
print(f"- testing completed in {timeval - timetr:.2f} seconds") print(f"- testing completed in {timeval - timetr:.2f} seconds")
def log_barplot_wandb(gfun_res, title_affix="per langauge"):
if title_affix == "per language":
for metric, lang_values in gfun_res.items():
data = [[lang, v] for lang, v in lang_values.items()]
table = wandb.Table(data=data, columns=["lang", f"{metric}"])
wandb.log(
{
f"gFun/language {metric}": wandb.plot.bar(
table, "lang", metric, title=f"{metric} {title_affix}"
)
}
)
else:
data = [[metric, value] for metric, value in gfun_res.items()]
table = wandb.Table(data=data, columns=["metric", "value"])
wandb.log(
{
f"gFun/average metric": wandb.plot.bar(
table, "metric", "value", title=f"metric {title_affix}"
)
}
)
wandb.log(gfun_res)
log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
log_barplot_wandb(avg_metrics_gfun, title_affix="averages")
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
@ -112,6 +165,8 @@ if __name__ == "__main__":
parser.add_argument("--nrows", type=int, default=None) parser.add_argument("--nrows", type=int, default=None)
parser.add_argument("--min_count", type=int, default=10) parser.add_argument("--min_count", type=int, default=10)
parser.add_argument("--max_labels", type=int, default=50) parser.add_argument("--max_labels", type=int, default=50)
parser.add_argument("--clf_type", type=str, default="multilabel")
parser.add_argument("--save_dataset", action="store_true")
# gFUN parameters ---------------------- # gFUN parameters ----------------------
parser.add_argument("-p", "--posteriors", action="store_true") parser.add_argument("-p", "--posteriors", action="store_true")
parser.add_argument("-m", "--multilingual", action="store_true") parser.add_argument("-m", "--multilingual", action="store_true")
@ -123,15 +178,17 @@ if __name__ == "__main__":
parser.add_argument("--features", action="store_false") parser.add_argument("--features", action="store_false")
parser.add_argument("--aggfunc", type=str, default="mean") parser.add_argument("--aggfunc", type=str, default="mean")
# transformer parameters --------------- # transformer parameters ---------------
parser.add_argument("--transformer_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--lr", type=float, default=1e-5) parser.add_argument("--textual_trf_name", type=str, default="mbert")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--eval_batch_size", type=int, default=128)
parser.add_argument("--textual_lr", type=float, default=1e-4)
parser.add_argument("--max_length", type=int, default=128) parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--patience", type=int, default=5) parser.add_argument("--patience", type=int, default=5)
parser.add_argument("--evaluate_step", type=int, default=10) parser.add_argument("--evaluate_step", type=int, default=10)
# Visual Transformer parameters -------------- # Visual Transformer parameters --------------
parser.add_argument("--visual_transformer_name", type=str, default="vit") parser.add_argument("--visual_trf_name", type=str, default="vit")
parser.add_argument("--visual_lr", type=float, default=1e-4)
args = parser.parse_args() args = parser.parse_args()

View File

@ -1,13 +1,13 @@
beautifulsoup4==4.11.2 beautifulsoup4==4.11.2
joblib==1.2.0 joblib==1.2.0
matplotlib==3.7.1 matplotlib==3.6.3
numpy==1.24.2 numpy==1.24.1
pandas==1.5.3 pandas==1.5.3
Pillow==9.4.0 Pillow==9.4.0
requests==2.28.2 requests==2.28.2
scikit_learn==1.2.1 scikit_learn==1.2.2
scipy==1.10.1 scipy==1.10.1
torch==1.13.1 torch==1.13.1
torchtext==0.14.1 torchtext==0.14.1
tqdm==4.65.0 tqdm==4.64.1
transformers==4.26.1 transformers==4.26.0