Compare commits
22 Commits
binary-cls
...
master
| Author | SHA1 | Date |
|---|---|---|
|
|
ab7a310b34 | |
|
|
41647f974a | |
|
|
ee2a9481de | |
|
|
ee38bcda10 | |
|
|
b34da419d0 | |
|
|
17d0003e48 | |
|
|
9d43ebb23b | |
|
|
56faaf2615 | |
|
|
f32b9227ae | |
|
|
65407f51fa | |
|
|
26aa0b327a | |
|
|
fece8d059e | |
|
|
5e41b4517a | |
|
|
a3e183d7fc | |
|
|
57918ec523 | |
|
|
7d0d6ba1f6 | |
|
|
5ef0904e0e | |
|
|
7e1ec46ebd | |
|
|
3240150542 | |
|
|
84dd1f093e | |
|
|
6b7917ca47 | |
|
|
7dead90271 |
|
|
@ -182,3 +182,4 @@ scripts/
|
||||||
logger/*
|
logger/*
|
||||||
explore_data.ipynb
|
explore_data.ipynb
|
||||||
run.sh
|
run.sh
|
||||||
|
wandb
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
import os
|
||||||
|
|
||||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||||
from dataManager.glamiDataset import get_dataframe
|
from dataManager.glamiDataset import get_dataframe
|
||||||
from dataManager.multilingualDataset import MultilingualDataset
|
from dataManager.multilingualDataset import MultilingualDataset
|
||||||
|
|
@ -22,7 +24,7 @@ class gFunDataset:
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
self.nrows = nrows
|
self.nrows = nrows
|
||||||
self.dataset = {}
|
self.dataset = {}
|
||||||
self.load_dataset()
|
self._load_dataset()
|
||||||
|
|
||||||
def get_label_binarizer(self, labels):
|
def get_label_binarizer(self, labels):
|
||||||
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
if self.dataset_name in ["rcv1-2", "jrc", "cls"]:
|
||||||
|
|
@ -35,7 +37,7 @@ class gFunDataset:
|
||||||
mlb.fit(labels)
|
mlb.fit(labels)
|
||||||
return mlb
|
return mlb
|
||||||
|
|
||||||
def load_dataset(self):
|
def _load_dataset(self):
|
||||||
if "glami" in self.dataset_dir.lower():
|
if "glami" in self.dataset_dir.lower():
|
||||||
print(f"- Loading GLAMI dataset from {self.dataset_dir}")
|
print(f"- Loading GLAMI dataset from {self.dataset_dir}")
|
||||||
self.dataset_name = "glami"
|
self.dataset_name = "glami"
|
||||||
|
|
@ -106,44 +108,19 @@ class gFunDataset:
|
||||||
return dataset, labels, data_langs
|
return dataset, labels, data_langs
|
||||||
|
|
||||||
def _load_glami(self, dataset_dir, nrows):
|
def _load_glami(self, dataset_dir, nrows):
|
||||||
def _balanced_sample(data, n, remainder=0):
|
train_split = get_dataframe("train", dataset_dir=dataset_dir).sample(n=nrows)
|
||||||
import pandas as pd
|
test_split = get_dataframe("test", dataset_dir=dataset_dir).sample(
|
||||||
|
n=int(nrows / 10)
|
||||||
langs = sorted(data.geo.unique().tolist())
|
|
||||||
dict_n = {lang: n for lang in langs}
|
|
||||||
dict_n[langs[0]] += remainder
|
|
||||||
|
|
||||||
sampled = []
|
|
||||||
for lang in langs:
|
|
||||||
sampled.append(data[data.geo == lang].sample(n=dict_n[lang]))
|
|
||||||
|
|
||||||
return pd.concat(sampled, axis=0)
|
|
||||||
|
|
||||||
# TODO: set this sampling as determinsitic/dependeing on the seed
|
|
||||||
lang_nrows = (
|
|
||||||
nrows // 13 if self.data_langs is None else nrows // len(self.data_langs)
|
|
||||||
) # GLAMI 1-M has 13 languages
|
|
||||||
remainder = (
|
|
||||||
nrows % 13 if self.data_langs is None else nrows % len(self.data_langs)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
train_split = get_dataframe("train", dataset_dir=dataset_dir)
|
|
||||||
train_split = _balanced_sample(train_split, lang_nrows, remainder=remainder)
|
|
||||||
|
|
||||||
if self.data_langs is None:
|
|
||||||
data_langs = sorted(train_split.geo.unique().tolist())
|
|
||||||
# TODO: if data langs is NOT none then we have a problem where we filter df by langs
|
|
||||||
if self.labels is None:
|
|
||||||
labels = train_split.category_name.unique().tolist()
|
|
||||||
|
|
||||||
# TODO: atm test data should contain same languages as train data
|
|
||||||
test_split = get_dataframe("test", dataset_dir=dataset_dir)
|
|
||||||
# TODO: atm we're using 1:1 train-test
|
|
||||||
test_split = _balanced_sample(test_split, lang_nrows, remainder=remainder)
|
|
||||||
|
|
||||||
gb_train = train_split.groupby("geo")
|
gb_train = train_split.groupby("geo")
|
||||||
gb_test = test_split.groupby("geo")
|
gb_test = test_split.groupby("geo")
|
||||||
|
|
||||||
|
if self.data_langs is None:
|
||||||
|
data_langs = sorted(train_split.geo.unique().tolist())
|
||||||
|
if self.labels is None:
|
||||||
|
labels = train_split.category_name.unique().tolist()
|
||||||
|
|
||||||
def _format_glami(data_df):
|
def _format_glami(data_df):
|
||||||
text = (data_df.name + " " + data_df.description).tolist()
|
text = (data_df.name + " " + data_df.description).tolist()
|
||||||
image = data_df.image_file.tolist()
|
image = data_df.image_file.tolist()
|
||||||
|
|
@ -205,6 +182,14 @@ class gFunDataset:
|
||||||
else:
|
else:
|
||||||
return self.labels
|
return self.labels
|
||||||
|
|
||||||
|
def save_as_pickle(self, path):
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
filepath = os.path.join(path, f"{self.dataset_name}_{self.nrows}.pkl")
|
||||||
|
with open(filepath, "wb") as f:
|
||||||
|
print(f"- saving dataset in {filepath}")
|
||||||
|
pickle.dump(self, f)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import os
|
import os
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,66 @@
|
||||||
class TorchMultiNewsDataset:
|
import torch
|
||||||
pass
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class MultilingualDatasetTorch(Dataset):
|
||||||
|
def __init__(self, lX, lY, split="train"):
|
||||||
|
self.lX = lX
|
||||||
|
self.lY = lY
|
||||||
|
self.split = split
|
||||||
|
self.langs = []
|
||||||
|
self.init()
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
|
||||||
|
if self.split != "whole":
|
||||||
|
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
||||||
|
self.langs = sum(
|
||||||
|
[
|
||||||
|
v
|
||||||
|
for v in {
|
||||||
|
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
|
||||||
|
}.values()
|
||||||
|
],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.X)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
if self.split == "whole":
|
||||||
|
return self.X[index], self.langs[index]
|
||||||
|
return self.X[index], self.Y[index], self.langs[index]
|
||||||
|
|
||||||
|
|
||||||
|
class MultimodalDatasetTorch(Dataset):
|
||||||
|
def __init__(self, lX, lY, split="train"):
|
||||||
|
self.lX = lX
|
||||||
|
self.lY = lY
|
||||||
|
self.split = split
|
||||||
|
self.langs = []
|
||||||
|
self.init()
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
self.X = torch.vstack([imgs for imgs in self.lX.values()])
|
||||||
|
if self.split != "whole":
|
||||||
|
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
||||||
|
self.langs = sum(
|
||||||
|
[
|
||||||
|
v
|
||||||
|
for v in {
|
||||||
|
lang: [lang] * len(data) for lang, data in self.lX.items()
|
||||||
|
}.values()
|
||||||
|
],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.X)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
if self.split == "whole":
|
||||||
|
return self.X[index], self.langs[index]
|
||||||
|
return self.X[index], self.Y[index], self.langs[index]
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,21 @@
|
||||||
from os.path import expanduser
|
from os.path import expanduser, join
|
||||||
from dataManager.gFunDataset import gFunDataset
|
from dataManager.gFunDataset import gFunDataset
|
||||||
from dataManager.multiNewsDataset import MultiNewsDataset
|
from dataManager.multiNewsDataset import MultiNewsDataset
|
||||||
from dataManager.amazonDataset import AmazonDataset
|
from dataManager.amazonDataset import AmazonDataset
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_pickle(path, dataset_name, nrows):
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
filepath = join(path, f"{dataset_name}_{nrows}.pkl")
|
||||||
|
|
||||||
|
with open(filepath, "rb") as f:
|
||||||
|
loaded = pickle.load(f)
|
||||||
|
print(f"- Loaded dataset from {filepath}")
|
||||||
|
loaded.show_dimension()
|
||||||
|
return loaded
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(dataset_name, args):
|
def get_dataset(dataset_name, args):
|
||||||
assert dataset_name in [
|
assert dataset_name in [
|
||||||
"multinews",
|
"multinews",
|
||||||
|
|
@ -58,13 +70,19 @@ def get_dataset(dataset_name, args):
|
||||||
nrows=args.nrows,
|
nrows=args.nrows,
|
||||||
)
|
)
|
||||||
elif dataset_name == "glami":
|
elif dataset_name == "glami":
|
||||||
dataset = gFunDataset(
|
if args.save_dataset is False:
|
||||||
dataset_dir=GLAMI_DATAPATH,
|
dataset = load_from_pickle(GLAMI_DATAPATH, dataset_name, args.nrows)
|
||||||
is_textual=True,
|
else:
|
||||||
is_visual=True,
|
dataset = gFunDataset(
|
||||||
is_multilabel=False,
|
dataset_dir=GLAMI_DATAPATH,
|
||||||
nrows=args.nrows,
|
is_textual=True,
|
||||||
)
|
is_visual=True,
|
||||||
|
is_multilabel=False,
|
||||||
|
nrows=args.nrows,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset.save_as_pickle(GLAMI_DATAPATH)
|
||||||
|
|
||||||
elif dataset_name == "cls":
|
elif dataset_name == "cls":
|
||||||
dataset = gFunDataset(
|
dataset = gFunDataset(
|
||||||
dataset_dir=CLS_DATAPATH,
|
dataset_dir=CLS_DATAPATH,
|
||||||
|
|
|
||||||
|
|
@ -1,51 +1,96 @@
|
||||||
from joblib import Parallel, delayed
|
from joblib import Parallel, delayed
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
from evaluation.metrics import *
|
from evaluation.metrics import *
|
||||||
|
from sklearn.metrics import accuracy_score, top_k_accuracy_score, f1_score
|
||||||
|
|
||||||
|
|
||||||
def evaluation_metrics(y, y_):
|
def evaluation_metrics(y, y_, clf_type):
|
||||||
if len(y.shape) == len(y_.shape) == 1 and len(np.unique(y)) > 2: # single-label
|
if clf_type == "singlelabel":
|
||||||
raise NotImplementedError()
|
return (
|
||||||
else:
|
accuracy_score(y, y_),
|
||||||
|
# TODO: we need the logits to compute this top_k_accuracy_score(y, y_, k=5),
|
||||||
|
# TODO: we need logits top_k_accuracy_score(y, y_, k=10),
|
||||||
|
f1_score(y, y_, average="macro", zero_division=1),
|
||||||
|
f1_score(y, y_, average="micro"),
|
||||||
|
)
|
||||||
|
elif clf_type == "multilabel":
|
||||||
return (
|
return (
|
||||||
macroF1(y, y_),
|
macroF1(y, y_),
|
||||||
microF1(y, y_),
|
microF1(y, y_),
|
||||||
macroK(y, y_),
|
macroK(y, y_),
|
||||||
microK(y, y_),
|
microK(y, y_),
|
||||||
# macroAcc(y, y_),
|
|
||||||
microAcc(
|
|
||||||
y, y_
|
|
||||||
), # TODO: we're using micro-averaging for accuracy, it is == to accuracy_score on binary classification
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("clf_type must be either 'singlelabel' or 'multilabel'")
|
||||||
|
|
||||||
|
|
||||||
def evaluate(ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1):
|
def evaluate(
|
||||||
|
ly_true, ly_pred, metrics=evaluation_metrics, n_jobs=-1, clf_type="multilabel"
|
||||||
|
):
|
||||||
if n_jobs == 1:
|
if n_jobs == 1:
|
||||||
return {lang: metrics(ly_true[lang], ly_pred[lang]) for lang in ly_true.keys()}
|
return {
|
||||||
|
lang: metrics(ly_true[lang], ly_pred[lang], clf_type)
|
||||||
|
for lang in ly_true.keys()
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
langs = list(ly_true.keys())
|
langs = list(ly_true.keys())
|
||||||
evals = Parallel(n_jobs=n_jobs)(
|
evals = Parallel(n_jobs=n_jobs)(
|
||||||
delayed(metrics)(ly_true[lang], ly_pred[lang]) for lang in langs
|
delayed(metrics)(ly_true[lang], ly_pred[lang], clf_type) for lang in langs
|
||||||
)
|
)
|
||||||
return {lang: evals[i] for i, lang in enumerate(langs)}
|
return {lang: evals[i] for i, lang in enumerate(langs)}
|
||||||
|
|
||||||
|
|
||||||
def log_eval(l_eval, phase="training", verbose=True):
|
def log_eval(l_eval, phase="training", clf_type="multilabel", verbose=True):
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"\n[Results {phase}]")
|
print(f"\n[Results {phase}]")
|
||||||
metrics = []
|
metrics = []
|
||||||
for lang in l_eval.keys():
|
|
||||||
macrof1, microf1, macrok, microk, microAcc = l_eval[lang]
|
if clf_type == "multilabel":
|
||||||
metrics.append([macrof1, microf1, macrok, microk, microAcc])
|
for lang in l_eval.keys():
|
||||||
if phase != "validation":
|
macrof1, microf1, macrok, microk = l_eval[lang]
|
||||||
|
metrics.append([macrof1, microf1, macrok, microk])
|
||||||
|
if phase != "validation":
|
||||||
|
print(f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f}")
|
||||||
|
averages = np.mean(np.array(metrics), axis=0)
|
||||||
|
if verbose:
|
||||||
print(
|
print(
|
||||||
f"Lang {lang}: macro-F1 = {macrof1:.3f} micro-F1 = {microf1:.3f} acc = {microAcc:.3f}"
|
"Averages: MF1, mF1, MK, mK",
|
||||||
|
np.round(averages, 3),
|
||||||
|
"\n",
|
||||||
)
|
)
|
||||||
averages = np.mean(np.array(metrics), axis=0)
|
return averages # TODO: return a dict avg and lang specific
|
||||||
if verbose:
|
|
||||||
print(
|
elif clf_type == "singlelabel":
|
||||||
"Averages: MF1, mF1, MK, mK",
|
lang_metrics = defaultdict(dict)
|
||||||
np.round(averages, 3),
|
_metrics = [
|
||||||
"\n",
|
"accuracy",
|
||||||
)
|
# "acc5", # "accuracy-at-5",
|
||||||
return averages
|
# "acc10", # "accuracy-at-10",
|
||||||
|
"MF1", # "macro-F1",
|
||||||
|
"mF1", # "micro-F1",
|
||||||
|
]
|
||||||
|
for lang in l_eval.keys():
|
||||||
|
# acc, top5, top10, macrof1, microf1 = l_eval[lang]
|
||||||
|
acc, macrof1, microf1 = l_eval[lang]
|
||||||
|
# metrics.append([acc, top5, top10, macrof1, microf1])
|
||||||
|
metrics.append([acc, macrof1, microf1])
|
||||||
|
|
||||||
|
for m, v in zip(_metrics, l_eval[lang]):
|
||||||
|
lang_metrics[m][lang] = v
|
||||||
|
|
||||||
|
if phase != "validation":
|
||||||
|
print(
|
||||||
|
# f"Lang {lang}: acc = {acc:.3f} acc-top5 = {top5:.3f} acc-top10 = {top10:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
|
||||||
|
f"Lang {lang}: acc = {acc:.3f} macro-F1: {macrof1:.3f} micro-F1 = {microf1:.3f}"
|
||||||
|
)
|
||||||
|
averages = np.mean(np.array(metrics), axis=0)
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
# "Averages: Acc, Acc-5, Acc-10, MF1, mF1",
|
||||||
|
"Averages: Acc, MF1, mF1",
|
||||||
|
np.round(averages, 3),
|
||||||
|
"\n",
|
||||||
|
)
|
||||||
|
avg_metrics = dict(zip(_metrics, averages))
|
||||||
|
return avg_metrics, lang_metrics
|
||||||
|
|
|
||||||
|
|
@ -239,7 +239,3 @@ def microK(true_labels, predicted_labels):
|
||||||
|
|
||||||
def macroAcc(true_labels, predicted_labels):
|
def macroAcc(true_labels, predicted_labels):
|
||||||
return macro_average(true_labels, predicted_labels, accuracy)
|
return macro_average(true_labels, predicted_labels, accuracy)
|
||||||
|
|
||||||
|
|
||||||
def microAcc(true_labels, predicted_labels):
|
|
||||||
return micro_average(true_labels, predicted_labels, accuracy)
|
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,14 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
|
|
||||||
# sys.path.append(os.path.join(os.getcwd(), "gfun"))
|
|
||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gfun.vgfs.commons import TfidfVectorizerMultilingual, AttentionAggregator
|
|
||||||
|
from gfun.vgfs.commons import AttentionAggregator, TfidfVectorizerMultilingual, predict
|
||||||
from gfun.vgfs.learners.svms import MetaClassifier, get_learner
|
from gfun.vgfs.learners.svms import MetaClassifier, get_learner
|
||||||
from gfun.vgfs.multilingualGen import MultilingualGen
|
from gfun.vgfs.multilingualGen import MultilingualGen
|
||||||
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
|
from gfun.vgfs.textualTransformerGen import TextualTransformerGen
|
||||||
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
|
|
||||||
from gfun.vgfs.vanillaFun import VanillaFunGen
|
from gfun.vgfs.vanillaFun import VanillaFunGen
|
||||||
|
from gfun.vgfs.visualTransformerGen import VisualTransformerGen
|
||||||
from gfun.vgfs.wceGen import WceGen
|
from gfun.vgfs.wceGen import WceGen
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -25,11 +22,14 @@ class GeneralizedFunnelling:
|
||||||
visual_transformer,
|
visual_transformer,
|
||||||
langs,
|
langs,
|
||||||
num_labels,
|
num_labels,
|
||||||
|
classification_type,
|
||||||
embed_dir,
|
embed_dir,
|
||||||
n_jobs,
|
n_jobs,
|
||||||
batch_size,
|
batch_size,
|
||||||
|
eval_batch_size,
|
||||||
max_length,
|
max_length,
|
||||||
lr,
|
textual_lr,
|
||||||
|
visual_lr,
|
||||||
epochs,
|
epochs,
|
||||||
patience,
|
patience,
|
||||||
evaluate_step,
|
evaluate_step,
|
||||||
|
|
@ -47,26 +47,31 @@ class GeneralizedFunnelling:
|
||||||
self.posteriors_vgf = posterior
|
self.posteriors_vgf = posterior
|
||||||
self.wce_vgf = wce
|
self.wce_vgf = wce
|
||||||
self.multilingual_vgf = multilingual
|
self.multilingual_vgf = multilingual
|
||||||
self.trasformer_vgf = textual_transformer
|
self.textual_trf_vgf = textual_transformer
|
||||||
self.visual_transformer_vgf = visual_transformer
|
self.visual_trf_vgf = visual_transformer
|
||||||
self.probabilistic = probabilistic
|
self.probabilistic = probabilistic
|
||||||
self.num_labels = num_labels
|
self.num_labels = num_labels
|
||||||
|
self.clf_type = classification_type
|
||||||
# ------------------------
|
# ------------------------
|
||||||
self.langs = langs
|
self.langs = langs
|
||||||
self.embed_dir = embed_dir
|
self.embed_dir = embed_dir
|
||||||
self.cached = True
|
self.cached = True
|
||||||
# Textual Transformer VGF params ----------
|
# Textual Transformer VGF params ----------
|
||||||
self.textaul_transformer_name = textual_transformer_name
|
self.textual_trf_name = textual_transformer_name
|
||||||
self.epochs = epochs
|
self.epochs = epochs
|
||||||
self.lr_transformer = lr
|
self.textual_trf_lr = textual_lr
|
||||||
self.batch_size_transformer = batch_size
|
self.textual_scheduler = "ReduceLROnPlateau"
|
||||||
|
self.batch_size_trf = batch_size
|
||||||
|
self.eval_batch_size_trf = eval_batch_size
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.early_stopping = True
|
self.early_stopping = True
|
||||||
self.patience = patience
|
self.patience = patience
|
||||||
self.evaluate_step = evaluate_step
|
self.evaluate_step = evaluate_step
|
||||||
self.device = device
|
self.device = device
|
||||||
# Visual Transformer VGF params ----------
|
# Visual Transformer VGF params ----------
|
||||||
self.visual_transformer_name = visual_transformer_name
|
self.visual_trf_name = visual_transformer_name
|
||||||
|
self.visual_trf_lr = visual_lr
|
||||||
|
self.visual_scheduler = "ReduceLROnPlateau"
|
||||||
# Metaclassifier params ------------
|
# Metaclassifier params ------------
|
||||||
self.optimc = optimc
|
self.optimc = optimc
|
||||||
# -------------------
|
# -------------------
|
||||||
|
|
@ -77,7 +82,7 @@ class GeneralizedFunnelling:
|
||||||
self.aggfunc = aggfunc
|
self.aggfunc = aggfunc
|
||||||
self.load_trained = load_trained
|
self.load_trained = load_trained
|
||||||
self.load_first_tier = (
|
self.load_first_tier = (
|
||||||
True # TODO: i guess we're always going to load at least the fitst tier
|
True # TODO: i guess we're always going to load at least the first tier
|
||||||
)
|
)
|
||||||
self.load_meta = load_meta
|
self.load_meta = load_meta
|
||||||
self.dataset_name = dataset_name
|
self.dataset_name = dataset_name
|
||||||
|
|
@ -112,7 +117,7 @@ class GeneralizedFunnelling:
|
||||||
self.attn_aggregator = AttentionAggregator(
|
self.attn_aggregator = AttentionAggregator(
|
||||||
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
||||||
out_dim=self.num_labels,
|
out_dim=self.num_labels,
|
||||||
lr=self.lr_transformer,
|
lr=self.textual_trf_lr,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
num_heads=1,
|
num_heads=1,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
|
@ -142,13 +147,15 @@ class GeneralizedFunnelling:
|
||||||
wce_vgf = WceGen(n_jobs=self.n_jobs)
|
wce_vgf = WceGen(n_jobs=self.n_jobs)
|
||||||
self.first_tier_learners.append(wce_vgf)
|
self.first_tier_learners.append(wce_vgf)
|
||||||
|
|
||||||
if self.trasformer_vgf:
|
if self.textual_trf_vgf:
|
||||||
transformer_vgf = TextualTransformerGen(
|
transformer_vgf = TextualTransformerGen(
|
||||||
dataset_name=self.dataset_name,
|
dataset_name=self.dataset_name,
|
||||||
model_name=self.textaul_transformer_name,
|
model_name=self.textual_trf_name,
|
||||||
lr=self.lr_transformer,
|
lr=self.textual_trf_lr,
|
||||||
|
scheduler=self.textual_scheduler,
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
batch_size=self.batch_size_transformer,
|
batch_size=self.batch_size_trf,
|
||||||
|
batch_size_eval=self.eval_batch_size_trf,
|
||||||
max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
print_steps=50,
|
print_steps=50,
|
||||||
probabilistic=self.probabilistic,
|
probabilistic=self.probabilistic,
|
||||||
|
|
@ -156,21 +163,24 @@ class GeneralizedFunnelling:
|
||||||
verbose=True,
|
verbose=True,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
classification_type=self.clf_type,
|
||||||
)
|
)
|
||||||
self.first_tier_learners.append(transformer_vgf)
|
self.first_tier_learners.append(transformer_vgf)
|
||||||
|
|
||||||
if self.visual_transformer_vgf:
|
if self.visual_trf_vgf:
|
||||||
visual_trasformer_vgf = VisualTransformerGen(
|
visual_trasformer_vgf = VisualTransformerGen(
|
||||||
dataset_name=self.dataset_name,
|
dataset_name=self.dataset_name,
|
||||||
model_name="vit",
|
model_name="vit",
|
||||||
lr=1e-5, # self.lr_visual_transformer,
|
lr=self.visual_trf_lr,
|
||||||
|
scheduler=self.visual_scheduler,
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
batch_size=32, # self.batch_size_visual_transformer,
|
batch_size=self.batch_size_trf,
|
||||||
# batch_size_eval=128,
|
batch_size_eval=self.eval_batch_size_trf,
|
||||||
probabilistic=self.probabilistic,
|
probabilistic=self.probabilistic,
|
||||||
evaluate_step=self.evaluate_step,
|
evaluate_step=self.evaluate_step,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
classification_type=self.clf_type,
|
||||||
)
|
)
|
||||||
self.first_tier_learners.append(visual_trasformer_vgf)
|
self.first_tier_learners.append(visual_trasformer_vgf)
|
||||||
|
|
||||||
|
|
@ -179,7 +189,7 @@ class GeneralizedFunnelling:
|
||||||
self.attn_aggregator = AttentionAggregator(
|
self.attn_aggregator = AttentionAggregator(
|
||||||
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
embed_dim=self.get_attn_agg_dim(attn_stacking_type=attn_stacking),
|
||||||
out_dim=self.num_labels,
|
out_dim=self.num_labels,
|
||||||
lr=self.lr_transformer,
|
lr=self.textual_trf_lr,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
num_heads=1,
|
num_heads=1,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
|
@ -198,7 +208,8 @@ class GeneralizedFunnelling:
|
||||||
self.posteriors_vgf,
|
self.posteriors_vgf,
|
||||||
self.multilingual_vgf,
|
self.multilingual_vgf,
|
||||||
self.wce_vgf,
|
self.wce_vgf,
|
||||||
self.trasformer_vgf,
|
self.textual_trf_vgf,
|
||||||
|
self.visual_trf_vgf,
|
||||||
self.aggfunc,
|
self.aggfunc,
|
||||||
)
|
)
|
||||||
print(f"- model id: {self._model_id}")
|
print(f"- model id: {self._model_id}")
|
||||||
|
|
@ -251,10 +262,9 @@ class GeneralizedFunnelling:
|
||||||
projections.append(l_posteriors)
|
projections.append(l_posteriors)
|
||||||
agg = self.aggregate(projections)
|
agg = self.aggregate(projections)
|
||||||
l_out = self.metaclassifier.predict_proba(agg)
|
l_out = self.metaclassifier.predict_proba(agg)
|
||||||
# converting to binary predictions
|
if self.clf_type == "singlelabel":
|
||||||
# if self.dataset_name in ["cls"]: # TODO: better way to do this
|
for lang, preds in l_out.items():
|
||||||
# for lang, preds in l_out.items():
|
l_out[lang] = predict(preds, clf_type=self.clf_type)
|
||||||
# l_out[lang] = np.expand_dims(np.argmax(preds, axis=1), axis=1)
|
|
||||||
return l_out
|
return l_out
|
||||||
|
|
||||||
def fit_transform(self, lX, lY):
|
def fit_transform(self, lX, lY):
|
||||||
|
|
@ -303,15 +313,21 @@ class GeneralizedFunnelling:
|
||||||
return aggregated
|
return aggregated
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
print("\n")
|
c = {}
|
||||||
print("-" * 50)
|
|
||||||
print("[GeneralizedFunnelling config]")
|
|
||||||
print(f"- model trained on langs: {self.langs}")
|
|
||||||
print("-- View Generating Functions configurations:\n")
|
|
||||||
|
|
||||||
for vgf in self.first_tier_learners:
|
for vgf in self.first_tier_learners:
|
||||||
print(vgf)
|
vgf_config = vgf.get_config()
|
||||||
print("-" * 50)
|
c.update(vgf_config)
|
||||||
|
|
||||||
|
gfun_config = {
|
||||||
|
"id": self._model_id,
|
||||||
|
"aggfunc": self.aggfunc,
|
||||||
|
"optimc": self.optimc,
|
||||||
|
"dataset": self.dataset_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
c["gFun"] = gfun_config
|
||||||
|
return c
|
||||||
|
|
||||||
def save(self, save_first_tier=True, save_meta=True):
|
def save(self, save_first_tier=True, save_meta=True):
|
||||||
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")
|
print(f"- Saving GeneralizedFunnelling model with id: {self._model_id}")
|
||||||
|
|
@ -334,7 +350,7 @@ class GeneralizedFunnelling:
|
||||||
pickle.dump(self.metaclassifier, f)
|
pickle.dump(self.metaclassifier, f)
|
||||||
return
|
return
|
||||||
|
|
||||||
def save_first_tier_learners(self):
|
def save_first_tier_learners(self, model_id):
|
||||||
for vgf in self.first_tier_learners:
|
for vgf in self.first_tier_learners:
|
||||||
vgf.save_vgf(model_id=self._model_id)
|
vgf.save_vgf(model_id=self._model_id)
|
||||||
return self
|
return self
|
||||||
|
|
@ -372,7 +388,7 @@ class GeneralizedFunnelling:
|
||||||
"rb",
|
"rb",
|
||||||
) as vgf:
|
) as vgf:
|
||||||
first_tier_learners.append(pickle.load(vgf))
|
first_tier_learners.append(pickle.load(vgf))
|
||||||
if self.trasformer_vgf:
|
if self.textual_trf_vgf:
|
||||||
with open(
|
with open(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
|
"models", "vgfs", "transformer", f"transformerGen_{model_id}.pkl"
|
||||||
|
|
@ -427,7 +443,15 @@ def get_params(optimc=False):
|
||||||
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
|
return [{"kernel": [kernel], "C": c_range, "gamma": ["auto"]}]
|
||||||
|
|
||||||
|
|
||||||
def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfunc):
|
def get_unique_id(
|
||||||
|
dataset_name,
|
||||||
|
posterior,
|
||||||
|
multilingual,
|
||||||
|
wce,
|
||||||
|
textual_transformer,
|
||||||
|
visual_transformer,
|
||||||
|
aggfunc,
|
||||||
|
):
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
now = datetime.now().strftime("%y%m%d")
|
now = datetime.now().strftime("%y%m%d")
|
||||||
|
|
@ -435,6 +459,7 @@ def get_unique_id(dataset_name, posterior, multilingual, wce, transformer, aggfu
|
||||||
model_id += "p" if posterior else ""
|
model_id += "p" if posterior else ""
|
||||||
model_id += "m" if multilingual else ""
|
model_id += "m" if multilingual else ""
|
||||||
model_id += "w" if wce else ""
|
model_id += "w" if wce else ""
|
||||||
model_id += "t" if transformer else ""
|
model_id += "t" if textual_transformer else ""
|
||||||
|
model_id += "v" if visual_transformer else ""
|
||||||
model_id += f"_{aggfunc}"
|
model_id += f"_{aggfunc}"
|
||||||
return f"{model_id}_{now}"
|
return f"{model_id}_{now}"
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,11 @@ from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from sklearn.preprocessing import normalize
|
from sklearn.preprocessing import normalize
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
from transformers.modeling_outputs import ModelOutput
|
from transformers.modeling_outputs import ModelOutput
|
||||||
|
|
||||||
|
import wandb
|
||||||
from evaluation.evaluate import evaluate, log_eval
|
from evaluation.evaluate import evaluate, log_eval
|
||||||
|
|
||||||
PRINT_ON_EPOCH = 1
|
PRINT_ON_EPOCH = 1
|
||||||
|
|
@ -21,6 +23,28 @@ def _normalize(lX, l2=True):
|
||||||
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
|
return {lang: normalize(np.asarray(X)) for lang, X in lX.items()} if l2 else lX
|
||||||
|
|
||||||
|
|
||||||
|
def verbosity_eval(epoch, print_eval):
|
||||||
|
if (epoch + 1) % print_eval == 0 and epoch != 0:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def format_langkey_wandb(lang_dict, vgf_name):
|
||||||
|
log_dict = {}
|
||||||
|
for metric, l_dict in lang_dict.items():
|
||||||
|
for lang, value in l_dict.items():
|
||||||
|
log_dict[f"{vgf_name}/language metric/{metric}/{lang}"] = value
|
||||||
|
return log_dict
|
||||||
|
|
||||||
|
|
||||||
|
def format_average_wandb(avg_dict, vgf_name):
|
||||||
|
log_dict = {}
|
||||||
|
for metric, value in avg_dict.items():
|
||||||
|
log_dict[f"{vgf_name}/average metric/{metric}"] = value
|
||||||
|
return log_dict
|
||||||
|
|
||||||
|
|
||||||
def XdotM(X, M, sif):
|
def XdotM(X, M, sif):
|
||||||
E = X.dot(M)
|
E = X.dot(M)
|
||||||
if sif:
|
if sif:
|
||||||
|
|
@ -57,18 +81,23 @@ def compute_pc(X, npc=1):
|
||||||
return svd.components_
|
return svd.components_
|
||||||
|
|
||||||
|
|
||||||
def predict(logits, classification_type="multilabel"):
|
def predict(logits, clf_type="multilabel"):
|
||||||
"""
|
"""
|
||||||
Converts soft precictions to hard predictions [0,1]
|
Converts soft precictions to hard predictions [0,1]
|
||||||
"""
|
"""
|
||||||
if classification_type == "multilabel":
|
if clf_type == "multilabel":
|
||||||
prediction = torch.sigmoid(logits) > 0.5
|
prediction = torch.sigmoid(logits) > 0.5
|
||||||
elif classification_type == "singlelabel":
|
return prediction.detach().cpu().numpy()
|
||||||
prediction = torch.argmax(logits, dim=1).view(-1, 1)
|
elif clf_type == "singlelabel":
|
||||||
|
if type(logits) != torch.Tensor:
|
||||||
|
logits = torch.tensor(logits)
|
||||||
|
prediction = torch.softmax(logits, dim=1)
|
||||||
|
prediction = prediction.detach().cpu().numpy()
|
||||||
|
_argmaxs = prediction.argmax(axis=1)
|
||||||
|
prediction = np.eye(prediction.shape[1])[_argmaxs]
|
||||||
|
return prediction
|
||||||
else:
|
else:
|
||||||
print("unknown classification type")
|
raise NotImplementedError()
|
||||||
|
|
||||||
return prediction.detach().cpu().numpy()
|
|
||||||
|
|
||||||
|
|
||||||
class TfidfVectorizerMultilingual:
|
class TfidfVectorizerMultilingual:
|
||||||
|
|
@ -114,63 +143,138 @@ class Trainer:
|
||||||
patience,
|
patience,
|
||||||
experiment_name,
|
experiment_name,
|
||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
|
classification_type,
|
||||||
|
vgf_name,
|
||||||
|
n_jobs,
|
||||||
|
scheduler_name=None,
|
||||||
):
|
):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = model.to(device)
|
self.model = model.to(device)
|
||||||
self.optimizer = self.init_optimizer(optimizer_name, lr)
|
self.optimizer, self.scheduler = self.init_optimizer(
|
||||||
|
optimizer_name, lr, scheduler_name
|
||||||
|
)
|
||||||
self.evaluate_steps = evaluate_step
|
self.evaluate_steps = evaluate_step
|
||||||
self.loss_fn = loss_fn.to(device)
|
self.loss_fn = loss_fn.to(device)
|
||||||
self.print_steps = print_steps
|
self.print_steps = print_steps
|
||||||
self.experiment_name = experiment_name
|
self.experiment_name = experiment_name
|
||||||
self.patience = patience
|
self.patience = patience
|
||||||
self.print_eval = evaluate_step
|
self.print_eval = 10
|
||||||
self.earlystopping = EarlyStopping(
|
self.earlystopping = EarlyStopping(
|
||||||
patience=patience,
|
patience=patience,
|
||||||
checkpoint_path=checkpoint_path,
|
checkpoint_path=checkpoint_path,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
)
|
)
|
||||||
|
self.clf_type = classification_type
|
||||||
|
self.vgf_name = vgf_name
|
||||||
|
self.scheduler_name = scheduler_name
|
||||||
|
self.n_jobs = n_jobs
|
||||||
|
self.monitored_metric = (
|
||||||
|
"macro-F1" if self.clf_type == "multilabel" else "accuracy"
|
||||||
|
) # TODO: make this configurable
|
||||||
|
|
||||||
def init_optimizer(self, optimizer_name, lr):
|
def init_optimizer(self, optimizer_name, lr, scheduler_name):
|
||||||
if optimizer_name.lower() == "adamw":
|
if optimizer_name.lower() == "adamw":
|
||||||
return AdamW(self.model.parameters(), lr=lr)
|
optim = AdamW(self.model.parameters(), lr=lr)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Optimizer {optimizer_name} not supported")
|
raise ValueError(f"Optimizer {optimizer_name} not supported")
|
||||||
|
if scheduler_name is None:
|
||||||
|
scheduler = None
|
||||||
|
elif scheduler_name == "ReduceLROnPlateau":
|
||||||
|
scheduler = ReduceLROnPlateau(optim, "max", factor=0.5, min_lr=1e-5)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Scheduler {scheduler_name} not supported")
|
||||||
|
return optim, scheduler
|
||||||
|
|
||||||
|
def get_config(self, train_dataloader, eval_dataloader, epochs):
|
||||||
|
return {
|
||||||
|
"model name": self.model.name_or_path
|
||||||
|
if not hasattr(self.model, "mt5encoder")
|
||||||
|
else self.model.mt5encoder.name_or_path,
|
||||||
|
"epochs": epochs,
|
||||||
|
"learning rate": self.optimizer.defaults["lr"],
|
||||||
|
"scheduler": self.scheduler_name, # TODO: add scheduler params
|
||||||
|
"train size": len(train_dataloader.dataset),
|
||||||
|
"eval size": len(eval_dataloader.dataset),
|
||||||
|
"train batch size": train_dataloader.batch_size,
|
||||||
|
"eval batch size": eval_dataloader.batch_size,
|
||||||
|
"max len": train_dataloader.dataset.X.shape[-1],
|
||||||
|
"patience": self.earlystopping.patience,
|
||||||
|
"evaluate every": self.evaluate_steps,
|
||||||
|
"print eval every": self.print_eval,
|
||||||
|
"print train steps": self.print_steps,
|
||||||
|
"classification type": self.clf_type,
|
||||||
|
}
|
||||||
|
|
||||||
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
def train(self, train_dataloader, eval_dataloader, epochs=10):
|
||||||
print(
|
_config = self.get_config(train_dataloader, eval_dataloader, epochs)
|
||||||
f"""- Training params for {self.experiment_name}:
|
|
||||||
- epochs: {epochs}
|
print(f"- Training params for {self.experiment_name}:")
|
||||||
- learning rate: {self.optimizer.defaults['lr']}
|
for k, v in _config.items():
|
||||||
- train batch size: {train_dataloader.batch_size}
|
print(f"\t{k}: {v}")
|
||||||
- eval batch size: {eval_dataloader.batch_size}
|
|
||||||
- max len: {train_dataloader.dataset.X.shape[-1]}
|
|
||||||
- patience: {self.earlystopping.patience}
|
|
||||||
- evaluate every: {self.evaluate_steps}
|
|
||||||
- print eval every: {self.print_eval}
|
|
||||||
- print train steps: {self.print_steps}\n"""
|
|
||||||
)
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
self.train_epoch(train_dataloader, epoch)
|
train_loss = self.train_epoch(train_dataloader, epoch)
|
||||||
if (epoch + 1) % self.evaluate_steps == 0:
|
|
||||||
print_eval = (epoch + 1) % self.print_eval == 0
|
if (epoch + 1) % self.evaluate_steps == 0 or (epoch + 1) == 1:
|
||||||
metric_watcher = self.evaluate(eval_dataloader, print_eval=print_eval)
|
print_eval = verbosity_eval(epoch, self.print_eval)
|
||||||
stop = self.earlystopping(metric_watcher, self.model, epoch + 1)
|
with torch.no_grad():
|
||||||
|
eval_loss, avg_metrics, lang_metrics = self.evaluate(
|
||||||
|
eval_dataloader,
|
||||||
|
print_eval=print_eval,
|
||||||
|
n_jobs=self.n_jobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
f"{self.vgf_name}/loss/val": eval_loss,
|
||||||
|
**format_langkey_wandb(lang_metrics, self.vgf_name),
|
||||||
|
**format_average_wandb(avg_metrics, self.vgf_name),
|
||||||
|
},
|
||||||
|
commit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
stop = self.earlystopping(
|
||||||
|
avg_metrics[self.monitored_metric], self.model, epoch + 1
|
||||||
|
)
|
||||||
if stop:
|
if stop:
|
||||||
print(
|
print(
|
||||||
f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}"
|
f"- restoring best model from epoch {self.earlystopping.best_epoch} with best metric: {self.earlystopping.best_score:3f}"
|
||||||
)
|
)
|
||||||
self.model = self.earlystopping.load_model(self.model).to(
|
restored_model = self.earlystopping.load_model(self.model)
|
||||||
self.device
|
|
||||||
)
|
# swapping model on gpu
|
||||||
|
del self.model
|
||||||
|
self.model = restored_model.to(self.device)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if self.scheduler is not None:
|
||||||
|
self.scheduler.step(avg_metrics[self.monitored_metric])
|
||||||
|
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
f"{self.vgf_name}/loss/train": train_loss,
|
||||||
|
f"{self.vgf_name}/learning rate": self.optimizer.param_groups[0][
|
||||||
|
"lr"
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
print(f"- last swipe on eval set")
|
print(f"- last swipe on eval set")
|
||||||
self.train_epoch(eval_dataloader, epoch=0)
|
self.train_epoch(
|
||||||
|
DataLoader(
|
||||||
|
eval_dataloader.dataset,
|
||||||
|
batch_size=train_dataloader.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
),
|
||||||
|
epoch=-1,
|
||||||
|
)
|
||||||
self.earlystopping.save_model(self.model)
|
self.earlystopping.save_model(self.model)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def train_epoch(self, dataloader, epoch):
|
def train_epoch(self, dataloader, epoch):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
batch_losses = []
|
||||||
for b_idx, (x, y, lang) in enumerate(dataloader):
|
for b_idx, (x, y, lang) in enumerate(dataloader):
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
y_hat = self.model(x.to(self.device))
|
y_hat = self.model(x.to(self.device))
|
||||||
|
|
@ -180,37 +284,47 @@ class Trainer:
|
||||||
loss = self.loss_fn(y_hat, y.to(self.device))
|
loss = self.loss_fn(y_hat, y.to(self.device))
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
batch_losses.append(loss.item())
|
||||||
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
if (epoch + 1) % PRINT_ON_EPOCH == 0:
|
||||||
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
if ((b_idx + 1) % self.print_steps == 0) or b_idx == 0:
|
||||||
print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}")
|
print(
|
||||||
return self
|
f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {np.mean(batch_losses):.4f}"
|
||||||
|
)
|
||||||
|
return np.mean(batch_losses)
|
||||||
|
|
||||||
def evaluate(self, dataloader, print_eval=True):
|
def evaluate(self, dataloader, print_eval=True, n_jobs=-1):
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
eval_losses = []
|
||||||
|
|
||||||
lY = defaultdict(list)
|
lY_true = defaultdict(list)
|
||||||
lY_hat = defaultdict(list)
|
lY_pred = defaultdict(list)
|
||||||
|
|
||||||
for b_idx, (x, y, lang) in enumerate(dataloader):
|
for b_idx, (x, y, lang) in enumerate(dataloader):
|
||||||
y_hat = self.model(x.to(self.device))
|
y_pred = self.model(x.to(self.device))
|
||||||
if isinstance(y_hat, ModelOutput):
|
if isinstance(y_pred, ModelOutput):
|
||||||
loss = self.loss_fn(y_hat.logits, y.to(self.device))
|
loss = self.loss_fn(y_pred.logits, y.to(self.device))
|
||||||
predictions = predict(y_hat.logits, classification_type="multilabel")
|
predictions = predict(y_pred.logits, clf_type=self.clf_type)
|
||||||
else:
|
else:
|
||||||
loss = self.loss_fn(y_hat, y.to(self.device))
|
loss = self.loss_fn(y_pred, y.to(self.device))
|
||||||
predictions = predict(y_hat, classification_type="multilabel")
|
predictions = predict(y_pred, clf_type=self.clf_type)
|
||||||
|
|
||||||
|
eval_losses.append(loss.item())
|
||||||
|
|
||||||
for l, _true, _pred in zip(lang, y, predictions):
|
for l, _true, _pred in zip(lang, y, predictions):
|
||||||
lY[l].append(_true.detach().cpu().numpy())
|
lY_true[l].append(_true.detach().cpu().numpy())
|
||||||
lY_hat[l].append(_pred)
|
lY_pred[l].append(_pred)
|
||||||
|
|
||||||
for lang in lY:
|
for lang in lY_true:
|
||||||
lY[lang] = np.vstack(lY[lang])
|
lY_true[lang] = np.vstack(lY_true[lang])
|
||||||
lY_hat[lang] = np.vstack(lY_hat[lang])
|
lY_pred[lang] = np.vstack(lY_pred[lang])
|
||||||
|
|
||||||
l_eval = evaluate(lY, lY_hat)
|
l_eval = evaluate(lY_true, lY_pred, clf_type=self.clf_type, n_jobs=n_jobs)
|
||||||
average_metrics = log_eval(l_eval, phase="validation", verbose=print_eval)
|
|
||||||
return average_metrics[0] # macro-F1
|
avg_metrics, lang_metrics = log_eval(
|
||||||
|
l_eval, phase="validation", clf_type=self.clf_type, verbose=print_eval
|
||||||
|
)
|
||||||
|
|
||||||
|
return np.mean(eval_losses), avg_metrics, lang_metrics
|
||||||
|
|
||||||
|
|
||||||
class EarlyStopping:
|
class EarlyStopping:
|
||||||
|
|
@ -232,7 +346,8 @@ class EarlyStopping:
|
||||||
self.experiment_name = experiment_name
|
self.experiment_name = experiment_name
|
||||||
|
|
||||||
def __call__(self, validation, model, epoch):
|
def __call__(self, validation, model, epoch):
|
||||||
if validation > self.best_score:
|
if validation >= self.best_score:
|
||||||
|
wandb.log({"patience": self.patience - self.counter})
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(
|
print(
|
||||||
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
|
f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}"
|
||||||
|
|
@ -244,11 +359,12 @@ class EarlyStopping:
|
||||||
self.save_model(model)
|
self.save_model(model)
|
||||||
elif validation < (self.best_score + self.min_delta):
|
elif validation < (self.best_score + self.min_delta):
|
||||||
self.counter += 1
|
self.counter += 1
|
||||||
|
wandb.log({"patience": self.patience - self.counter})
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(
|
print(
|
||||||
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
|
f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}"
|
||||||
)
|
)
|
||||||
if self.counter >= self.patience:
|
if self.counter >= self.patience and self.patience != -1:
|
||||||
print(f"- earlystopping: Early stopping at epoch {epoch}")
|
print(f"- earlystopping: Early stopping at epoch {epoch}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -104,10 +104,6 @@ class MultilingualGen(ViewGen):
|
||||||
pickle.dump(self, f)
|
pickle.dump(self, f)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
_str = f"[Multilingual VGF (m)]\n- embed_dir: {self.embed_dir}\n- langs: {self.langs}\n- n_jobs: {self.n_jobs}\n- cached: {self.cached}\n- sif: {self.sif}\n- probabilistic: {self.probabilistic}\n"
|
|
||||||
return _str
|
|
||||||
|
|
||||||
|
|
||||||
def load_MUSEs(langs, l_vocab, dir_path, cached=False):
|
def load_MUSEs(langs, l_vocab, dir_path, cached=False):
|
||||||
dir_path = expanduser(dir_path)
|
dir_path = expanduser(dir_path)
|
||||||
|
|
|
||||||
|
|
@ -6,20 +6,50 @@ from collections import defaultdict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import transformers
|
import transformers
|
||||||
|
from transformers import MT5EncoderModel
|
||||||
# from sklearn.model_selection import train_test_split
|
|
||||||
# from torch.optim import AdamW
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
from transformers.modeling_outputs import ModelOutput
|
||||||
|
|
||||||
from gfun.vgfs.commons import Trainer
|
from gfun.vgfs.commons import Trainer
|
||||||
from gfun.vgfs.transformerGen import TransformerGen
|
from gfun.vgfs.transformerGen import TransformerGen
|
||||||
from gfun.vgfs.viewGen import ViewGen
|
from gfun.vgfs.viewGen import ViewGen
|
||||||
|
from dataManager.torchDataset import MultilingualDatasetTorch
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
||||||
# TODO: add support to loggers
|
class MT5ForSequenceClassification(nn.Module):
|
||||||
|
def __init__(self, model_name, num_labels, output_hidden_states):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.output_hidden_states = output_hidden_states
|
||||||
|
self.mt5encoder = MT5EncoderModel.from_pretrained(
|
||||||
|
model_name, output_hidden_states=True
|
||||||
|
)
|
||||||
|
self.dropout = nn.Dropout(0.1)
|
||||||
|
self.linear = nn.Linear(512, num_labels)
|
||||||
|
|
||||||
|
def forward(self, input_ids):
|
||||||
|
embed = self.mt5encoder(input_ids=input_ids)
|
||||||
|
pooled = torch.mean(embed.last_hidden_state, dim=1)
|
||||||
|
outputs = self.dropout(pooled)
|
||||||
|
logits = self.linear(outputs)
|
||||||
|
if self.output_hidden_states:
|
||||||
|
return ModelOutput(
|
||||||
|
logits=logits,
|
||||||
|
pooled=pooled,
|
||||||
|
)
|
||||||
|
return ModelOutput(logits=logits)
|
||||||
|
|
||||||
|
def save_pretrained(self, checkpoint_dir):
|
||||||
|
torch.save(self.state_dict(), checkpoint_dir + ".pt")
|
||||||
|
return
|
||||||
|
|
||||||
|
def from_pretrained(self, checkpoint_dir):
|
||||||
|
checkpoint_dir += ".pt"
|
||||||
|
return self.load_state_dict(torch.load(checkpoint_dir))
|
||||||
|
|
||||||
|
|
||||||
class TextualTransformerGen(ViewGen, TransformerGen):
|
class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
|
|
@ -39,23 +69,27 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
evaluate_step=10,
|
evaluate_step=10,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
patience=5,
|
patience=5,
|
||||||
|
classification_type="multilabel",
|
||||||
|
scheduler="ReduceLROnPlateau",
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
self._validate_model_name(model_name),
|
self._validate_model_name(model_name),
|
||||||
dataset_name,
|
dataset_name,
|
||||||
epochs,
|
epochs=epochs,
|
||||||
lr,
|
lr=lr,
|
||||||
batch_size,
|
scheduler=scheduler,
|
||||||
batch_size_eval,
|
batch_size=batch_size,
|
||||||
max_length,
|
batch_size_eval=batch_size_eval,
|
||||||
print_steps,
|
device=device,
|
||||||
device,
|
evaluate_step=evaluate_step,
|
||||||
probabilistic,
|
patience=patience,
|
||||||
n_jobs,
|
probabilistic=probabilistic,
|
||||||
evaluate_step,
|
max_length=max_length,
|
||||||
verbose,
|
print_steps=print_steps,
|
||||||
patience,
|
n_jobs=n_jobs,
|
||||||
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
|
self.clf_type = classification_type
|
||||||
self.fitted = False
|
self.fitted = False
|
||||||
print(
|
print(
|
||||||
f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
||||||
|
|
@ -66,15 +100,22 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
return "bert-base-uncased"
|
return "bert-base-uncased"
|
||||||
elif "mbert" == model_name:
|
elif "mbert" == model_name:
|
||||||
return "bert-base-multilingual-uncased"
|
return "bert-base-multilingual-uncased"
|
||||||
elif "xlm" == model_name:
|
elif "xlm-roberta" == model_name:
|
||||||
return "xlm-roberta-base"
|
return "xlm-roberta-base"
|
||||||
|
elif "mt5" == model_name:
|
||||||
|
return "google/mt5-small"
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def load_pretrained_model(self, model_name, num_labels):
|
def load_pretrained_model(self, model_name, num_labels):
|
||||||
return AutoModelForSequenceClassification.from_pretrained(
|
if model_name == "google/mt5-small":
|
||||||
model_name, num_labels=num_labels, output_hidden_states=True
|
return MT5ForSequenceClassification(
|
||||||
)
|
model_name, num_labels=num_labels, output_hidden_states=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
model_name, num_labels=num_labels, output_hidden_states=True
|
||||||
|
)
|
||||||
|
|
||||||
def load_tokenizer(self, model_name):
|
def load_tokenizer(self, model_name):
|
||||||
return AutoTokenizer.from_pretrained(model_name)
|
return AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
@ -127,9 +168,8 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
experiment_name = (
|
experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||||
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
|
||||||
)
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
optimizer_name="adamW",
|
optimizer_name="adamW",
|
||||||
|
|
@ -140,7 +180,16 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
evaluate_step=self.evaluate_step,
|
evaluate_step=self.evaluate_step,
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
checkpoint_path="models/vgfs/transformer",
|
checkpoint_path=os.path.join(
|
||||||
|
"models",
|
||||||
|
"vgfs",
|
||||||
|
"transformer",
|
||||||
|
self._format_model_name(self.model_name),
|
||||||
|
),
|
||||||
|
vgf_name="textual_trf",
|
||||||
|
classification_type=self.clf_type,
|
||||||
|
n_jobs=self.n_jobs,
|
||||||
|
scheduler_name=self.scheduler,
|
||||||
)
|
)
|
||||||
trainer.train(
|
trainer.train(
|
||||||
train_dataloader=tra_dataloader,
|
train_dataloader=tra_dataloader,
|
||||||
|
|
@ -175,8 +224,12 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for input_ids, lang in dataloader:
|
for input_ids, lang in dataloader:
|
||||||
input_ids = input_ids.to(self.device)
|
input_ids = input_ids.to(self.device)
|
||||||
out = self.model(input_ids).hidden_states[-1]
|
# TODO: check this
|
||||||
batch_embeddings = out[:, 0, :].cpu().numpy()
|
if isinstance(self.model, MT5ForSequenceClassification):
|
||||||
|
batch_embeddings = self.model(input_ids).pooled.cpu().numpy()
|
||||||
|
else:
|
||||||
|
out = self.model(input_ids).hidden_states[-1]
|
||||||
|
batch_embeddings = out[:, 0, :].cpu().numpy()
|
||||||
_embeds.append((batch_embeddings, lang))
|
_embeds.append((batch_embeddings, lang))
|
||||||
|
|
||||||
for embed, lang in _embeds:
|
for embed, lang in _embeds:
|
||||||
|
|
@ -206,39 +259,22 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
pickle.dump(self, f)
|
pickle.dump(self, f)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __str__(self):
|
def freeze_model(self):
|
||||||
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
# TODO: up to n-layers? or all? avoid freezing head ovb...
|
||||||
return str
|
for param in self.model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def _format_model_name(self, model_name):
|
||||||
|
if "mt5" in model_name:
|
||||||
|
return "google-mt5"
|
||||||
|
elif "bert" in model_name:
|
||||||
|
if "multilingual" in model_name:
|
||||||
|
return "mbert"
|
||||||
|
elif "xlm-roberta" in model_name:
|
||||||
|
return "xlm-roberta"
|
||||||
|
else:
|
||||||
|
return model_name
|
||||||
|
|
||||||
class MultilingualDatasetTorch(Dataset):
|
def get_config(self):
|
||||||
def __init__(self, lX, lY, split="train"):
|
c = super().get_config()
|
||||||
self.lX = lX
|
return {"textual_trf": c}
|
||||||
self.lY = lY
|
|
||||||
self.split = split
|
|
||||||
self.langs = []
|
|
||||||
self.init()
|
|
||||||
|
|
||||||
def init(self):
|
|
||||||
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
|
|
||||||
if self.split != "whole":
|
|
||||||
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
|
||||||
self.langs = sum(
|
|
||||||
[
|
|
||||||
v
|
|
||||||
for v in {
|
|
||||||
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
|
|
||||||
}.values()
|
|
||||||
],
|
|
||||||
[],
|
|
||||||
)
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.X)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
if self.split == "whole":
|
|
||||||
return self.X[index], self.langs[index]
|
|
||||||
return self.X[index], self.Y[index], self.langs[index]
|
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ class TransformerGen:
|
||||||
evaluate_step=10,
|
evaluate_step=10,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
patience=5,
|
patience=5,
|
||||||
|
scheduler=None,
|
||||||
):
|
):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.dataset_name = dataset_name
|
self.dataset_name = dataset_name
|
||||||
|
|
@ -46,6 +47,7 @@ class TransformerGen:
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.patience = patience
|
self.patience = patience
|
||||||
self.datasets = {}
|
self.datasets = {}
|
||||||
|
self.scheduler = scheduler
|
||||||
self.feature2posterior_projector = (
|
self.feature2posterior_projector = (
|
||||||
self.make_probabilistic() if probabilistic else None
|
self.make_probabilistic() if probabilistic else None
|
||||||
)
|
)
|
||||||
|
|
@ -94,3 +96,22 @@ class TransformerGen:
|
||||||
val_lY[lang] = val_Y
|
val_lY[lang] = val_Y
|
||||||
|
|
||||||
return tr_lX, tr_lY, val_lX, val_lY
|
return tr_lX, tr_lY, val_lX, val_lY
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {
|
||||||
|
"model_name": self.model_name,
|
||||||
|
"dataset_name": self.dataset_name,
|
||||||
|
"epochs": self.epochs,
|
||||||
|
"lr": self.lr,
|
||||||
|
"scheduler": self.scheduler,
|
||||||
|
"batch_size": self.batch_size,
|
||||||
|
"batch_size_eval": self.batch_size_eval,
|
||||||
|
"max_length": self.max_length,
|
||||||
|
"print_steps": self.print_steps,
|
||||||
|
"device": self.device,
|
||||||
|
"probabilistic": self.probabilistic,
|
||||||
|
"n_jobs": self.n_jobs,
|
||||||
|
"evaluate_step": self.evaluate_step,
|
||||||
|
"verbose": self.verbose,
|
||||||
|
"patience": self.patience,
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -65,8 +65,3 @@ class VanillaFunGen(ViewGen):
|
||||||
with open(_path, "wb") as f:
|
with open(_path, "wb") as f:
|
||||||
pickle.dump(self, f)
|
pickle.dump(self, f)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
_str = f"[VanillaFunGen (-p)]\n- base learner: {self.learners}\n- n_jobs: {self.n_jobs}\n"
|
|
||||||
# - parameters: {self.first_tier_parameters}
|
|
||||||
return _str
|
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,12 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||||||
|
|
||||||
from gfun.vgfs.commons import Trainer
|
from gfun.vgfs.commons import Trainer
|
||||||
from gfun.vgfs.transformerGen import TransformerGen
|
from gfun.vgfs.transformerGen import TransformerGen
|
||||||
from gfun.vgfs.viewGen import ViewGen
|
from gfun.vgfs.viewGen import ViewGen
|
||||||
|
from dataManager.torchDataset import MultimodalDatasetTorch
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
@ -20,6 +20,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
model_name,
|
model_name,
|
||||||
dataset_name,
|
dataset_name,
|
||||||
lr=1e-5,
|
lr=1e-5,
|
||||||
|
scheduler="ReduceLROnPlateau",
|
||||||
epochs=10,
|
epochs=10,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
batch_size_eval=128,
|
batch_size_eval=128,
|
||||||
|
|
@ -27,12 +28,14 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
device="cpu",
|
device="cpu",
|
||||||
probabilistic=False,
|
probabilistic=False,
|
||||||
patience=5,
|
patience=5,
|
||||||
|
classification_type="multilabel",
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_name,
|
model_name,
|
||||||
dataset_name,
|
dataset_name,
|
||||||
lr=lr,
|
|
||||||
epochs=epochs,
|
epochs=epochs,
|
||||||
|
lr=lr,
|
||||||
|
scheduler=scheduler,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
batch_size_eval=batch_size_eval,
|
batch_size_eval=batch_size_eval,
|
||||||
device=device,
|
device=device,
|
||||||
|
|
@ -40,6 +43,7 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
patience=patience,
|
patience=patience,
|
||||||
probabilistic=probabilistic,
|
probabilistic=probabilistic,
|
||||||
)
|
)
|
||||||
|
self.clf_type = classification_type
|
||||||
self.fitted = False
|
self.fitted = False
|
||||||
print(
|
print(
|
||||||
f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
||||||
|
|
@ -97,7 +101,10 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}"
|
experiment_name = (
|
||||||
|
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||||
|
)
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
optimizer_name="adamW",
|
optimizer_name="adamW",
|
||||||
|
|
@ -109,6 +116,9 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
patience=self.patience,
|
patience=self.patience,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
checkpoint_path="models/vgfs/transformer",
|
checkpoint_path="models/vgfs/transformer",
|
||||||
|
vgf_name="visual_trf",
|
||||||
|
classification_type=self.clf_type,
|
||||||
|
n_jobs=self.n_jobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.train(
|
trainer.train(
|
||||||
|
|
@ -175,66 +185,5 @@ class VisualTransformerGen(ViewGen, TransformerGen):
|
||||||
pickle.dump(self, f)
|
pickle.dump(self, f)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __str__(self):
|
def get_config(self):
|
||||||
str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
return {"visual_trf": super().get_config()}
|
||||||
return str
|
|
||||||
|
|
||||||
|
|
||||||
class MultimodalDatasetTorch(Dataset):
|
|
||||||
def __init__(self, lX, lY, split="train"):
|
|
||||||
self.lX = lX
|
|
||||||
self.lY = lY
|
|
||||||
self.split = split
|
|
||||||
self.langs = []
|
|
||||||
self.init()
|
|
||||||
|
|
||||||
def init(self):
|
|
||||||
self.X = torch.vstack([imgs for imgs in self.lX.values()])
|
|
||||||
if self.split != "whole":
|
|
||||||
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
|
||||||
self.langs = sum(
|
|
||||||
[
|
|
||||||
v
|
|
||||||
for v in {
|
|
||||||
lang: [lang] * len(data) for lang, data in self.lX.items()
|
|
||||||
}.values()
|
|
||||||
],
|
|
||||||
[],
|
|
||||||
)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.X)
|
|
||||||
|
|
||||||
def __getitem__(self, index):
|
|
||||||
if self.split == "whole":
|
|
||||||
return self.X[index], self.langs[index]
|
|
||||||
return self.X[index], self.Y[index], self.langs[index]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from os.path import expanduser
|
|
||||||
|
|
||||||
from dataManager.gFunDataset import gFunDataset
|
|
||||||
|
|
||||||
GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset")
|
|
||||||
dataset = gFunDataset(
|
|
||||||
dataset_dir=GLAMI_DATAPATH,
|
|
||||||
is_textual=True,
|
|
||||||
is_visual=True,
|
|
||||||
is_multilabel=False,
|
|
||||||
nrows=50,
|
|
||||||
)
|
|
||||||
|
|
||||||
vg = VisualTransformerGen(
|
|
||||||
dataset_name=dataset.dataset_name,
|
|
||||||
model_name="vit",
|
|
||||||
device="cuda",
|
|
||||||
epochs=5,
|
|
||||||
evaluate_step=10,
|
|
||||||
patience=10,
|
|
||||||
probabilistic=True,
|
|
||||||
)
|
|
||||||
lX, lY = dataset.training()
|
|
||||||
vg.fit(lX, lY)
|
|
||||||
out = vg.transform(lX)
|
|
||||||
exit(0)
|
|
||||||
|
|
|
||||||
|
|
@ -40,10 +40,6 @@ class WceGen(ViewGen):
|
||||||
"sif": self.sif,
|
"sif": self.sif,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
_str = f"[WordClass VGF (w)]\n- sif: {self.sif}\n- n_jobs: {self.n_jobs}\n"
|
|
||||||
return _str
|
|
||||||
|
|
||||||
def save_vgf(self, model_id):
|
def save_vgf(self, model_id):
|
||||||
import pickle
|
import pickle
|
||||||
from os.path import join
|
from os.path import join
|
||||||
|
|
|
||||||
107
main.py
107
main.py
|
|
@ -1,3 +1,8 @@
|
||||||
|
import os
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
|
|
@ -7,18 +12,36 @@ from gfun.generalizedFunnelling import GeneralizedFunnelling
|
||||||
|
|
||||||
"""
|
"""
|
||||||
TODO:
|
TODO:
|
||||||
- [!] add support for Binary Datasets (e.g. cls) - NB: CLS dataset is loading only "books" domain data
|
- Transformers VGFs:
|
||||||
- [!] documents should be trimmed to the same length (?)
|
- scheduler with warmup and cosine
|
||||||
- [!] logging
|
- freeze params method
|
||||||
- add documentations sphinx
|
- General:
|
||||||
- [!] zero-shot setup
|
[!] zero-shot setup
|
||||||
- FFNN posterior-probabilities' dependent
|
- CLS dataset is loading only "books" domain data
|
||||||
- re-init langs when loading VGFs?
|
- documents should be trimmed to the same length (for SVMs we are using way too long tokens)
|
||||||
- [!] loss of Attention-aggregator seems to be uncorrelated with Macro-F1 on the validation set!
|
- Attention Aggregator:
|
||||||
- [!] experiment with weight init of Attention-aggregator
|
- experiment with weight init of Attention-aggregator
|
||||||
|
- FFNN posterior-probabilities' dependent
|
||||||
|
- Docs:
|
||||||
|
- add documentations sphinx
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_config_name(args):
|
||||||
|
config_name = ""
|
||||||
|
if args.posteriors:
|
||||||
|
config_name += "P+"
|
||||||
|
if args.wce:
|
||||||
|
config_name += "W+"
|
||||||
|
if args.multilingual:
|
||||||
|
config_name += "M+"
|
||||||
|
if args.textual_transformer:
|
||||||
|
config_name += f"TT_{args.textual_trf_name}+"
|
||||||
|
if args.visual_transformer:
|
||||||
|
config_name += f"VT_{args.visual_trf_name}+"
|
||||||
|
return config_name.rstrip("+")
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
dataset = get_dataset(args.dataset, args)
|
dataset = get_dataset(args.dataset, args)
|
||||||
lX, lY = dataset.training()
|
lX, lY = dataset.training()
|
||||||
|
|
@ -43,6 +66,7 @@ def main(args):
|
||||||
dataset_name=args.dataset,
|
dataset_name=args.dataset,
|
||||||
langs=dataset.langs(),
|
langs=dataset.langs(),
|
||||||
num_labels=dataset.num_labels(),
|
num_labels=dataset.num_labels(),
|
||||||
|
classification_type=args.clf_type,
|
||||||
# Posterior VGF params ----------------
|
# Posterior VGF params ----------------
|
||||||
posterior=args.posteriors,
|
posterior=args.posteriors,
|
||||||
# Multilingual VGF params -------------
|
# Multilingual VGF params -------------
|
||||||
|
|
@ -52,24 +76,26 @@ def main(args):
|
||||||
wce=args.wce,
|
wce=args.wce,
|
||||||
# Transformer VGF params --------------
|
# Transformer VGF params --------------
|
||||||
textual_transformer=args.textual_transformer,
|
textual_transformer=args.textual_transformer,
|
||||||
textual_transformer_name=args.transformer_name,
|
textual_transformer_name=args.textual_trf_name,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
|
eval_batch_size=args.eval_batch_size,
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
lr=args.lr,
|
textual_lr=args.textual_lr,
|
||||||
|
visual_lr=args.visual_lr,
|
||||||
max_length=args.max_length,
|
max_length=args.max_length,
|
||||||
patience=args.patience,
|
patience=args.patience,
|
||||||
evaluate_step=args.evaluate_step,
|
evaluate_step=args.evaluate_step,
|
||||||
device=args.device,
|
device=args.device,
|
||||||
# Visual Transformer VGF params --------------
|
# Visual Transformer VGF params --------------
|
||||||
visual_transformer=args.visual_transformer,
|
visual_transformer=args.visual_transformer,
|
||||||
visual_transformer_name=args.visual_transformer_name,
|
visual_transformer_name=args.visual_trf_name,
|
||||||
# batch_size=args.batch_size,
|
# batch_size=args.batch_size,
|
||||||
# epochs=args.epochs,
|
# epochs=args.epochs,
|
||||||
# lr=args.lr,
|
# lr=args.lr,
|
||||||
# patience=args.patience,
|
# patience=args.patience,
|
||||||
# evaluate_step=args.evaluate_step,
|
# evaluate_step=args.evaluate_step,
|
||||||
# device="cuda",
|
# device="cuda",
|
||||||
# General params ----------------------
|
# General params ---------------------
|
||||||
probabilistic=args.features,
|
probabilistic=args.features,
|
||||||
aggfunc=args.aggfunc,
|
aggfunc=args.aggfunc,
|
||||||
optimc=args.optimc,
|
optimc=args.optimc,
|
||||||
|
|
@ -78,27 +104,54 @@ def main(args):
|
||||||
n_jobs=args.n_jobs,
|
n_jobs=args.n_jobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# gfun.get_config()
|
config = gfun.get_config()
|
||||||
|
|
||||||
|
wandb.init(project="gfun", name=f"gFun-{get_config_name(args)}", config=config)
|
||||||
|
|
||||||
gfun.fit(lX, lY)
|
gfun.fit(lX, lY)
|
||||||
|
|
||||||
if args.load_trained is None and not args.nosave:
|
if args.load_trained is None and not args.nosave:
|
||||||
gfun.save(save_first_tier=True, save_meta=True)
|
gfun.save(save_first_tier=True, save_meta=True)
|
||||||
|
|
||||||
# print("- Computing evaluation on training set")
|
|
||||||
# preds = gfun.transform(lX)
|
|
||||||
# train_eval = evaluate(lY, preds)
|
|
||||||
# log_eval(train_eval, phase="train")
|
|
||||||
|
|
||||||
timetr = time()
|
timetr = time()
|
||||||
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
print(f"- training completed in {timetr - tinit:.2f} seconds")
|
||||||
|
|
||||||
gfun_preds = gfun.transform(lX_te)
|
gfun_preds = gfun.transform(lX_te)
|
||||||
test_eval = evaluate(lY_te, gfun_preds)
|
test_eval = evaluate(lY_te, gfun_preds, clf_type=args.clf_type, n_jobs=args.n_jobs)
|
||||||
log_eval(test_eval, phase="test")
|
avg_metrics_gfun, lang_metrics_gfun = log_eval(
|
||||||
|
test_eval, phase="test", clf_type=args.clf_type
|
||||||
|
)
|
||||||
|
|
||||||
timeval = time()
|
timeval = time()
|
||||||
print(f"- testing completed in {timeval - timetr:.2f} seconds")
|
print(f"- testing completed in {timeval - timetr:.2f} seconds")
|
||||||
|
|
||||||
|
def log_barplot_wandb(gfun_res, title_affix="per langauge"):
|
||||||
|
if title_affix == "per language":
|
||||||
|
for metric, lang_values in gfun_res.items():
|
||||||
|
data = [[lang, v] for lang, v in lang_values.items()]
|
||||||
|
table = wandb.Table(data=data, columns=["lang", f"{metric}"])
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
f"gFun/language {metric}": wandb.plot.bar(
|
||||||
|
table, "lang", metric, title=f"{metric} {title_affix}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data = [[metric, value] for metric, value in gfun_res.items()]
|
||||||
|
table = wandb.Table(data=data, columns=["metric", "value"])
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
f"gFun/average metric": wandb.plot.bar(
|
||||||
|
table, "metric", "value", title=f"metric {title_affix}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
wandb.log(gfun_res)
|
||||||
|
|
||||||
|
log_barplot_wandb(lang_metrics_gfun, title_affix="per language")
|
||||||
|
log_barplot_wandb(avg_metrics_gfun, title_affix="averages")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
|
|
@ -112,6 +165,8 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--nrows", type=int, default=None)
|
parser.add_argument("--nrows", type=int, default=None)
|
||||||
parser.add_argument("--min_count", type=int, default=10)
|
parser.add_argument("--min_count", type=int, default=10)
|
||||||
parser.add_argument("--max_labels", type=int, default=50)
|
parser.add_argument("--max_labels", type=int, default=50)
|
||||||
|
parser.add_argument("--clf_type", type=str, default="multilabel")
|
||||||
|
parser.add_argument("--save_dataset", action="store_true")
|
||||||
# gFUN parameters ----------------------
|
# gFUN parameters ----------------------
|
||||||
parser.add_argument("-p", "--posteriors", action="store_true")
|
parser.add_argument("-p", "--posteriors", action="store_true")
|
||||||
parser.add_argument("-m", "--multilingual", action="store_true")
|
parser.add_argument("-m", "--multilingual", action="store_true")
|
||||||
|
|
@ -123,15 +178,17 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--features", action="store_false")
|
parser.add_argument("--features", action="store_false")
|
||||||
parser.add_argument("--aggfunc", type=str, default="mean")
|
parser.add_argument("--aggfunc", type=str, default="mean")
|
||||||
# transformer parameters ---------------
|
# transformer parameters ---------------
|
||||||
parser.add_argument("--transformer_name", type=str, default="mbert")
|
|
||||||
parser.add_argument("--batch_size", type=int, default=32)
|
|
||||||
parser.add_argument("--epochs", type=int, default=100)
|
parser.add_argument("--epochs", type=int, default=100)
|
||||||
parser.add_argument("--lr", type=float, default=1e-5)
|
parser.add_argument("--textual_trf_name", type=str, default="mbert")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=32)
|
||||||
|
parser.add_argument("--eval_batch_size", type=int, default=128)
|
||||||
|
parser.add_argument("--textual_lr", type=float, default=1e-4)
|
||||||
parser.add_argument("--max_length", type=int, default=128)
|
parser.add_argument("--max_length", type=int, default=128)
|
||||||
parser.add_argument("--patience", type=int, default=5)
|
parser.add_argument("--patience", type=int, default=5)
|
||||||
parser.add_argument("--evaluate_step", type=int, default=10)
|
parser.add_argument("--evaluate_step", type=int, default=10)
|
||||||
# Visual Transformer parameters --------------
|
# Visual Transformer parameters --------------
|
||||||
parser.add_argument("--visual_transformer_name", type=str, default="vit")
|
parser.add_argument("--visual_trf_name", type=str, default="vit")
|
||||||
|
parser.add_argument("--visual_lr", type=float, default=1e-4)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
beautifulsoup4==4.11.2
|
beautifulsoup4==4.11.2
|
||||||
joblib==1.2.0
|
joblib==1.2.0
|
||||||
matplotlib==3.7.1
|
matplotlib==3.6.3
|
||||||
numpy==1.24.2
|
numpy==1.24.1
|
||||||
pandas==1.5.3
|
pandas==1.5.3
|
||||||
Pillow==9.4.0
|
Pillow==9.4.0
|
||||||
requests==2.28.2
|
requests==2.28.2
|
||||||
scikit_learn==1.2.1
|
scikit_learn==1.2.2
|
||||||
scipy==1.10.1
|
scipy==1.10.1
|
||||||
torch==1.13.1
|
torch==1.13.1
|
||||||
torchtext==0.14.1
|
torchtext==0.14.1
|
||||||
tqdm==4.65.0
|
tqdm==4.64.1
|
||||||
transformers==4.26.1
|
transformers==4.26.0
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue