import os os.environ["TOKENIZERS_PARALLELISM"] = "true" from collections import defaultdict import numpy as np import torch import torch.nn as nn import transformers from transformers import MT5EncoderModel from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers.modeling_outputs import ModelOutput from gfun.vgfs.commons import Trainer from gfun.vgfs.transformerGen import TransformerGen from gfun.vgfs.viewGen import ViewGen from dataManager.torchDataset import MultilingualDatasetTorch transformers.logging.set_verbosity_error() class MT5ForSequenceClassification(nn.Module): def __init__(self, model_name, num_labels, output_hidden_states): super().__init__() self.output_hidden_states = output_hidden_states self.mt5encoder = MT5EncoderModel.from_pretrained( model_name, output_hidden_states=True ) self.dropout = nn.Dropout(0.1) self.linear = nn.Linear(512, num_labels) def forward(self, input_ids): embed = self.mt5encoder(input_ids=input_ids) pooled = torch.mean(embed.last_hidden_state, dim=1) outputs = self.dropout(pooled) logits = self.linear(outputs) if self.output_hidden_states: return ModelOutput( logits=logits, pooled=pooled, ) return ModelOutput(logits=logits) def save_pretrained(self, checkpoint_dir): torch.save(self.state_dict(), checkpoint_dir + ".pt") return self def from_pretrained(self, checkpoint_dir): checkpoint_dir += ".pt" self.load_state_dict(torch.load(checkpoint_dir)) return self class TextualTransformerGen(ViewGen, TransformerGen): def __init__( self, model_name, dataset_name, epochs=10, lr=1e-5, batch_size=4, batch_size_eval=32, max_length=512, print_steps=50, device="cpu", probabilistic=False, n_jobs=-1, evaluate_step=10, verbose=False, patience=5, classification_type="multilabel", scheduler="ReduceLROnPlateau", ): super().__init__( self._validate_model_name(model_name), dataset_name, epochs=epochs, lr=lr, scheduler=scheduler, batch_size=batch_size, batch_size_eval=batch_size_eval, device=device, evaluate_step=evaluate_step, patience=patience, probabilistic=probabilistic, max_length=max_length, print_steps=print_steps, n_jobs=n_jobs, verbose=verbose, ) self.clf_type = classification_type self.fitted = False print( f"- init Textual TransformerModel model_name: {self.model_name}, device: {self.device}]" ) def _validate_model_name(self, model_name): if "bert" == model_name: return "bert-base-uncased" elif "mbert" == model_name: return "bert-base-multilingual-cased" elif "xlm-roberta" == model_name: return "xlm-roberta-base" elif "mt5" == model_name: return "google/mt5-small" else: raise NotImplementedError def load_pretrained_model(self, model_name, num_labels): if model_name == "google/mt5-small": return MT5ForSequenceClassification( model_name, num_labels=num_labels, output_hidden_states=True ) else: # model_name = "models/vgfs/trained_transformer/mbert-sentiment/checkpoint-8500" # TODO hardcoded to pre-traiend mbert model_name = "mbert-rai-multi-2000/checkpoint-1500" # TODO hardcoded to pre-traiend mbert return AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=num_labels, output_hidden_states=True ) def load_tokenizer(self, model_name): # model_name = "mbert-rai-multi-2000/checkpoint-1500" # TODO hardcoded to pre-traiend mbert return AutoTokenizer.from_pretrained(model_name) def init_model(self, model_name, num_labels): return self.load_pretrained_model(model_name, num_labels), self.load_tokenizer( model_name ) def _tokenize(self, X): return self.tokenizer( X, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length, ) def fit(self, lX, lY): if self.fitted: return self print("- fitting Textual Transformer View Generating Function") _l = list(lX.keys())[0] self.num_labels = lY[_l].shape[-1] self.model, self.tokenizer = self.init_model( self.model_name, num_labels=self.num_labels ) self.model.to("cuda") # tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data( # lX, lY, split=0.2, seed=42, modality="text" # ) # # tra_dataloader = self.build_dataloader( # tr_lX, # tr_lY, # processor_fn=self._tokenize, # torchDataset=MultilingualDatasetTorch, # batch_size=self.batch_size, # split="train", # shuffle=True, # ) # # val_dataloader = self.build_dataloader( # val_lX, # val_lY, # processor_fn=self._tokenize, # torchDataset=MultilingualDatasetTorch, # batch_size=self.batch_size_eval, # split="val", # shuffle=False, # ) # # experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}" # # trainer = Trainer( # model=self.model, # optimizer_name="adamW", # lr=self.lr, # device=self.device, # loss_fn=torch.nn.CrossEntropyLoss(), # print_steps=self.print_steps, # evaluate_step=self.evaluate_step, # patience=self.patience, # experiment_name=experiment_name, # checkpoint_path=os.path.join( # "models", # "vgfs", # "trained_transformer", # self._format_model_name(self.model_name), # ), # vgf_name="textual_trf", # classification_type=self.clf_type, # n_jobs=self.n_jobs, # scheduler_name=self.scheduler, # ) # trainer.train( # train_dataloader=tra_dataloader, # eval_dataloader=val_dataloader, # epochs=self.epochs, # ) if self.probabilistic: transformed = self.transform(lX) self.feature2posterior_projector.fit(transformed, lY) self.fitted = True return self def transform(self, lX): # forcing to only text modality lX = {lang: data["text"] for lang, data in lX.items()} _embeds = [] l_embeds = defaultdict(list) dataloader = self.build_dataloader( lX, lY=None, processor_fn=self._tokenize, torchDataset=MultilingualDatasetTorch, batch_size=self.batch_size_eval, split="whole", shuffle=False, ) self.model.eval() with torch.no_grad(): for input_ids, lang in dataloader: input_ids = input_ids.to(self.device) if isinstance(self.model, MT5ForSequenceClassification): batch_embeddings = self.model(input_ids).pooled.cpu().numpy() else: out = self.model(input_ids).hidden_states[-1] batch_embeddings = out[:, 0, :].cpu().numpy() _embeds.append((batch_embeddings, lang)) for embed, lang in _embeds: for sample_embed, sample_lang in zip(embed, lang): l_embeds[sample_lang].append(sample_embed) if self.probabilistic and self.fitted: l_embeds = self.feature2posterior_projector.transform(l_embeds) elif not self.probabilistic and self.fitted: l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()} return l_embeds def fit_transform(self, lX, lY): return self.fit(lX, lY).transform(lX) def save_vgf(self, model_id): import pickle from os import makedirs from os.path import join vgf_name = "textualTransformerGen" _basedir = join("models", "vgfs", "textual_transformer") makedirs(_basedir, exist_ok=True) _path = join(_basedir, f"{vgf_name}_{model_id}.pkl") with open(_path, "wb") as f: pickle.dump(self, f) return self def freeze_model(self): # TODO: up to n-layers? or all? avoid freezing head ovb... for param in self.model.parameters(): param.requires_grad = False def _format_model_name(self, model_name): if "mt5" in model_name: return "google-mt5" elif "bert" in model_name: if "multilingual" in model_name: return "mbert" elif "xlm-roberta" in model_name: return "xlm-roberta" else: return model_name def get_config(self): c = super().get_config() return {"name": "textual-trasnformer VGF", "textual_trf": c}