diff --git a/gfun/vgfs/textualTransformerGen.py b/gfun/vgfs/textualTransformerGen.py index 70e9909..ed5a92c 100644 --- a/gfun/vgfs/textualTransformerGen.py +++ b/gfun/vgfs/textualTransformerGen.py @@ -11,6 +11,7 @@ import transformers from transformers import MT5EncoderModel from torch.utils.data import Dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer +from transformers.modeling_outputs import ModelOutput from gfun.vgfs.commons import Trainer from gfun.vgfs.transformerGen import TransformerGen @@ -23,18 +24,31 @@ class MT5ForSequenceClassification(nn.Module): def __init__(self, model_name, num_labels, output_hidden_states): super().__init__() + self.output_hidden_states = output_hidden_states self.mt5encoder = MT5EncoderModel.from_pretrained( - model_name, output_hidden_states=output_hidden_states + model_name, output_hidden_states=True ) self.dropout = nn.Dropout(0.1) self.linear = nn.Linear(512, num_labels) - def forward(self, input_ids, attn_mask): - # TODO: output hidden states - outputs = self.mt5encoder(input_ids=input_ids, attention_mask=attn_mask) - outputs = self.dropout(outputs[0]) - outputs = self.linear(outputs) - return outputs + def forward(self, input_ids): + embed = self.mt5encoder(input_ids=input_ids) + pooled = torch.mean(embed.last_hidden_state, dim=1) + outputs = self.dropout(pooled) + logits = self.linear(outputs) + if self.output_hidden_states: + return ModelOutput( + logits=logits, + pooled=pooled, + ) + return ModelOutput(logits=logits) + + def save_pretrained(self, checkpoint_dir): + pass # TODO: implement + + def from_pretrained(self, checkpoint_dir): + # TODO: implement + return self class TextualTransformerGen(ViewGen, TransformerGen): @@ -169,11 +183,12 @@ class TextualTransformerGen(ViewGen, TransformerGen): vgf_name="textual_trf", classification_type=self.clf_type, n_jobs=self.n_jobs, - scheduler_name="ReduceLROnPlateau", + # scheduler_name="ReduceLROnPlateau", + scheduler_name=None, ) trainer.train( train_dataloader=tra_dataloader, - eval_dataloader=val_dataloader, + eval_dataloader=val_dataloader, # TODO: debug setting epochs=self.epochs, ) @@ -204,8 +219,12 @@ class TextualTransformerGen(ViewGen, TransformerGen): with torch.no_grad(): for input_ids, lang in dataloader: input_ids = input_ids.to(self.device) - out = self.model(input_ids).hidden_states[-1] - batch_embeddings = out[:, 0, :].cpu().numpy() + # TODO: check this + if isinstance(self.model, MT5ForSequenceClassification): + batch_embeddings = self.model(input_ids).pooled.cpu().numpy() + else: + out = self.model(input_ids).hidden_states[-1] + batch_embeddings = out[:, 0, :].cpu().numpy() _embeds.append((batch_embeddings, lang)) for embed, lang in _embeds: @@ -235,6 +254,11 @@ class TextualTransformerGen(ViewGen, TransformerGen): pickle.dump(self, f) return self + def freeze_model(self): + # TODO: up to n-layers? or all? avoid freezing head ovb... + for param in self.model.parameters(): + param.requires_grad = False + def __str__(self): str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n" return str