import os os.environ["TOKENIZERS_PARALLELISM"] = "true" from collections import defaultdict import numpy as np import torch import torch.nn as nn import transformers from transformers import MT5EncoderModel from torch.utils.data import Dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers.modeling_outputs import ModelOutput from gfun.vgfs.commons import Trainer from gfun.vgfs.transformerGen import TransformerGen from gfun.vgfs.viewGen import ViewGen transformers.logging.set_verbosity_error() class MT5ForSequenceClassification(nn.Module): def __init__(self, model_name, num_labels, output_hidden_states): super().__init__() self.output_hidden_states = output_hidden_states self.mt5encoder = MT5EncoderModel.from_pretrained( model_name, output_hidden_states=True ) self.dropout = nn.Dropout(0.1) self.linear = nn.Linear(512, num_labels) def forward(self, input_ids): embed = self.mt5encoder(input_ids=input_ids) pooled = torch.mean(embed.last_hidden_state, dim=1) outputs = self.dropout(pooled) logits = self.linear(outputs) if self.output_hidden_states: return ModelOutput( logits=logits, pooled=pooled, ) return ModelOutput(logits=logits) def save_pretrained(self, checkpoint_dir): pass # TODO: implement def from_pretrained(self, checkpoint_dir): # TODO: implement return self class TextualTransformerGen(ViewGen, TransformerGen): def __init__( self, model_name, dataset_name, epochs=10, lr=1e-5, batch_size=4, batch_size_eval=32, max_length=512, print_steps=50, device="cpu", probabilistic=False, n_jobs=-1, evaluate_step=10, verbose=False, patience=5, classification_type="multilabel", ): super().__init__( self._validate_model_name(model_name), dataset_name, epochs, lr, batch_size, batch_size_eval, max_length, print_steps, device, probabilistic, n_jobs, evaluate_step, verbose, patience, ) self.clf_type = classification_type self.fitted = False print( f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]" ) def _validate_model_name(self, model_name): if "bert" == model_name: return "bert-base-uncased" elif "mbert" == model_name: return "bert-base-multilingual-uncased" elif "xlm" == model_name: return "xlm-roberta-base" elif "mt5" == model_name: return "google/mt5-small" else: raise NotImplementedError def load_pretrained_model(self, model_name, num_labels): if model_name == "google/mt5-small": return MT5ForSequenceClassification( model_name, num_labels=num_labels, output_hidden_states=True ) else: return AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=num_labels, output_hidden_states=True ) def load_tokenizer(self, model_name): return AutoTokenizer.from_pretrained(model_name) def init_model(self, model_name, num_labels): return self.load_pretrained_model(model_name, num_labels), self.load_tokenizer( model_name ) def _tokenize(self, X): return self.tokenizer( X, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length, ) def fit(self, lX, lY): if self.fitted: return self print("- fitting Textual Transformer View Generating Function") _l = list(lX.keys())[0] self.num_labels = lY[_l].shape[-1] self.model, self.tokenizer = self.init_model( self.model_name, num_labels=self.num_labels ) tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data( lX, lY, split=0.2, seed=42, modality="text" ) tra_dataloader = self.build_dataloader( tr_lX, tr_lY, processor_fn=self._tokenize, torchDataset=MultilingualDatasetTorch, batch_size=self.batch_size, split="train", shuffle=True, ) val_dataloader = self.build_dataloader( val_lX, val_lY, processor_fn=self._tokenize, torchDataset=MultilingualDatasetTorch, batch_size=self.batch_size_eval, split="val", shuffle=False, ) experiment_name = ( f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}" ) trainer = Trainer( model=self.model, optimizer_name="adamW", lr=self.lr, device=self.device, loss_fn=torch.nn.CrossEntropyLoss(), print_steps=self.print_steps, evaluate_step=self.evaluate_step, patience=self.patience, experiment_name=experiment_name, checkpoint_path="models/vgfs/transformer", vgf_name="textual_trf", classification_type=self.clf_type, n_jobs=self.n_jobs, # scheduler_name="ReduceLROnPlateau", scheduler_name=None, ) trainer.train( train_dataloader=tra_dataloader, eval_dataloader=val_dataloader, # TODO: debug setting epochs=self.epochs, ) if self.probabilistic: self.feature2posterior_projector.fit(self.transform(lX), lY) self.fitted = True return self def transform(self, lX): # forcing to only text modality lX = {lang: data["text"] for lang, data in lX.items()} _embeds = [] l_embeds = defaultdict(list) dataloader = self.build_dataloader( lX, lY=None, processor_fn=self._tokenize, torchDataset=MultilingualDatasetTorch, batch_size=self.batch_size_eval, split="whole", shuffle=False, ) self.model.eval() with torch.no_grad(): for input_ids, lang in dataloader: input_ids = input_ids.to(self.device) # TODO: check this if isinstance(self.model, MT5ForSequenceClassification): batch_embeddings = self.model(input_ids).pooled.cpu().numpy() else: out = self.model(input_ids).hidden_states[-1] batch_embeddings = out[:, 0, :].cpu().numpy() _embeds.append((batch_embeddings, lang)) for embed, lang in _embeds: for sample_embed, sample_lang in zip(embed, lang): l_embeds[sample_lang].append(sample_embed) if self.probabilistic and self.fitted: l_embeds = self.feature2posterior_projector.transform(l_embeds) elif not self.probabilistic and self.fitted: l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()} return l_embeds def fit_transform(self, lX, lY): return self.fit(lX, lY).transform(lX) def save_vgf(self, model_id): import pickle from os import makedirs from os.path import join vgf_name = "textualTransformerGen" _basedir = join("models", "vgfs", "textual_transformer") makedirs(_basedir, exist_ok=True) _path = join(_basedir, f"{vgf_name}_{model_id}.pkl") with open(_path, "wb") as f: pickle.dump(self, f) return self def freeze_model(self): # TODO: up to n-layers? or all? avoid freezing head ovb... for param in self.model.parameters(): param.requires_grad = False def __str__(self): str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n" return str class MultilingualDatasetTorch(Dataset): def __init__(self, lX, lY, split="train"): self.lX = lX self.lY = lY self.split = split self.langs = [] self.init() def init(self): self.X = torch.vstack([data.input_ids for data in self.lX.values()]) if self.split != "whole": self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()]) self.langs = sum( [ v for v in { lang: [lang] * len(data.input_ids) for lang, data in self.lX.items() }.values() ], [], ) return self def __len__(self): return len(self.X) def __getitem__(self, index): if self.split == "whole": return self.X[index], self.langs[index] return self.X[index], self.Y[index], self.langs[index]