import os os.environ["TOKENIZERS_PARALLELISM"] = "true" from collections import defaultdict import numpy as np import torch import transformers # from sklearn.model_selection import train_test_split # from torch.optim import AdamW from torch.utils.data import Dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer from gfun.vgfs.commons import Trainer from gfun.vgfs.transformerGen import TransformerGen from gfun.vgfs.viewGen import ViewGen transformers.logging.set_verbosity_error() # TODO: add support to loggers class TextualTransformerGen(ViewGen, TransformerGen): def __init__( self, model_name, dataset_name, epochs=10, lr=1e-5, batch_size=4, batch_size_eval=32, max_length=512, print_steps=50, device="cpu", probabilistic=False, n_jobs=-1, evaluate_step=10, verbose=False, patience=5, ): super().__init__( self._validate_model_name(model_name), dataset_name, epochs, lr, batch_size, batch_size_eval, max_length, print_steps, device, probabilistic, n_jobs, evaluate_step, verbose, patience, ) self.fitted = False print( f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]" ) def _validate_model_name(self, model_name): if "bert" == model_name: return "bert-base-uncased" elif "mbert" == model_name: return "bert-base-multilingual-uncased" elif "xlm" == model_name: return "xlm-roberta-base" else: raise NotImplementedError def load_pretrained_model(self, model_name, num_labels): return AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=num_labels, output_hidden_states=True ) def load_tokenizer(self, model_name): return AutoTokenizer.from_pretrained(model_name) def init_model(self, model_name, num_labels): return self.load_pretrained_model(model_name, num_labels), self.load_tokenizer( model_name ) def _tokenize(self, X): return self.tokenizer( X, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length, ) def fit(self, lX, lY): if self.fitted: return self print("- fitting Textual Transformer View Generating Function") _l = list(lX.keys())[0] self.num_labels = lY[_l].shape[-1] self.model, self.tokenizer = self.init_model( self.model_name, num_labels=self.num_labels ) tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data( lX, lY, split=0.2, seed=42, modality="text" ) tra_dataloader = self.build_dataloader( tr_lX, tr_lY, processor_fn=self._tokenize, torchDataset=MultilingualDatasetTorch, batch_size=self.batch_size, split="train", shuffle=True, ) val_dataloader = self.build_dataloader( val_lX, val_lY, processor_fn=self._tokenize, torchDataset=MultilingualDatasetTorch, batch_size=self.batch_size_eval, split="val", shuffle=False, ) experiment_name = ( f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}" ) trainer = Trainer( model=self.model, optimizer_name="adamW", lr=self.lr, device=self.device, loss_fn=torch.nn.CrossEntropyLoss(), print_steps=self.print_steps, evaluate_step=self.evaluate_step, patience=self.patience, experiment_name=experiment_name, checkpoint_path="models/vgfs/transformer", ) trainer.train( train_dataloader=tra_dataloader, eval_dataloader=val_dataloader, epochs=self.epochs, ) if self.probabilistic: self.feature2posterior_projector.fit(self.transform(lX), lY) self.fitted = True return self def transform(self, lX): # forcing to only text modality lX = {lang: data["text"] for lang, data in lX.items()} _embeds = [] l_embeds = defaultdict(list) dataloader = self.build_dataloader( lX, lY=None, processor_fn=self._tokenize, torchDataset=MultilingualDatasetTorch, batch_size=self.batch_size_eval, split="whole", shuffle=False, ) self.model.eval() with torch.no_grad(): for input_ids, lang in dataloader: input_ids = input_ids.to(self.device) out = self.model(input_ids).hidden_states[-1] batch_embeddings = out[:, 0, :].cpu().numpy() _embeds.append((batch_embeddings, lang)) for embed, lang in _embeds: for sample_embed, sample_lang in zip(embed, lang): l_embeds[sample_lang].append(sample_embed) if self.probabilistic and self.fitted: l_embeds = self.feature2posterior_projector.transform(l_embeds) elif not self.probabilistic and self.fitted: l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()} return l_embeds def fit_transform(self, lX, lY): return self.fit(lX, lY).transform(lX) def save_vgf(self, model_id): import pickle from os import makedirs from os.path import join vgf_name = "textualTransformerGen" _basedir = join("models", "vgfs", "textual_transformer") makedirs(_basedir, exist_ok=True) _path = join(_basedir, f"{vgf_name}_{model_id}.pkl") with open(_path, "wb") as f: pickle.dump(self, f) return self def __str__(self): str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n" return str class MultilingualDatasetTorch(Dataset): def __init__(self, lX, lY, split="train"): self.lX = lX self.lY = lY self.split = split self.langs = [] self.init() def init(self): self.X = torch.vstack([data.input_ids for data in self.lX.values()]) if self.split != "whole": self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()]) self.langs = sum( [ v for v in { lang: [lang] * len(data.input_ids) for lang, data in self.lX.items() }.values() ], [], ) return self def __len__(self): return len(self.X) def __getitem__(self, index): if self.split == "whole": return self.X[index], self.langs[index] return self.X[index], self.Y[index], self.langs[index]