246 lines
7.4 KiB
Python
246 lines
7.4 KiB
Python
import os
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
|
|
from collections import defaultdict
|
|
|
|
import numpy as np
|
|
import torch
|
|
import transformers
|
|
from sklearn.model_selection import train_test_split
|
|
from torch.optim import AdamW
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
from vgfs.learners.svms import FeatureSet2Posteriors
|
|
from vgfs.viewGen import ViewGen
|
|
from vgfs.transformerGen import TransformerGen
|
|
from vgfs.commons import Trainer, predict
|
|
|
|
transformers.logging.set_verbosity_error()
|
|
|
|
|
|
# TODO: add support to loggers
|
|
# TODO: multiple inheritance - maybe define a superclass for TransformerGenerator, whether it is a Textual or a Visual one, implementing dataset creation functions
|
|
|
|
|
|
class TextualTransformerGen(ViewGen, TransformerGen):
|
|
def __init__(
|
|
self,
|
|
model_name,
|
|
epochs=10,
|
|
lr=1e-5,
|
|
batch_size=4,
|
|
batch_size_eval=32,
|
|
max_length=512,
|
|
print_steps=50,
|
|
device="cpu",
|
|
probabilistic=False,
|
|
n_jobs=-1,
|
|
evaluate_step=10,
|
|
verbose=False,
|
|
patience=5,
|
|
):
|
|
super().__init__(
|
|
model_name,
|
|
epochs,
|
|
lr,
|
|
batch_size,
|
|
batch_size_eval,
|
|
max_length,
|
|
print_steps,
|
|
device,
|
|
probabilistic,
|
|
n_jobs,
|
|
evaluate_step,
|
|
verbose,
|
|
patience,
|
|
)
|
|
self.fitted = False
|
|
self._init()
|
|
|
|
def _init(self):
|
|
if self.probabilistic:
|
|
self.feature2posterior_projector = FeatureSet2Posteriors(
|
|
n_jobs=self.n_jobs, verbose=False
|
|
)
|
|
self.model_name = self._get_model_name(self.model_name)
|
|
print(
|
|
f"- init TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
|
)
|
|
|
|
def _get_model_name(self, name):
|
|
if "bert" == name:
|
|
name_model = "bert-base-uncased"
|
|
elif "mbert" == name:
|
|
name_model = "bert-base-multilingual-uncased"
|
|
elif "xlm" == name:
|
|
name_model = "xlm-roberta-base"
|
|
else:
|
|
raise NotImplementedError
|
|
return name_model
|
|
|
|
def load_pretrained_model(self, model_name, num_labels):
|
|
return AutoModelForSequenceClassification.from_pretrained(
|
|
model_name, num_labels=num_labels, output_hidden_states=True
|
|
)
|
|
|
|
def load_tokenizer(self, model_name):
|
|
return AutoTokenizer.from_pretrained(model_name)
|
|
|
|
def init_model(self, model_name, num_labels):
|
|
return self.load_pretrained_model(model_name, num_labels), self.load_tokenizer(
|
|
model_name
|
|
)
|
|
|
|
def _tokenize(self, X):
|
|
return self.tokenizer(
|
|
X,
|
|
return_tensors="pt",
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=self.max_length,
|
|
)
|
|
|
|
def fit(self, lX, lY):
|
|
if self.fitted:
|
|
return self
|
|
print("- fitting Textual Transformer View Generating Function")
|
|
_l = list(lX.keys())[0]
|
|
self.num_labels = lY[_l].shape[-1]
|
|
self.model, self.tokenizer = self.init_model(
|
|
self.model_name, num_labels=self.num_labels
|
|
)
|
|
|
|
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
|
|
lX, lY, split=0.2, seed=42
|
|
)
|
|
|
|
tra_dataloader = self.build_dataloader(
|
|
tr_lX,
|
|
tr_lY,
|
|
processor_fn=self._tokenize,
|
|
torchDataset=MultilingualDatasetTorch,
|
|
batch_size=self.batch_size,
|
|
split="train",
|
|
shuffle=True,
|
|
)
|
|
|
|
val_dataloader = self.build_dataloader(
|
|
val_lX,
|
|
val_lY,
|
|
processor_fn=self._tokenize,
|
|
torchDataset=MultilingualDatasetTorch,
|
|
batch_size=self.batch_size_eval,
|
|
split="val",
|
|
shuffle=False,
|
|
)
|
|
|
|
experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}"
|
|
trainer = Trainer(
|
|
model=self.model,
|
|
optimizer_name="adamW",
|
|
lr=self.lr,
|
|
device=self.device,
|
|
loss_fn=torch.nn.CrossEntropyLoss(),
|
|
print_steps=self.print_steps,
|
|
evaluate_step=self.evaluate_step,
|
|
patience=self.patience,
|
|
experiment_name=experiment_name,
|
|
)
|
|
trainer.train(
|
|
train_dataloader=tra_dataloader,
|
|
eval_dataloader=val_dataloader,
|
|
epochs=self.epochs,
|
|
)
|
|
|
|
if self.probabilistic:
|
|
self.feature2posterior_projector.fit(self.transform(lX), lY)
|
|
|
|
self.fitted = True
|
|
|
|
return self
|
|
|
|
def transform(self, lX):
|
|
_embeds = []
|
|
l_embeds = defaultdict(list)
|
|
|
|
dataloader = self.build_dataloader(
|
|
lX,
|
|
lY=None,
|
|
processor_fn=self._tokenize,
|
|
torchDataset=MultilingualDatasetTorch,
|
|
batch_size=self.batch_size_eval,
|
|
split="whole",
|
|
shuffle=False,
|
|
)
|
|
|
|
self.model.eval()
|
|
with torch.no_grad():
|
|
for input_ids, lang in dataloader:
|
|
input_ids = input_ids.to(self.device)
|
|
out = self.model(input_ids).hidden_states[-1]
|
|
batch_embeddings = out[:, 0, :].cpu().numpy()
|
|
_embeds.append((batch_embeddings, lang))
|
|
|
|
for embed, lang in _embeds:
|
|
for sample_embed, sample_lang in zip(embed, lang):
|
|
l_embeds[sample_lang].append(sample_embed)
|
|
|
|
if self.probabilistic and self.fitted:
|
|
l_embeds = self.feature2posterior_projector.transform(l_embeds)
|
|
|
|
return l_embeds
|
|
|
|
def fit_transform(self, lX, lY):
|
|
return self.fit(lX, lY).transform(lX)
|
|
|
|
def save_vgf(self, model_id):
|
|
import pickle
|
|
from os import makedirs
|
|
from os.path import join
|
|
|
|
vgf_name = "transformerGen"
|
|
_basedir = join("models", "vgfs", "transformer")
|
|
makedirs(_basedir, exist_ok=True)
|
|
_path = join(_basedir, f"{vgf_name}_{model_id}.pkl")
|
|
with open(_path, "wb") as f:
|
|
pickle.dump(self, f)
|
|
return self
|
|
|
|
def __str__(self):
|
|
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
|
return str
|
|
|
|
|
|
class MultilingualDatasetTorch(Dataset):
|
|
def __init__(self, lX, lY, split="train"):
|
|
self.lX = lX
|
|
self.lY = lY
|
|
self.split = split
|
|
self.langs = []
|
|
self.init()
|
|
|
|
def init(self):
|
|
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
|
|
if self.split != "whole":
|
|
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
|
|
self.langs = sum(
|
|
[
|
|
v
|
|
for v in {
|
|
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
|
|
}.values()
|
|
],
|
|
[],
|
|
)
|
|
|
|
return self
|
|
|
|
def __len__(self):
|
|
return len(self.X)
|
|
|
|
def __getitem__(self, index):
|
|
if self.split == "whole":
|
|
return self.X[index], self.langs[index]
|
|
return self.X[index], self.Y[index], self.langs[index]
|