gfun_multimodal/gfun/vgfs/textualTransformerGen.py

287 lines
9.4 KiB
Python

import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import transformers
from transformers import MT5EncoderModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers.modeling_outputs import ModelOutput
from gfun.vgfs.commons import Trainer
from gfun.vgfs.transformerGen import TransformerGen
from gfun.vgfs.viewGen import ViewGen
from dataManager.torchDataset import MultilingualDatasetTorch
transformers.logging.set_verbosity_error()
class MT5ForSequenceClassification(nn.Module):
def __init__(self, model_name, num_labels, output_hidden_states):
super().__init__()
self.output_hidden_states = output_hidden_states
self.mt5encoder = MT5EncoderModel.from_pretrained(
model_name, output_hidden_states=True
)
self.dropout = nn.Dropout(0.1)
self.linear = nn.Linear(512, num_labels)
def forward(self, input_ids):
embed = self.mt5encoder(input_ids=input_ids)
pooled = torch.mean(embed.last_hidden_state, dim=1)
outputs = self.dropout(pooled)
logits = self.linear(outputs)
if self.output_hidden_states:
return ModelOutput(
logits=logits,
pooled=pooled,
)
return ModelOutput(logits=logits)
def save_pretrained(self, checkpoint_dir):
torch.save(self.state_dict(), checkpoint_dir + ".pt")
return self
def from_pretrained(self, checkpoint_dir):
checkpoint_dir += ".pt"
self.load_state_dict(torch.load(checkpoint_dir))
return self
class TextualTransformerGen(ViewGen, TransformerGen):
def __init__(
self,
model_name,
dataset_name,
epochs=10,
lr=1e-5,
batch_size=4,
batch_size_eval=32,
max_length=512,
print_steps=50,
device="cpu",
probabilistic=False,
n_jobs=-1,
evaluate_step=10,
verbose=False,
patience=5,
classification_type="multilabel",
scheduler="ReduceLROnPlateau",
):
super().__init__(
self._validate_model_name(model_name),
dataset_name,
epochs=epochs,
lr=lr,
scheduler=scheduler,
batch_size=batch_size,
batch_size_eval=batch_size_eval,
device=device,
evaluate_step=evaluate_step,
patience=patience,
probabilistic=probabilistic,
max_length=max_length,
print_steps=print_steps,
n_jobs=n_jobs,
verbose=verbose,
)
self.clf_type = classification_type
self.fitted = False
print(
f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]"
)
def _validate_model_name(self, model_name):
if "bert" == model_name:
return "bert-base-uncased"
elif "mbert" == model_name:
return "bert-base-multilingual-cased"
elif "xlm-roberta" == model_name:
return "xlm-roberta-base"
elif "mt5" == model_name:
return "google/mt5-small"
else:
raise NotImplementedError
def load_pretrained_model(self, model_name, num_labels):
if model_name == "google/mt5-small":
return MT5ForSequenceClassification(
model_name, num_labels=num_labels, output_hidden_states=True
)
else:
# model_name = "models/vgfs/trained_transformer/mbert-sentiment/checkpoint-8500" # TODO hardcoded to pre-traiend mbert
model_name = "mbert-rai-multi-2000/checkpoint-1500" # TODO hardcoded to pre-traiend mbert
return AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, output_hidden_states=True
)
def load_tokenizer(self, model_name):
# model_name = "mbert-rai-multi-2000/checkpoint-1500" # TODO hardcoded to pre-traiend mbert
return AutoTokenizer.from_pretrained(model_name)
def init_model(self, model_name, num_labels):
return self.load_pretrained_model(model_name, num_labels), self.load_tokenizer(
model_name
)
def _tokenize(self, X):
return self.tokenizer(
X,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.max_length,
)
def fit(self, lX, lY):
if self.fitted:
return self
print("- fitting Textual Transformer View Generating Function")
_l = list(lX.keys())[0]
self.num_labels = lY[_l].shape[-1]
self.model, self.tokenizer = self.init_model(
self.model_name, num_labels=self.num_labels
)
self.model.to("cuda")
# tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
# lX, lY, split=0.2, seed=42, modality="text"
# )
#
# tra_dataloader = self.build_dataloader(
# tr_lX,
# tr_lY,
# processor_fn=self._tokenize,
# torchDataset=MultilingualDatasetTorch,
# batch_size=self.batch_size,
# split="train",
# shuffle=True,
# )
#
# val_dataloader = self.build_dataloader(
# val_lX,
# val_lY,
# processor_fn=self._tokenize,
# torchDataset=MultilingualDatasetTorch,
# batch_size=self.batch_size_eval,
# split="val",
# shuffle=False,
# )
#
# experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
#
# trainer = Trainer(
# model=self.model,
# optimizer_name="adamW",
# lr=self.lr,
# device=self.device,
# loss_fn=torch.nn.CrossEntropyLoss(),
# print_steps=self.print_steps,
# evaluate_step=self.evaluate_step,
# patience=self.patience,
# experiment_name=experiment_name,
# checkpoint_path=os.path.join(
# "models",
# "vgfs",
# "trained_transformer",
# self._format_model_name(self.model_name),
# ),
# vgf_name="textual_trf",
# classification_type=self.clf_type,
# n_jobs=self.n_jobs,
# scheduler_name=self.scheduler,
# )
# trainer.train(
# train_dataloader=tra_dataloader,
# eval_dataloader=val_dataloader,
# epochs=self.epochs,
# )
if self.probabilistic:
transformed = self.transform(lX)
self.feature2posterior_projector.fit(transformed, lY)
self.fitted = True
return self
def transform(self, lX):
# forcing to only text modality
lX = {lang: data["text"] for lang, data in lX.items()}
_embeds = []
l_embeds = defaultdict(list)
dataloader = self.build_dataloader(
lX,
lY=None,
processor_fn=self._tokenize,
torchDataset=MultilingualDatasetTorch,
batch_size=self.batch_size_eval,
split="whole",
shuffle=False,
)
self.model.eval()
with torch.no_grad():
for input_ids, lang in dataloader:
input_ids = input_ids.to(self.device)
if isinstance(self.model, MT5ForSequenceClassification):
batch_embeddings = self.model(input_ids).pooled.cpu().numpy()
else:
out = self.model(input_ids).hidden_states[-1]
batch_embeddings = out[:, 0, :].cpu().numpy()
_embeds.append((batch_embeddings, lang))
for embed, lang in _embeds:
for sample_embed, sample_lang in zip(embed, lang):
l_embeds[sample_lang].append(sample_embed)
if self.probabilistic and self.fitted:
l_embeds = self.feature2posterior_projector.transform(l_embeds)
elif not self.probabilistic and self.fitted:
l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()}
return l_embeds
def fit_transform(self, lX, lY):
return self.fit(lX, lY).transform(lX)
def save_vgf(self, model_id):
import pickle
from os import makedirs
from os.path import join
vgf_name = "textualTransformerGen"
_basedir = join("models", "vgfs", "textual_transformer")
makedirs(_basedir, exist_ok=True)
_path = join(_basedir, f"{vgf_name}_{model_id}.pkl")
with open(_path, "wb") as f:
pickle.dump(self, f)
return self
def freeze_model(self):
# TODO: up to n-layers? or all? avoid freezing head ovb...
for param in self.model.parameters():
param.requires_grad = False
def _format_model_name(self, model_name):
if "mt5" in model_name:
return "google-mt5"
elif "bert" in model_name:
if "multilingual" in model_name:
return "mbert"
elif "xlm-roberta" in model_name:
return "xlm-roberta"
else:
return model_name
def get_config(self):
c = super().get_config()
return {"name": "textual-trasnformer VGF", "textual_trf": c}