gfun_multimodal/gfun/vgfs/textualTransformerGen.py

274 lines
8.5 KiB
Python

import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import transformers
from transformers import MT5EncoderModel
from torch.utils.data import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from gfun.vgfs.commons import Trainer
from gfun.vgfs.transformerGen import TransformerGen
from gfun.vgfs.viewGen import ViewGen
transformers.logging.set_verbosity_error()
class MT5ForSequenceClassification(nn.Module):
def __init__(self, model_name, num_labels, output_hidden_states):
super().__init__()
self.mt5encoder = MT5EncoderModel.from_pretrained(
model_name, output_hidden_states=output_hidden_states
)
self.dropout = nn.Dropout(0.1)
self.linear = nn.Linear(512, num_labels)
def forward(self, input_ids, attn_mask):
# TODO: output hidden states
outputs = self.mt5encoder(input_ids=input_ids, attention_mask=attn_mask)
outputs = self.dropout(outputs[0])
outputs = self.linear(outputs)
return outputs
class TextualTransformerGen(ViewGen, TransformerGen):
def __init__(
self,
model_name,
dataset_name,
epochs=10,
lr=1e-5,
batch_size=4,
batch_size_eval=32,
max_length=512,
print_steps=50,
device="cpu",
probabilistic=False,
n_jobs=-1,
evaluate_step=10,
verbose=False,
patience=5,
classification_type="multilabel",
):
super().__init__(
self._validate_model_name(model_name),
dataset_name,
epochs,
lr,
batch_size,
batch_size_eval,
max_length,
print_steps,
device,
probabilistic,
n_jobs,
evaluate_step,
verbose,
patience,
)
self.clf_type = classification_type
self.fitted = False
print(
f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]"
)
def _validate_model_name(self, model_name):
if "bert" == model_name:
return "bert-base-uncased"
elif "mbert" == model_name:
return "bert-base-multilingual-uncased"
elif "xlm" == model_name:
return "xlm-roberta-base"
elif "mt5" == model_name:
return "google/mt5-small"
else:
raise NotImplementedError
def load_pretrained_model(self, model_name, num_labels):
if model_name == "google/mt5-small":
return MT5ForSequenceClassification(
model_name, num_labels=num_labels, output_hidden_states=True
)
else:
return AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, output_hidden_states=True
)
def load_tokenizer(self, model_name):
return AutoTokenizer.from_pretrained(model_name)
def init_model(self, model_name, num_labels):
return self.load_pretrained_model(model_name, num_labels), self.load_tokenizer(
model_name
)
def _tokenize(self, X):
return self.tokenizer(
X,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_length,
)
def fit(self, lX, lY):
if self.fitted:
return self
print("- fitting Textual Transformer View Generating Function")
_l = list(lX.keys())[0]
self.num_labels = lY[_l].shape[-1]
self.model, self.tokenizer = self.init_model(
self.model_name, num_labels=self.num_labels
)
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
lX, lY, split=0.2, seed=42, modality="text"
)
tra_dataloader = self.build_dataloader(
tr_lX,
tr_lY,
processor_fn=self._tokenize,
torchDataset=MultilingualDatasetTorch,
batch_size=self.batch_size,
split="train",
shuffle=True,
)
val_dataloader = self.build_dataloader(
val_lX,
val_lY,
processor_fn=self._tokenize,
torchDataset=MultilingualDatasetTorch,
batch_size=self.batch_size_eval,
split="val",
shuffle=False,
)
experiment_name = (
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
)
trainer = Trainer(
model=self.model,
optimizer_name="adamW",
lr=self.lr,
device=self.device,
loss_fn=torch.nn.CrossEntropyLoss(),
print_steps=self.print_steps,
evaluate_step=self.evaluate_step,
patience=self.patience,
experiment_name=experiment_name,
checkpoint_path="models/vgfs/transformer",
vgf_name="textual_trf",
classification_type=self.clf_type,
n_jobs=self.n_jobs,
scheduler_name="ReduceLROnPlateau",
)
trainer.train(
train_dataloader=tra_dataloader,
eval_dataloader=val_dataloader,
epochs=self.epochs,
)
if self.probabilistic:
self.feature2posterior_projector.fit(self.transform(lX), lY)
self.fitted = True
return self
def transform(self, lX):
# forcing to only text modality
lX = {lang: data["text"] for lang, data in lX.items()}
_embeds = []
l_embeds = defaultdict(list)
dataloader = self.build_dataloader(
lX,
lY=None,
processor_fn=self._tokenize,
torchDataset=MultilingualDatasetTorch,
batch_size=self.batch_size_eval,
split="whole",
shuffle=False,
)
self.model.eval()
with torch.no_grad():
for input_ids, lang in dataloader:
input_ids = input_ids.to(self.device)
out = self.model(input_ids).hidden_states[-1]
batch_embeddings = out[:, 0, :].cpu().numpy()
_embeds.append((batch_embeddings, lang))
for embed, lang in _embeds:
for sample_embed, sample_lang in zip(embed, lang):
l_embeds[sample_lang].append(sample_embed)
if self.probabilistic and self.fitted:
l_embeds = self.feature2posterior_projector.transform(l_embeds)
elif not self.probabilistic and self.fitted:
l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()}
return l_embeds
def fit_transform(self, lX, lY):
return self.fit(lX, lY).transform(lX)
def save_vgf(self, model_id):
import pickle
from os import makedirs
from os.path import join
vgf_name = "textualTransformerGen"
_basedir = join("models", "vgfs", "textual_transformer")
makedirs(_basedir, exist_ok=True)
_path = join(_basedir, f"{vgf_name}_{model_id}.pkl")
with open(_path, "wb") as f:
pickle.dump(self, f)
return self
def __str__(self):
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
return str
class MultilingualDatasetTorch(Dataset):
def __init__(self, lX, lY, split="train"):
self.lX = lX
self.lY = lY
self.split = split
self.langs = []
self.init()
def init(self):
self.X = torch.vstack([data.input_ids for data in self.lX.values()])
if self.split != "whole":
self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()])
self.langs = sum(
[
v
for v in {
lang: [lang] * len(data.input_ids) for lang, data in self.lX.items()
}.values()
],
[],
)
return self
def __len__(self):
return len(self.X)
def __getitem__(self, index):
if self.split == "whole":
return self.X[index], self.langs[index]
return self.X[index], self.Y[index], self.langs[index]