Compare commits
22 Commits
binary-cls
...
master
| Author | SHA1 | Date |
|---|---|---|
|
|
ab7a310b34 | |
|
|
41647f974a | |
|
|
ee2a9481de | |
|
|
ee38bcda10 | |
|
|
b34da419d0 | |
|
|
17d0003e48 | |
|
|
9d43ebb23b | |
|
|
56faaf2615 | |
|
|
f32b9227ae | |
|
|
65407f51fa | |
|
|
26aa0b327a | |
|
|
fece8d059e | |
|
|
5e41b4517a | |
|
|
a3e183d7fc | |
|
|
57918ec523 | |
|
|
7d0d6ba1f6 | |
|
|
5ef0904e0e | |
|
|
7e1ec46ebd | |
|
|
3240150542 | |
|
|
84dd1f093e | |
|
|
6b7917ca47 | |
|
|
7dead90271 |
|
|
@ -181,4 +181,5 @@ models/*
|
|||
scripts/
|
||||
logger/*
|
||||
explore_data.ipynb
|
||||
run.sh
|
||||
run.sh
|
||||
wandb
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
import os
|
||||
|
||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||
from dataManager.glamiDataset import get_dataframe
|
||||
from dataManager.multilingualDataset import MultilingualDataset
|
||||
|
|
@ -22,7 +24,7 @@ class gFunDataset:
|
|||
self.labels = labels
|
||||
self.nrows = nrows
|
||||
self.dataset = {}
|
||||
self.load_dataset()
|
||||
self._load_dataset()
|
||||
|
||||
def get_label_binarizer(self, labels):
|
||||
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
||||
|
|
@ -35,7 +37,7 @@ class gFunDataset:
|
|||
mlb.fit(labels)
|
||||
return mlb
|
||||
|
||||
def load_dataset(self):
|
||||
def _load_dataset(self):
|
||||
if "glami" in self.dataset_dir.lower():
|
||||
print(f"- Loading GLAMI dataset from {self.dataset_dir}")
|
||||
self.dataset_name = "glami"
|
||||
|
|
@ -106,44 +108,19 @@ class gFunDataset:
|
|||
return dataset, labels, data_langs
|
||||
|
||||
def _load_glami(self, dataset_dir, nrows):
|
||||
def _balanced_sample(data, n, remainder=0):
|
||||
import pandas as pd
|
||||
|
||||
langs = sorted(data.geo.unique().tolist())
|
||||
dict_n = {lang: n for lang in langs}
|
||||
dict_n[langs[0]] += remainder
|
||||
|
||||
sampled = []
|
||||
for lang in langs:
|
||||
sampled.append(data[data.geo == lang].sample(n=dict_n[lang]))
|
||||
|
||||
return pd.concat(sampled, axis=0)
|
||||
|
||||
# TODO: set this sampling as determinsitic/dependeing on the seed
|
||||
lang_nrows = (
|
||||
nrows // 13 if self.data_langs is None else nrows // len(self.data_langs)
|
||||
) # GLAMI 1-M has 13 languages
|
||||
remainder = (
|
||||
nrows % 13 if self.data_langs is None else nrows % len(self.data_langs)
|
||||
train_split = get_dataframe("train", dataset_dir=dataset_dir).sample(n=nrows)
|
||||
test_split = get_dataframe("test", dataset_dir=dataset_dir).sample(
|
||||
n=int(nrows / 10)
|
||||
)
|
||||
|
||||
train_split = get_dataframe("train", dataset_dir=dataset_dir)
|
||||
train_split = _balanced_sample(train_split, lang_nrows, remainder=remainder)
|
||||
|
||||
if self.data_langs is None:
|
||||
data_langs = sorted(train_split.geo.unique().tolist())
|
||||
# TODO: if data langs is NOT none then we have a problem where we filter df by langs
|
||||
if self.labels is None:
|
||||
labels = train_split.category_name.unique().tolist()
|
||||
|
||||
# TODO: atm test data should contain same languages as train data
|
||||
test_split = get_dataframe("test", dataset_dir=dataset_dir)
|
||||
# TODO: atm we're using 1:1 train-test
|
||||
test_split = _balanced_sample(test_split, lang_nrows, remainder=remainder)
|
||||
|
||||
gb_train = train_split.groupby("geo")
|
||||
gb_test = test_split.groupby("geo")
|
||||
|
||||
if self.data_langs is None:
|
||||
data_langs = sorted(train_split.geo.unique().tolist())
|
||||
if self.labels is None:
|
||||
labels = train_split.category_name.unique().tolist()
|
||||
|
||||
def _format_glami(data_df):
|
||||
text = (data_df.name + " " + data_df.description).tolist()
|
||||
image = data_df.image_file.tolist()
|
||||
|
|
@ -205,6 +182,14 @@ class gFunDataset:
|
|||
else:
|
||||
return self.labels
|
||||
|
||||
def save_as_pickle(self, path):
|
||||
import pickle
|
||||
|
||||
filepath = os.path.join(path, f"{self.dataset_name}_{self.nrows}.pkl")
|
||||
with open(filepath, "wb") as f:
|
||||
print(f"- saving dataset in {filepath}")
|
||||
pickle.dump(self, f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
|
|
|||
|
|
@ -1,2 +1,66 @@
|
|||
class TorchMultiNewsDataset:
|
||||
pass
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class MultilingualDatasetTorch(Dataset):
|
||||
def __init__(self, lX, lY, split="train"):
|
||||
self.lX = lX
|
||||
self.lY = lY
|
||||
self.split = split
|
||||
self.langs = []
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
|
||||
if self.split != "whole":
|
||||
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
||||
self.langs = sum(
|
||||
[
|
||||
v
|
||||
for v in {
|
||||
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
|
||||
}.values()
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return len(self.X)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.split == "whole":
|
||||
return self.X[index], self.langs[index]
|
||||
return self.X[index], self.Y[index], self.langs[index]
|
||||
|
||||
|
||||
class MultimodalDatasetTorch(Dataset):
|
||||
def __init__(self, lX, lY, split="train"):
|
||||
self.lX = lX
|
||||
self.lY = lY
|
||||
self.split = split
|
||||
self.langs = []
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
self.X = torch.vstack([imgs for imgs in self.lX.values()])
|
||||
if self.split != "whole":
|
||||
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
||||
self.langs = sum(
|
||||
[
|
||||
v
|
||||
for v in {
|
||||
lang: [lang] * len(data) for lang, data in self.lX.items()
|
||||
}.values()
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.X)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.split == "whole":
|
||||
return self.X[index], self.langs[index]
|
||||
return self.X[index], self.Y[index], self.langs[index]
|
||||
|
|
|
|||
|
|
@ -1,9 +1,21 @@
|
|||
from os.path import expanduser
|
||||
from os.path import expanduser, join
|
||||
from dataManager.gFunDataset import gFunDataset
|
||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
||||
from dataManager.amazonDataset import AmazonDataset
|
||||
|
||||
|
||||
def load_from_pickle(path, dataset_name, nrows):
|
||||
import pickle
|
||||
|
||||
filepath = join(path, f"{dataset_name}_{nrows}.pkl")
|
||||
|
||||
with open(filepath, "rb") as f:
|
||||
loaded = pickle.load(f)
|
||||
print(f"- Loaded dataset from {filepath}")
|
||||
loaded.show_dimension()
|
||||
return loaded
|
||||
|
||||
|
||||
def get_dataset(dataset_name, args):
|
||||
assert dataset_name in [
|
||||
"multinews",
|
||||
|
|
@ -58,13 +70,19 @@ def get_dataset(dataset_name, args):
|
|||
nrows=args.nrows,
|
||||
)
|
||||
elif dataset_name == "glami":
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=GLAMI_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=True,
|
||||
is_multilabel=False,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
if args.save_dataset is False:
|
||||
dataset = load_from_pickle(GLAMI_DATAPATH, dataset_name, args.nrows)
|
||||
else:
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=GLAMI_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=True,
|
||||
is_multilabel=False,
|
||||
nrows=args.nrows,
|
||||
)
|
||||
|
||||
dataset.save_as_pickle(GLAMI_DATAPATH)
|
||||
|
||||
elif dataset_name == "cls":
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=CLS_DATAPATH,
|
||||
|
|
|
|||
|
|
@ -1,51 +1,96 @@
|
|||
from joblib import Parallel, delayed
|
||||
from collections import defaultdict
|
||||
|
||||
from evaluation.metrics import *
|
||||
from sklearn.metrics import accuracy_score, top_k_accuracy_score, f1_score
|
||||
|
||||
|
||||
def evaluation_metrics(y, y_):
|
||||
if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
def evaluation_metrics(y, y_, clf_type):
|
||||
if clf_type == "singlelabel":
|
||||
return (
|
||||
accuracy_score(y, y_),
|
||||
# TODO: we need the logits to compute this top_k_accuracy_score(y, y_, k=5),
|
||||
# TODO: we need logits top_k_accuracy_score(y, y_, k=10),
|
||||
f1_score(y, y_, average="macro", zero_division=1),
|
||||
f1_score(y, y_, average="micro"),
|
||||
)
|
||||
elif clf_type == "multilabel":
|
||||
return (
|
||||
macroF1(y, y_),
|
||||
microF1(y, y_),
|
||||
macroK(y, y_),
|
||||
microK(y, y_),
|
||||
# macroAcc(y, y_),
|
||||
microAcc(
|
||||
y, y_
|
||||
), # TODO: we're using micro-averaging for accuracy, it is == to accuracy_score on binary classification
|
||||
)
|
||||
else:
|
||||
raise ValueError("clf_type must be either 'singlelabel' or 'multilabel'")
|
||||
|
||||
|
||||
def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1):
|
||||
def evaluate(
|
||||
ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1, clf_type="multilabel"
|
||||
):
|
||||
if n_jobs == 1:
|
||||
return {lang: metrics(ly_true[lang], ly_pred[lang]) for lang in ly_true.keys()}
|
||||
return {
|
||||
lang: metrics(ly_true[lang], ly_pred[lang], clf_type)
|
||||
for lang in ly_true.keys()
|
||||
}
|
||||
else:
|
||||
langs = list(ly_true.keys())
|
||||
evals = Parallel(n_jobs=n_jobs)(
|
||||
delayed(metrics)(ly_true[lang], ly_pred[lang]) for lang in langs
|
||||
delayed(metrics)(ly_true[lang], ly_pred[lang], clf_type) for lang in langs
|
||||
)
|
||||
return {lang: evals[i] for i, lang in enumerate(langs)}
|
||||
|
||||
|
||||
def log_eval(l_eval, phase="training", verbose=True):
|
||||
def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
|
||||
if verbose:
|
||||
print(f"\n[Results {phase}]")
|
||||
metrics = []
|
||||
for lang in l_eval.keys():
|
||||
macrof1, microf1, macrok, microk, microAcc = l_eval[lang]
|
||||
metrics.append([macrof1, microf1, macrok, microk, microAcc])
|
||||
if phase != "validation":
|
||||
|
||||
if clf_type == "multilabel":
|
||||
for lang in l_eval.keys():
|
||||
macrof1, microf1, macrok, microk = l_eval[lang]
|
||||
metrics.append([macrof1, microf1, macrok, microk])
|
||||
if phase != "validation":
|
||||
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
|
||||
averages = np.mean(np.array(metrics), axis=0)
|
||||
if verbose:
|
||||
print(
|
||||
f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f} acc = {microAcc:.3f}"
|
||||
"Averages: MF1, mF1, MK, mK",
|
||||
np.round(averages, 3),
|
||||
"\n",
|
||||
)
|
||||
averages = np.mean(np.array(metrics), axis=0)
|
||||
if verbose:
|
||||
print(
|
||||
"Averages: MF1, mF1, MK, mK",
|
||||
np.round(averages, 3),
|
||||
"\n",
|
||||
)
|
||||
return averages
|
||||
return averages # TODO: return a dict avg and lang specific
|
||||
|
||||
elif clf_type == "singlelabel":
|
||||
lang_metrics = defaultdict(dict)
|
||||
_metrics = [
|
||||
"accuracy",
|
||||
# "acc5", # "accuracy-at-5",
|
||||
# "acc10", # "accuracy-at-10",
|
||||
"MF1", # "macro-F1",
|
||||
"mF1", # "micro-F1",
|
||||
]
|
||||
for lang in l_eval.keys():
|
||||
# acc, top5, top10, macrof1, microf1 = l_eval[lang]
|
||||
acc, macrof1, microf1 = l_eval[lang]
|
||||
# metrics.append([acc, top5, top10, macrof1, microf1])
|
||||
metrics.append([acc, macrof1, microf1])
|
||||
|
||||
for m, v in zip(_metrics, l_eval[lang]):
|
||||
lang_metrics[m][lang] = v
|
||||
|
||||
if phase != "validation":
|
||||
print(
|
||||
# f"Lang {lang}: acc = {acc:.3f} acc-top5 = {top5:.3f} acc-top10 = {top10:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
|
||||
f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
|
||||
)
|
||||
averages = np.mean(np.array(metrics), axis=0)
|
||||
if verbose:
|
||||
print(
|
||||
# "Averages: Acc, Acc-5, Acc-10, MF1, mF1",
|
||||
"Averages: Acc, MF1, mF1",
|
||||
np.round(averages, 3),
|
||||
"\n",
|
||||
)
|
||||
avg_metrics = dict(zip(_metrics, averages))
|
||||
return avg_metrics, lang_metrics
|
||||
|
|
|
|||
|
|
@ -239,7 +239,3 @@ def microK(true_labels, predicted_labels):
|
|||
|
||||
def macroAcc(true_labels, predicted_labels):
|
||||
return macro_average(true_labels, predicted_labels, accuracy)
|
||||
|
||||
|
||||
def microAcc(true_labels, predicted_labels):
|
||||
return micro_average(true_labels, predicted_labels, accuracy)
|
||||
|
|
|
|||
|
|
@ -1,17 +1,14 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
# sys.path.append(os.path.join(os.getcwd(), "gfun"))
|
||||
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
from gfun.vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator
|
||||
|
||||
from gfun.vgfs.commons import AttentionAggregator, TfidfVectorizerMultilingual, predict
|
||||
from gfun.vgfs.learners.svms import MetaClassifier, get_learner
|
||||
from gfun.vgfs.multilingualGen import MultilingualGen
|
||||
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
|
||||
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
|
||||
from gfun.vgfs.vanillaFun import VanillaFunGen
|
||||
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
|
||||
from gfun.vgfs.wceGen import WceGen
|
||||
|
||||
|
||||
|
|
@ -25,11 +22,14 @@ class GeneralizedFunnelling:
|
|||
visual_transformer,
|
||||
langs,
|
||||
num_labels,
|
||||
classification_type,
|
||||
embed_dir,
|
||||
n_jobs,
|
||||
batch_size,
|
||||
eval_batch_size,
|
||||
max_length,
|
||||
lr,
|
||||
textual_lr,
|
||||
visual_lr,
|
||||
epochs,
|
||||
patience,
|
||||
evaluate_step,
|
||||
|
|
@ -47,26 +47,31 @@ class GeneralizedFunnelling:
|
|||
self.posteriors_vgf = posterior
|
||||
self.wce_vgf = wce
|
||||
self.multilingual_vgf = multilingual
|
||||
self.trasformer_vgf = textual_transformer
|
||||
self.visual_transformer_vgf = visual_transformer
|
||||
self.textual_trf_vgf = textual_transformer
|
||||
self.visual_trf_vgf = visual_transformer
|
||||
self.probabilistic = probabilistic
|
||||
self.num_labels = num_labels
|
||||
self.clf_type = classification_type
|
||||
# ------------------------
|
||||
self.langs = langs
|
||||
self.embed_dir = embed_dir
|
||||
self.cached = True
|
||||
# Textual Transformer VGF params ----------
|
||||
self.textaul_transformer_name = textual_transformer_name
|
||||
self.textual_trf_name = textual_transformer_name
|
||||
self.epochs = epochs
|
||||
self.lr_transformer = lr
|
||||
self.batch_size_transformer = batch_size
|
||||
self.textual_trf_lr = textual_lr
|
||||
self.textual_scheduler = "ReduceLROnPlateau"
|
||||
self.batch_size_trf = batch_size
|
||||
self.eval_batch_size_trf = eval_batch_size
|
||||
self.max_length = max_length
|
||||
self.early_stopping = True
|
||||
self.patience = patience
|
||||
self.evaluate_step = evaluate_step
|
||||
self.device = device
|
||||
# Visual Transformer VGF params ----------
|
||||
self.visual_transformer_name = visual_transformer_name
|
||||
self.visual_trf_name = visual_transformer_name
|
||||
self.visual_trf_lr = visual_lr
|
||||
self.visual_scheduler = "ReduceLROnPlateau"
|
||||
# Metaclassifier params ------------
|
||||
self.optimc = optimc
|
||||
# -------------------
|
||||
|
|
@ -77,7 +82,7 @@ class GeneralizedFunnelling:
|
|||
self.aggfunc = aggfunc
|
||||
self.load_trained = load_trained
|
||||
self.load_first_tier = (
|
||||
True # TODO: i guess we're always going to load at least the fitst tier
|
||||
True # TODO: i guess we're always going to load at least the first tier
|
||||
)
|
||||
self.load_meta = load_meta
|
||||
self.dataset_name = dataset_name
|
||||
|
|
@ -112,7 +117,7 @@ class GeneralizedFunnelling:
|
|||
self.attn_aggregator = AttentionAggregator(
|
||||
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
||||
out_dim=self.num_labels,
|
||||
lr=self.lr_transformer,
|
||||
lr=self.textual_trf_lr,
|
||||
patience=self.patience,
|
||||
num_heads=1,
|
||||
device=self.device,
|
||||
|
|
@ -142,13 +147,15 @@ class GeneralizedFunnelling:
|
|||
wce_vgf = WceGen(n_jobs=self.n_jobs)
|
||||
self.first_tier_learners.append(wce_vgf)
|
||||
|
||||
if self.trasformer_vgf:
|
||||
if self.textual_trf_vgf:
|
||||
transformer_vgf = TextualTransformerGen(
|
||||
dataset_name=self.dataset_name,
|
||||
model_name=self.textaul_transformer_name,
|
||||
lr=self.lr_transformer,
|
||||
model_name=self.textual_trf_name,
|
||||
lr=self.textual_trf_lr,
|
||||
scheduler=self.textual_scheduler,
|
||||
epochs=self.epochs,
|
||||
batch_size=self.batch_size_transformer,
|
||||
batch_size=self.batch_size_trf,
|
||||
batch_size_eval=self.eval_batch_size_trf,
|
||||
max_length=self.max_length,
|
||||
print_steps=50,
|
||||
probabilistic=self.probabilistic,
|
||||
|
|
@ -156,21 +163,24 @@ class GeneralizedFunnelling:
|
|||
verbose=True,
|
||||
patience=self.patience,
|
||||
device=self.device,
|
||||
classification_type=self.clf_type,
|
||||
)
|
||||
self.first_tier_learners.append(transformer_vgf)
|
||||
|
||||
if self.visual_transformer_vgf:
|
||||
if self.visual_trf_vgf:
|
||||
visual_trasformer_vgf = VisualTransformerGen(
|
||||
dataset_name=self.dataset_name,
|
||||
model_name="vit",
|
||||
lr=1e-5, # self.lr_visual_transformer,
|
||||
lr=self.visual_trf_lr,
|
||||
scheduler=self.visual_scheduler,
|
||||
epochs=self.epochs,
|
||||
batch_size=32, # self.batch_size_visual_transformer,
|
||||
# batch_size_eval=128,
|
||||
batch_size=self.batch_size_trf,
|
||||
batch_size_eval=self.eval_batch_size_trf,
|
||||
probabilistic=self.probabilistic,
|
||||
evaluate_step=self.evaluate_step,
|
||||
patience=self.patience,
|
||||
device=self.device,
|
||||
classification_type=self.clf_type,
|
||||
)
|
||||
self.first_tier_learners.append(visual_trasformer_vgf)
|
||||
|
||||
|
|
@ -179,7 +189,7 @@ class GeneralizedFunnelling:
|
|||
self.attn_aggregator = AttentionAggregator(
|
||||
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
||||
out_dim=self.num_labels,
|
||||
lr=self.lr_transformer,
|
||||
lr=self.textual_trf_lr,
|
||||
patience=self.patience,
|
||||
num_heads=1,
|
||||
device=self.device,
|
||||
|
|
@ -198,7 +208,8 @@ class GeneralizedFunnelling:
|
|||
self.posteriors_vgf,
|
||||
self.multilingual_vgf,
|
||||
self.wce_vgf,
|
||||
self.trasformer_vgf,
|
||||
self.textual_trf_vgf,
|
||||
self.visual_trf_vgf,
|
||||
self.aggfunc,
|
||||
)
|
||||
print(f"- model id: {self._model_id}")
|
||||
|
|
@ -251,10 +262,9 @@ class GeneralizedFunnelling:
|
|||
projections.append(l_posteriors)
|
||||
agg = self.aggregate(projections)
|
||||
l_out = self.metaclassifier.predict_proba(agg)
|
||||
# converting to binary predictions
|
||||
# if self.dataset_name in ["cls"]: # TODO: better way to do this
|
||||
# for lang, preds in l_out.items():
|
||||
# l_out[lang] = np.expand_dims(np.argmax(preds, axis=1), axis=1)
|
||||
if self.clf_type == "singlelabel":
|
||||
for lang, preds in l_out.items():
|
||||
l_out[lang] = predict(preds, clf_type=self.clf_type)
|
||||
return l_out
|
||||
|
||||
def fit_transform(self, lX, lY):
|
||||
|
|
@ -303,15 +313,21 @@ class GeneralizedFunnelling:
|
|||
return aggregated
|
||||
|
||||
def get_config(self):
|
||||
print("\n")
|
||||
print("-" * 50)
|
||||
print("[GeneralizedFunnelling config]")
|
||||
print(f"- model trained on langs: {self.langs}")
|
||||
print("-- View Generating Functions configurations:\n")
|
||||
c = {}
|
||||
|
||||
for vgf in self.first_tier_learners:
|
||||
print(vgf)
|
||||
print("-" * 50)
|
||||
vgf_config = vgf.get_config()
|
||||
c.update(vgf_config)
|
||||
|
||||
gfun_config = {
|
||||
"id": self._model_id,
|
||||
"aggfunc": self.aggfunc,
|
||||
"optimc": self.optimc,
|
||||
"dataset": self.dataset_name,
|
||||
}
|
||||
|
||||
c["gFun"] = gfun_config
|
||||
return c
|
||||
|
||||
def save(self, save_first_tier=True, save_meta=True):
|
||||
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")
|
||||
|
|
@ -334,7 +350,7 @@ class GeneralizedFunnelling:
|
|||
pickle.dump(self.metaclassifier, f)
|
||||
return
|
||||
|
||||
def save_first_tier_learners(self):
|
||||
def save_first_tier_learners(self, model_id):
|
||||
for vgf in self.first_tier_learners:
|
||||
vgf.save_vgf(model_id=self._model_id)
|
||||
return self
|
||||
|
|
@ -372,7 +388,7 @@ class GeneralizedFunnelling:
|
|||
"rb",
|
||||
) as vgf:
|
||||
first_tier_learners.append(pickle.load(vgf))
|
||||
if self.trasformer_vgf:
|
||||
if self.textual_trf_vgf:
|
||||
with open(
|
||||
os.path.join(
|
||||
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
|
||||
|
|
@ -427,7 +443,15 @@ def get_params(optimc=False):
|
|||
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
|
||||
|
||||
|
||||
def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfunc):
|
||||
def get_unique_id(
|
||||
dataset_name,
|
||||
posterior,
|
||||
multilingual,
|
||||
wce,
|
||||
textual_transformer,
|
||||
visual_transformer,
|
||||
aggfunc,
|
||||
):
|
||||
from datetime import datetime
|
||||
|
||||
now = datetime.now().strftime("%y%m%d")
|
||||
|
|
@ -435,6 +459,7 @@ def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfu
|
|||
model_id += "p" if posterior else ""
|
||||
model_id += "m" if multilingual else ""
|
||||
model_id += "w" if wce else ""
|
||||
model_id += "t" if transformer else ""
|
||||
model_id += "t" if textual_transformer else ""
|
||||
model_id += "v" if visual_transformer else ""
|
||||
model_id += f"_{aggfunc}"
|
||||
return f"{model_id}_{now}"
|
||||
|
|
|
|||
|
|
@ -9,9 +9,11 @@ from sklearn.feature_extraction.text import TfidfVectorizer
|
|||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import normalize
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers.modeling_outputs import ModelOutput
|
||||
|
||||
import wandb
|
||||
from evaluation.evaluate import evaluate, log_eval
|
||||
|
||||
PRINT_ON_EPOCH = 1
|
||||
|
|
@ -21,6 +23,28 @@ def _normalize(lX, l2=True):
|
|||
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
|
||||
|
||||
|
||||
def verbosity_eval(epoch, print_eval):
|
||||
if (epoch + 1) % print_eval == 0 and epoch != 0:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def format_langkey_wandb(lang_dict, vgf_name):
|
||||
log_dict = {}
|
||||
for metric, l_dict in lang_dict.items():
|
||||
for lang, value in l_dict.items():
|
||||
log_dict[f"{vgf_name}/language metric/{metric}/{lang}"] = value
|
||||
return log_dict
|
||||
|
||||
|
||||
def format_average_wandb(avg_dict, vgf_name):
|
||||
log_dict = {}
|
||||
for metric, value in avg_dict.items():
|
||||
log_dict[f"{vgf_name}/average metric/{metric}"] = value
|
||||
return log_dict
|
||||
|
||||
|
||||
def XdotM(X, M, sif):
|
||||
E = X.dot(M)
|
||||
if sif:
|
||||
|
|
@ -57,18 +81,23 @@ def compute_pc(X, npc=1):
|
|||
return svd.components_
|
||||
|
||||
|
||||
def predict(logits, classification_type="multilabel"):
|
||||
def predict(logits, clf_type="multilabel"):
|
||||
"""
|
||||
Converts soft precictions to hard predictions [0,1]
|
||||
"""
|
||||
if classification_type == "multilabel":
|
||||
if clf_type == "multilabel":
|
||||
prediction = torch.sigmoid(logits) > 0.5
|
||||
elif classification_type == "singlelabel":
|
||||
prediction = torch.argmax(logits, dim=1).view(-1, 1)
|
||||
return prediction.detach().cpu().numpy()
|
||||
elif clf_type == "singlelabel":
|
||||
if type(logits) != torch.Tensor:
|
||||
logits = torch.tensor(logits)
|
||||
prediction = torch.softmax(logits, dim=1)
|
||||
prediction = prediction.detach().cpu().numpy()
|
||||
_argmaxs = prediction.argmax(axis=1)
|
||||
prediction = np.eye(prediction.shape[1])[_argmaxs]
|
||||
return prediction
|
||||
else:
|
||||
print("unknown classification type")
|
||||
|
||||
return prediction.detach().cpu().numpy()
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TfidfVectorizerMultilingual:
|
||||
|
|
@ -114,63 +143,138 @@ class Trainer:
|
|||
patience,
|
||||
experiment_name,
|
||||
checkpoint_path,
|
||||
classification_type,
|
||||
vgf_name,
|
||||
n_jobs,
|
||||
scheduler_name=None,
|
||||
):
|
||||
self.device = device
|
||||
self.model = model.to(device)
|
||||
self.optimizer = self.init_optimizer(optimizer_name, lr)
|
||||
self.optimizer, self.scheduler = self.init_optimizer(
|
||||
optimizer_name, lr, scheduler_name
|
||||
)
|
||||
self.evaluate_steps = evaluate_step
|
||||
self.loss_fn = loss_fn.to(device)
|
||||
self.print_steps = print_steps
|
||||
self.experiment_name = experiment_name
|
||||
self.patience = patience
|
||||
self.print_eval = evaluate_step
|
||||
self.print_eval = 10
|
||||
self.earlystopping = EarlyStopping(
|
||||
patience=patience,
|
||||
checkpoint_path=checkpoint_path,
|
||||
verbose=False,
|
||||
experiment_name=experiment_name,
|
||||
)
|
||||
self.clf_type = classification_type
|
||||
self.vgf_name = vgf_name
|
||||
self.scheduler_name = scheduler_name
|
||||
self.n_jobs = n_jobs
|
||||
self.monitored_metric = (
|
||||
"macro-F1" if self.clf_type == "multilabel" else "accuracy"
|
||||
) # TODO: make this configurable
|
||||
|
||||
def init_optimizer(self, optimizer_name, lr):
|
||||
def init_optimizer(self, optimizer_name, lr, scheduler_name):
|
||||
if optimizer_name.lower() == "adamw":
|
||||
return AdamW(self.model.parameters(), lr=lr)
|
||||
optim = AdamW(self.model.parameters(), lr=lr)
|
||||
else:
|
||||
raise ValueError(f"Optimizer {optimizer_name} not supported")
|
||||
if scheduler_name is None:
|
||||
scheduler = None
|
||||
elif scheduler_name == "ReduceLROnPlateau":
|
||||
scheduler = ReduceLROnPlateau(optim, "max", factor=0.5, min_lr=1e-5)
|
||||
else:
|
||||
raise ValueError(f"Scheduler {scheduler_name} not supported")
|
||||
return optim, scheduler
|
||||
|
||||
def get_config(self, train_dataloader, eval_dataloader, epochs):
|
||||
return {
|
||||
"model name": self.model.name_or_path
|
||||
if not hasattr(self.model, "mt5encoder")
|
||||
else self.model.mt5encoder.name_or_path,
|
||||
"epochs": epochs,
|
||||
"learning rate": self.optimizer.defaults["lr"],
|
||||
"scheduler": self.scheduler_name, # TODO: add scheduler params
|
||||
"train size": len(train_dataloader.dataset),
|
||||
"eval size": len(eval_dataloader.dataset),
|
||||
"train batch size": train_dataloader.batch_size,
|
||||
"eval batch size": eval_dataloader.batch_size,
|
||||
"max len": train_dataloader.dataset.X.shape[-1],
|
||||
"patience": self.earlystopping.patience,
|
||||
"evaluate every": self.evaluate_steps,
|
||||
"print eval every": self.print_eval,
|
||||
"print train steps": self.print_steps,
|
||||
"classification type": self.clf_type,
|
||||
}
|
||||
|
||||
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
||||
print(
|
||||
f"""- Training params for {self.experiment_name}:
|
||||
- epochs: {epochs}
|
||||
- learning rate: {self.optimizer.defaults['lr']}
|
||||
- train batch size: {train_dataloader.batch_size}
|
||||
- eval batch size: {eval_dataloader.batch_size}
|
||||
- max len: {train_dataloader.dataset.X.shape[-1]}
|
||||
- patience: {self.earlystopping.patience}
|
||||
- evaluate every: {self.evaluate_steps}
|
||||
- print eval every: {self.print_eval}
|
||||
- print train steps: {self.print_steps}\n"""
|
||||
)
|
||||
_config = self.get_config(train_dataloader, eval_dataloader, epochs)
|
||||
|
||||
print(f"- Training params for {self.experiment_name}:")
|
||||
for k, v in _config.items():
|
||||
print(f"\t{k}: {v}")
|
||||
|
||||
for epoch in range(epochs):
|
||||
self.train_epoch(train_dataloader, epoch)
|
||||
if (epoch + 1) % self.evaluate_steps == 0:
|
||||
print_eval = (epoch + 1) % self.print_eval == 0
|
||||
metric_watcher = self.evaluate(eval_dataloader, print_eval=print_eval)
|
||||
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
||||
train_loss = self.train_epoch(train_dataloader, epoch)
|
||||
|
||||
if (epoch + 1) % self.evaluate_steps == 0 or (epoch + 1) == 1:
|
||||
print_eval = verbosity_eval(epoch, self.print_eval)
|
||||
with torch.no_grad():
|
||||
eval_loss, avg_metrics, lang_metrics = self.evaluate(
|
||||
eval_dataloader,
|
||||
print_eval=print_eval,
|
||||
n_jobs=self.n_jobs,
|
||||
)
|
||||
|
||||
wandb.log(
|
||||
{
|
||||
f"{self.vgf_name}/loss/val": eval_loss,
|
||||
**format_langkey_wandb(lang_metrics, self.vgf_name),
|
||||
**format_average_wandb(avg_metrics, self.vgf_name),
|
||||
},
|
||||
commit=False,
|
||||
)
|
||||
|
||||
stop = self.earlystopping(
|
||||
avg_metrics[self.monitored_metric], self.model, epoch + 1
|
||||
)
|
||||
if stop:
|
||||
print(
|
||||
f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}"
|
||||
)
|
||||
self.model = self.earlystopping.load_model(self.model).to(
|
||||
self.device
|
||||
)
|
||||
restored_model = self.earlystopping.load_model(self.model)
|
||||
|
||||
# swapping model on gpu
|
||||
del self.model
|
||||
self.model = restored_model.to(self.device)
|
||||
break
|
||||
|
||||
if self.scheduler is not None:
|
||||
self.scheduler.step(avg_metrics[self.monitored_metric])
|
||||
|
||||
wandb.log(
|
||||
{
|
||||
f"{self.vgf_name}/loss/train": train_loss,
|
||||
f"{self.vgf_name}/learning rate": self.optimizer.param_groups[0][
|
||||
"lr"
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
print(f"- last swipe on eval set")
|
||||
self.train_epoch(eval_dataloader, epoch=0)
|
||||
self.train_epoch(
|
||||
DataLoader(
|
||||
eval_dataloader.dataset,
|
||||
batch_size=train_dataloader.batch_size,
|
||||
shuffle=True,
|
||||
),
|
||||
epoch=-1,
|
||||
)
|
||||
self.earlystopping.save_model(self.model)
|
||||
return self.model
|
||||
|
||||
def train_epoch(self, dataloader, epoch):
|
||||
self.model.train()
|
||||
batch_losses = []
|
||||
for b_idx, (x, y, lang) in enumerate(dataloader):
|
||||
self.optimizer.zero_grad()
|
||||
y_hat = self.model(x.to(self.device))
|
||||
|
|
@ -180,37 +284,47 @@ class Trainer:
|
|||
loss = self.loss_fn(y_hat, y.to(self.device))
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
batch_losses.append(loss.item())
|
||||
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
||||
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
||||
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
||||
return self
|
||||
print(
|
||||
f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {np.mean(batch_losses):.4f}"
|
||||
)
|
||||
return np.mean(batch_losses)
|
||||
|
||||
def evaluate(self, dataloader, print_eval=True):
|
||||
def evaluate(self, dataloader, print_eval=True, n_jobs=-1):
|
||||
self.model.eval()
|
||||
eval_losses = []
|
||||
|
||||
lY = defaultdict(list)
|
||||
lY_hat = defaultdict(list)
|
||||
lY_true = defaultdict(list)
|
||||
lY_pred = defaultdict(list)
|
||||
|
||||
for b_idx, (x, y, lang) in enumerate(dataloader):
|
||||
y_hat = self.model(x.to(self.device))
|
||||
if isinstance(y_hat, ModelOutput):
|
||||
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
||||
predictions = predict(y_hat.logits, classification_type="multilabel")
|
||||
y_pred = self.model(x.to(self.device))
|
||||
if isinstance(y_pred, ModelOutput):
|
||||
loss = self.loss_fn(y_pred.logits, y.to(self.device))
|
||||
predictions = predict(y_pred.logits, clf_type=self.clf_type)
|
||||
else:
|
||||
loss = self.loss_fn(y_hat, y.to(self.device))
|
||||
predictions = predict(y_hat, classification_type="multilabel")
|
||||
loss = self.loss_fn(y_pred, y.to(self.device))
|
||||
predictions = predict(y_pred, clf_type=self.clf_type)
|
||||
|
||||
eval_losses.append(loss.item())
|
||||
|
||||
for l, _true, _pred in zip(lang, y, predictions):
|
||||
lY[l].append(_true.detach().cpu().numpy())
|
||||
lY_hat[l].append(_pred)
|
||||
lY_true[l].append(_true.detach().cpu().numpy())
|
||||
lY_pred[l].append(_pred)
|
||||
|
||||
for lang in lY:
|
||||
lY[lang] = np.vstack(lY[lang])
|
||||
lY_hat[lang] = np.vstack(lY_hat[lang])
|
||||
for lang in lY_true:
|
||||
lY_true[lang] = np.vstack(lY_true[lang])
|
||||
lY_pred[lang] = np.vstack(lY_pred[lang])
|
||||
|
||||
l_eval = evaluate(lY, lY_hat)
|
||||
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
|
||||
return average_metrics[0] # macro-F1
|
||||
l_eval = evaluate(lY_true, lY_pred, clf_type=self.clf_type, n_jobs=n_jobs)
|
||||
|
||||
avg_metrics, lang_metrics = log_eval(
|
||||
l_eval, phase="validation", clf_type=self.clf_type, verbose=print_eval
|
||||
)
|
||||
|
||||
return np.mean(eval_losses), avg_metrics, lang_metrics
|
||||
|
||||
|
||||
class EarlyStopping:
|
||||
|
|
@ -232,7 +346,8 @@ class EarlyStopping:
|
|||
self.experiment_name = experiment_name
|
||||
|
||||
def __call__(self, validation, model, epoch):
|
||||
if validation > self.best_score:
|
||||
if validation >= self.best_score:
|
||||
wandb.log({"patience": self.patience - self.counter})
|
||||
if self.verbose:
|
||||
print(
|
||||
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
|
||||
|
|
@ -244,11 +359,12 @@ class EarlyStopping:
|
|||
self.save_model(model)
|
||||
elif validation < (self.best_score + self.min_delta):
|
||||
self.counter += 1
|
||||
wandb.log({"patience": self.patience - self.counter})
|
||||
if self.verbose:
|
||||
print(
|
||||
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
|
||||
)
|
||||
if self.counter >= self.patience:
|
||||
if self.counter >= self.patience and self.patience != -1:
|
||||
print(f"- earlystopping: Early stopping at epoch {epoch}")
|
||||
return True
|
||||
|
||||
|
|
|
|||
|
|
@ -104,10 +104,6 @@ class MultilingualGen(ViewGen):
|
|||
pickle.dump(self, f)
|
||||
return self
|
||||
|
||||
def __str__(self):
|
||||
_str = f"[Multilingual VGF (m)]\n- embed_dir: {self.embed_dir}\n- langs: {self.langs}\n- n_jobs: {self.n_jobs}\n- cached: {self.cached}\n- sif: {self.sif}\n- probabilistic: {self.probabilistic}\n"
|
||||
return _str
|
||||
|
||||
|
||||
def load_MUSEs(langs, l_vocab, dir_path, cached=False):
|
||||
dir_path = expanduser(dir_path)
|
||||
|
|
|
|||
|
|
@ -6,20 +6,50 @@ from collections import defaultdict
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
|
||||
# from sklearn.model_selection import train_test_split
|
||||
# from torch.optim import AdamW
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import MT5EncoderModel
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from transformers.modeling_outputs import ModelOutput
|
||||
|
||||
from gfun.vgfs.commons import Trainer
|
||||
from gfun.vgfs.transformerGen import TransformerGen
|
||||
from gfun.vgfs.viewGen import ViewGen
|
||||
from dataManager.torchDataset import MultilingualDatasetTorch
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
||||
# TODO: add support to loggers
|
||||
class MT5ForSequenceClassification(nn.Module):
|
||||
def __init__(self, model_name, num_labels, output_hidden_states):
|
||||
super().__init__()
|
||||
|
||||
self.output_hidden_states = output_hidden_states
|
||||
self.mt5encoder = MT5EncoderModel.from_pretrained(
|
||||
model_name, output_hidden_states=True
|
||||
)
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.linear = nn.Linear(512, num_labels)
|
||||
|
||||
def forward(self, input_ids):
|
||||
embed = self.mt5encoder(input_ids=input_ids)
|
||||
pooled = torch.mean(embed.last_hidden_state, dim=1)
|
||||
outputs = self.dropout(pooled)
|
||||
logits = self.linear(outputs)
|
||||
if self.output_hidden_states:
|
||||
return ModelOutput(
|
||||
logits=logits,
|
||||
pooled=pooled,
|
||||
)
|
||||
return ModelOutput(logits=logits)
|
||||
|
||||
def save_pretrained(self, checkpoint_dir):
|
||||
torch.save(self.state_dict(), checkpoint_dir + ".pt")
|
||||
return
|
||||
|
||||
def from_pretrained(self, checkpoint_dir):
|
||||
checkpoint_dir += ".pt"
|
||||
return self.load_state_dict(torch.load(checkpoint_dir))
|
||||
|
||||
|
||||
class TextualTransformerGen(ViewGen, TransformerGen):
|
||||
|
|
@ -39,23 +69,27 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
evaluate_step=10,
|
||||
verbose=False,
|
||||
patience=5,
|
||||
classification_type="multilabel",
|
||||
scheduler="ReduceLROnPlateau",
|
||||
):
|
||||
super().__init__(
|
||||
self._validate_model_name(model_name),
|
||||
dataset_name,
|
||||
epochs,
|
||||
lr,
|
||||
batch_size,
|
||||
batch_size_eval,
|
||||
max_length,
|
||||
print_steps,
|
||||
device,
|
||||
probabilistic,
|
||||
n_jobs,
|
||||
evaluate_step,
|
||||
verbose,
|
||||
patience,
|
||||
epochs=epochs,
|
||||
lr=lr,
|
||||
scheduler=scheduler,
|
||||
batch_size=batch_size,
|
||||
batch_size_eval=batch_size_eval,
|
||||
device=device,
|
||||
evaluate_step=evaluate_step,
|
||||
patience=patience,
|
||||
probabilistic=probabilistic,
|
||||
max_length=max_length,
|
||||
print_steps=print_steps,
|
||||
n_jobs=n_jobs,
|
||||
verbose=verbose,
|
||||
)
|
||||
self.clf_type = classification_type
|
||||
self.fitted = False
|
||||
print(
|
||||
f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
||||
|
|
@ -66,15 +100,22 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
return "bert-base-uncased"
|
||||
elif "mbert" == model_name:
|
||||
return "bert-base-multilingual-uncased"
|
||||
elif "xlm" == model_name:
|
||||
elif "xlm-roberta" == model_name:
|
||||
return "xlm-roberta-base"
|
||||
elif "mt5" == model_name:
|
||||
return "google/mt5-small"
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_pretrained_model(self, model_name, num_labels):
|
||||
return AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name, num_labels=num_labels, output_hidden_states=True
|
||||
)
|
||||
if model_name == "google/mt5-small":
|
||||
return MT5ForSequenceClassification(
|
||||
model_name, num_labels=num_labels, output_hidden_states=True
|
||||
)
|
||||
else:
|
||||
return AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name, num_labels=num_labels, output_hidden_states=True
|
||||
)
|
||||
|
||||
def load_tokenizer(self, model_name):
|
||||
return AutoTokenizer.from_pretrained(model_name)
|
||||
|
|
@ -127,9 +168,8 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
shuffle=False,
|
||||
)
|
||||
|
||||
experiment_name = (
|
||||
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||
)
|
||||
experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||
|
||||
trainer = Trainer(
|
||||
model=self.model,
|
||||
optimizer_name="adamW",
|
||||
|
|
@ -140,7 +180,16 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
evaluate_step=self.evaluate_step,
|
||||
patience=self.patience,
|
||||
experiment_name=experiment_name,
|
||||
checkpoint_path="models/vgfs/transformer",
|
||||
checkpoint_path=os.path.join(
|
||||
"models",
|
||||
"vgfs",
|
||||
"transformer",
|
||||
self._format_model_name(self.model_name),
|
||||
),
|
||||
vgf_name="textual_trf",
|
||||
classification_type=self.clf_type,
|
||||
n_jobs=self.n_jobs,
|
||||
scheduler_name=self.scheduler,
|
||||
)
|
||||
trainer.train(
|
||||
train_dataloader=tra_dataloader,
|
||||
|
|
@ -175,8 +224,12 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
with torch.no_grad():
|
||||
for input_ids, lang in dataloader:
|
||||
input_ids = input_ids.to(self.device)
|
||||
out = self.model(input_ids).hidden_states[-1]
|
||||
batch_embeddings = out[:, 0, :].cpu().numpy()
|
||||
# TODO: check this
|
||||
if isinstance(self.model, MT5ForSequenceClassification):
|
||||
batch_embeddings = self.model(input_ids).pooled.cpu().numpy()
|
||||
else:
|
||||
out = self.model(input_ids).hidden_states[-1]
|
||||
batch_embeddings = out[:, 0, :].cpu().numpy()
|
||||
_embeds.append((batch_embeddings, lang))
|
||||
|
||||
for embed, lang in _embeds:
|
||||
|
|
@ -206,39 +259,22 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
pickle.dump(self, f)
|
||||
return self
|
||||
|
||||
def __str__(self):
|
||||
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
||||
return str
|
||||
def freeze_model(self):
|
||||
# TODO: up to n-layers? or all? avoid freezing head ovb...
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _format_model_name(self, model_name):
|
||||
if "mt5" in model_name:
|
||||
return "google-mt5"
|
||||
elif "bert" in model_name:
|
||||
if "multilingual" in model_name:
|
||||
return "mbert"
|
||||
elif "xlm-roberta" in model_name:
|
||||
return "xlm-roberta"
|
||||
else:
|
||||
return model_name
|
||||
|
||||
class MultilingualDatasetTorch(Dataset):
|
||||
def __init__(self, lX, lY, split="train"):
|
||||
self.lX = lX
|
||||
self.lY = lY
|
||||
self.split = split
|
||||
self.langs = []
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
|
||||
if self.split != "whole":
|
||||
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
||||
self.langs = sum(
|
||||
[
|
||||
v
|
||||
for v in {
|
||||
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
|
||||
}.values()
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return len(self.X)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.split == "whole":
|
||||
return self.X[index], self.langs[index]
|
||||
return self.X[index], self.Y[index], self.langs[index]
|
||||
def get_config(self):
|
||||
c = super().get_config()
|
||||
return {"textual_trf": c}
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ class TransformerGen:
|
|||
evaluate_step=10,
|
||||
verbose=False,
|
||||
patience=5,
|
||||
scheduler=None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.dataset_name = dataset_name
|
||||
|
|
@ -46,6 +47,7 @@ class TransformerGen:
|
|||
self.verbose = verbose
|
||||
self.patience = patience
|
||||
self.datasets = {}
|
||||
self.scheduler = scheduler
|
||||
self.feature2posterior_projector = (
|
||||
self.make_probabilistic() if probabilistic else None
|
||||
)
|
||||
|
|
@ -94,3 +96,22 @@ class TransformerGen:
|
|||
val_lY[lang] = val_Y
|
||||
|
||||
return tr_lX, tr_lY, val_lX, val_lY
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
"model_name": self.model_name,
|
||||
"dataset_name": self.dataset_name,
|
||||
"epochs": self.epochs,
|
||||
"lr": self.lr,
|
||||
"scheduler": self.scheduler,
|
||||
"batch_size": self.batch_size,
|
||||
"batch_size_eval": self.batch_size_eval,
|
||||
"max_length": self.max_length,
|
||||
"print_steps": self.print_steps,
|
||||
"device": self.device,
|
||||
"probabilistic": self.probabilistic,
|
||||
"n_jobs": self.n_jobs,
|
||||
"evaluate_step": self.evaluate_step,
|
||||
"verbose": self.verbose,
|
||||
"patience": self.patience,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -65,8 +65,3 @@ class VanillaFunGen(ViewGen):
|
|||
with open(_path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
return self
|
||||
|
||||
def __str__(self):
|
||||
_str = f"[VanillaFunGen (-p)]\n- base learner: {self.learners}\n- n_jobs: {self.n_jobs}\n"
|
||||
# - parameters: {self.first_tier_parameters}
|
||||
return _str
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@ import numpy as np
|
|||
import torch
|
||||
import transformers
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||||
|
||||
from gfun.vgfs.commons import Trainer
|
||||
from gfun.vgfs.transformerGen import TransformerGen
|
||||
from gfun.vgfs.viewGen import ViewGen
|
||||
from dataManager.torchDataset import MultimodalDatasetTorch
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
|
@ -20,6 +20,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
model_name,
|
||||
dataset_name,
|
||||
lr=1e-5,
|
||||
scheduler="ReduceLROnPlateau",
|
||||
epochs=10,
|
||||
batch_size=32,
|
||||
batch_size_eval=128,
|
||||
|
|
@ -27,12 +28,14 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
device="cpu",
|
||||
probabilistic=False,
|
||||
patience=5,
|
||||
classification_type="multilabel",
|
||||
):
|
||||
super().__init__(
|
||||
model_name,
|
||||
dataset_name,
|
||||
lr=lr,
|
||||
epochs=epochs,
|
||||
lr=lr,
|
||||
scheduler=scheduler,
|
||||
batch_size=batch_size,
|
||||
batch_size_eval=batch_size_eval,
|
||||
device=device,
|
||||
|
|
@ -40,6 +43,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
patience=patience,
|
||||
probabilistic=probabilistic,
|
||||
)
|
||||
self.clf_type = classification_type
|
||||
self.fitted = False
|
||||
print(
|
||||
f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
||||
|
|
@ -97,7 +101,10 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
shuffle=False,
|
||||
)
|
||||
|
||||
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}"
|
||||
experiment_name = (
|
||||
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=self.model,
|
||||
optimizer_name="adamW",
|
||||
|
|
@ -109,6 +116,9 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
patience=self.patience,
|
||||
experiment_name=experiment_name,
|
||||
checkpoint_path="models/vgfs/transformer",
|
||||
vgf_name="visual_trf",
|
||||
classification_type=self.clf_type,
|
||||
n_jobs=self.n_jobs,
|
||||
)
|
||||
|
||||
trainer.train(
|
||||
|
|
@ -175,66 +185,5 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
|||
pickle.dump(self, f)
|
||||
return self
|
||||
|
||||
def __str__(self):
|
||||
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
||||
return str
|
||||
|
||||
|
||||
class MultimodalDatasetTorch(Dataset):
|
||||
def __init__(self, lX, lY, split="train"):
|
||||
self.lX = lX
|
||||
self.lY = lY
|
||||
self.split = split
|
||||
self.langs = []
|
||||
self.init()
|
||||
|
||||
def init(self):
|
||||
self.X = torch.vstack([imgs for imgs in self.lX.values()])
|
||||
if self.split != "whole":
|
||||
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
||||
self.langs = sum(
|
||||
[
|
||||
v
|
||||
for v in {
|
||||
lang: [lang] * len(data) for lang, data in self.lX.items()
|
||||
}.values()
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.X)
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.split == "whole":
|
||||
return self.X[index], self.langs[index]
|
||||
return self.X[index], self.Y[index], self.langs[index]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from os.path import expanduser
|
||||
|
||||
from dataManager.gFunDataset import gFunDataset
|
||||
|
||||
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
||||
dataset = gFunDataset(
|
||||
dataset_dir=GLAMI_DATAPATH,
|
||||
is_textual=True,
|
||||
is_visual=True,
|
||||
is_multilabel=False,
|
||||
nrows=50,
|
||||
)
|
||||
|
||||
vg = VisualTransformerGen(
|
||||
dataset_name=dataset.dataset_name,
|
||||
model_name="vit",
|
||||
device="cuda",
|
||||
epochs=5,
|
||||
evaluate_step=10,
|
||||
patience=10,
|
||||
probabilistic=True,
|
||||
)
|
||||
lX, lY = dataset.training()
|
||||
vg.fit(lX, lY)
|
||||
out = vg.transform(lX)
|
||||
exit(0)
|
||||
def get_config(self):
|
||||
return {"visual_trf": super().get_config()}
|
||||
|
|
|
|||
|
|
@ -40,10 +40,6 @@ class WceGen(ViewGen):
|
|||
"sif": self.sif,
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
_str = f"[WordClass VGF (w)]\n- sif: {self.sif}\n- n_jobs: {self.n_jobs}\n"
|
||||
return _str
|
||||
|
||||
def save_vgf(self, model_id):
|
||||
import pickle
|
||||
from os.path import join
|
||||
|
|
|
|||
111
main.py
111
main.py
|
|
@ -1,3 +1,8 @@
|
|||
import os
|
||||
import wandb
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
from argparse import ArgumentParser
|
||||
from time import time
|
||||
|
||||
|
|
@ -5,18 +10,36 @@ from dataManager.utils import get_dataset
|
|||
from evaluation.evaluate import evaluate, log_eval
|
||||
from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||
|
||||
"""
|
||||
TODO:
|
||||
- [!] add support for Binary Datasets (e.g. cls) - NB: CLS dataset is loading only "books" domain data
|
||||
- [!] documents should be trimmed to the same length (?)
|
||||
- [!] logging
|
||||
- add documentations sphinx
|
||||
- [!] zero-shot setup
|
||||
- FFNN posterior-probabilities' dependent
|
||||
- re-init langs when loading VGFs?
|
||||
- [!] loss of Attention-aggregator seems to be uncorrelated with Macro-F1 on the validation set!
|
||||
- [!] experiment with weight init of Attention-aggregator
|
||||
"""
|
||||
TODO:
|
||||
- Transformers VGFs:
|
||||
- scheduler with warmup and cosine
|
||||
- freeze params method
|
||||
- General:
|
||||
[!] zero-shot setup
|
||||
- CLS dataset is loading only "books" domain data
|
||||
- documents should be trimmed to the same length (for SVMs we are using way too long tokens)
|
||||
- Attention Aggregator:
|
||||
- experiment with weight init of Attention-aggregator
|
||||
- FFNN posterior-probabilities' dependent
|
||||
- Docs:
|
||||
- add documentations sphinx
|
||||
"""
|
||||
|
||||
|
||||
def get_config_name(args):
|
||||
config_name = ""
|
||||
if args.posteriors:
|
||||
config_name += "P+"
|
||||
if args.wce:
|
||||
config_name += "W+"
|
||||
if args.multilingual:
|
||||
config_name += "M+"
|
||||
if args.textual_transformer:
|
||||
config_name += f"TT_{args.textual_trf_name}+"
|
||||
if args.visual_transformer:
|
||||
config_name += f"VT_{args.visual_trf_name}+"
|
||||
return config_name.rstrip("+")
|
||||
|
||||
|
||||
def main(args):
|
||||
|
|
@ -43,6 +66,7 @@ def main(args):
|
|||
dataset_name=args.dataset,
|
||||
langs=dataset.langs(),
|
||||
num_labels=dataset.num_labels(),
|
||||
classification_type=args.clf_type,
|
||||
# Posterior VGF params ----------------
|
||||
posterior=args.posteriors,
|
||||
# Multilingual VGF params -------------
|
||||
|
|
@ -52,24 +76,26 @@ def main(args):
|
|||
wce=args.wce,
|
||||
# Transformer VGF params --------------
|
||||
textual_transformer=args.textual_transformer,
|
||||
textual_transformer_name=args.transformer_name,
|
||||
textual_transformer_name=args.textual_trf_name,
|
||||
batch_size=args.batch_size,
|
||||
eval_batch_size=args.eval_batch_size,
|
||||
epochs=args.epochs,
|
||||
lr=args.lr,
|
||||
textual_lr=args.textual_lr,
|
||||
visual_lr=args.visual_lr,
|
||||
max_length=args.max_length,
|
||||
patience=args.patience,
|
||||
evaluate_step=args.evaluate_step,
|
||||
device=args.device,
|
||||
# Visual Transformer VGF params --------------
|
||||
visual_transformer=args.visual_transformer,
|
||||
visual_transformer_name=args.visual_transformer_name,
|
||||
visual_transformer_name=args.visual_trf_name,
|
||||
# batch_size=args.batch_size,
|
||||
# epochs=args.epochs,
|
||||
# lr=args.lr,
|
||||
# patience=args.patience,
|
||||
# evaluate_step=args.evaluate_step,
|
||||
# device="cuda",
|
||||
# General params ----------------------
|
||||
# General params ---------------------
|
||||
probabilistic=args.features,
|
||||
aggfunc=args.aggfunc,
|
||||
optimc=args.optimc,
|
||||
|
|
@ -78,27 +104,54 @@ def main(args):
|
|||
n_jobs=args.n_jobs,
|
||||
)
|
||||
|
||||
# gfun.get_config()
|
||||
config = gfun.get_config()
|
||||
|
||||
wandb.init(project="gfun", name=f"gFun-{get_config_name(args)}", config=config)
|
||||
|
||||
gfun.fit(lX, lY)
|
||||
|
||||
if args.load_trained is None and not args.nosave:
|
||||
gfun.save(save_first_tier=True, save_meta=True)
|
||||
|
||||
# print("- Computing evaluation on training set")
|
||||
# preds = gfun.transform(lX)
|
||||
# train_eval = evaluate(lY, preds)
|
||||
# log_eval(train_eval, phase="train")
|
||||
|
||||
timetr = time()
|
||||
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
||||
|
||||
gfun_preds = gfun.transform(lX_te)
|
||||
test_eval = evaluate(lY_te, gfun_preds)
|
||||
log_eval(test_eval, phase="test")
|
||||
test_eval = evaluate(lY_te, gfun_preds, clf_type=args.clf_type, n_jobs=args.n_jobs)
|
||||
avg_metrics_gfun, lang_metrics_gfun = log_eval(
|
||||
test_eval, phase="test", clf_type=args.clf_type
|
||||
)
|
||||
|
||||
timeval = time()
|
||||
print(f"- testing completed in {timeval - timetr:.2f} seconds")
|
||||
|
||||
def log_barplot_wandb(gfun_res, title_affix="per langauge"):
|
||||
if title_affix == "per language":
|
||||
for metric, lang_values in gfun_res.items():
|
||||
data = [[lang, v] for lang, v in lang_values.items()]
|
||||
table = wandb.Table(data=data, columns=["lang", f"{metric}"])
|
||||
wandb.log(
|
||||
{
|
||||
f"gFun/language {metric}": wandb.plot.bar(
|
||||
table, "lang", metric, title=f"{metric} {title_affix}"
|
||||
)
|
||||
}
|
||||
)
|
||||
else:
|
||||
data = [[metric, value] for metric, value in gfun_res.items()]
|
||||
table = wandb.Table(data=data, columns=["metric", "value"])
|
||||
wandb.log(
|
||||
{
|
||||
f"gFun/average metric": wandb.plot.bar(
|
||||
table, "metric", "value", title=f"metric {title_affix}"
|
||||
)
|
||||
}
|
||||
)
|
||||
wandb.log(gfun_res)
|
||||
|
||||
log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
|
||||
log_barplot_wandb(avg_metrics_gfun, title_affix="averages")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
|
|
@ -112,6 +165,8 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--nrows", type=int, default=None)
|
||||
parser.add_argument("--min_count", type=int, default=10)
|
||||
parser.add_argument("--max_labels", type=int, default=50)
|
||||
parser.add_argument("--clf_type", type=str, default="multilabel")
|
||||
parser.add_argument("--save_dataset", action="store_true")
|
||||
# gFUN parameters ----------------------
|
||||
parser.add_argument("-p", "--posteriors", action="store_true")
|
||||
parser.add_argument("-m", "--multilingual", action="store_true")
|
||||
|
|
@ -123,15 +178,17 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--features", action="store_false")
|
||||
parser.add_argument("--aggfunc", type=str, default="mean")
|
||||
# transformer parameters ---------------
|
||||
parser.add_argument("--transformer_name", type=str, default="mbert")
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--lr", type=float, default=1e-5)
|
||||
parser.add_argument("--textual_trf_name", type=str, default="mbert")
|
||||
parser.add_argument("--batch_size", type=int, default=32)
|
||||
parser.add_argument("--eval_batch_size", type=int, default=128)
|
||||
parser.add_argument("--textual_lr", type=float, default=1e-4)
|
||||
parser.add_argument("--max_length", type=int, default=128)
|
||||
parser.add_argument("--patience", type=int, default=5)
|
||||
parser.add_argument("--evaluate_step", type=int, default=10)
|
||||
# Visual Transformer parameters --------------
|
||||
parser.add_argument("--visual_transformer_name", type=str, default="vit")
|
||||
parser.add_argument("--visual_trf_name", type=str, default="vit")
|
||||
parser.add_argument("--visual_lr", type=float, default=1e-4)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
beautifulsoup4==4.11.2
|
||||
joblib==1.2.0
|
||||
matplotlib==3.7.1
|
||||
numpy==1.24.2
|
||||
matplotlib==3.6.3
|
||||
numpy==1.24.1
|
||||
pandas==1.5.3
|
||||
Pillow==9.4.0
|
||||
requests==2.28.2
|
||||
scikit_learn==1.2.1
|
||||
scikit_learn==1.2.2
|
||||
scipy==1.10.1
|
||||
torch==1.13.1
|
||||
torchtext==0.14.1
|
||||
tqdm==4.65.0
|
||||
transformers==4.26.1
|
||||
tqdm==4.64.1
|
||||
transformers==4.26.0
|
||||
|
|
|
|||
Loading…
Reference in New Issue