average pooling for MT5ForSequenceClassification and standardized return data

This commit is contained in:
Andrea Pedrotti 2023-03-15 11:46:53 +01:00
parent fece8d059e
commit 26aa0b327a
1 changed files with 35 additions and 11 deletions

View File

@ -11,6 +11,7 @@ import transformers
from transformers import MT5EncoderModel
from torch.utils.data import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers.modeling_outputs import ModelOutput
from gfun.vgfs.commons import Trainer
from gfun.vgfs.transformerGen import TransformerGen
@ -23,18 +24,31 @@ class MT5ForSequenceClassification(nn.Module):
def __init__(self, model_name, num_labels, output_hidden_states):
super().__init__()
self.output_hidden_states = output_hidden_states
self.mt5encoder = MT5EncoderModel.from_pretrained(
model_name, output_hidden_states=output_hidden_states
model_name, output_hidden_states=True
)
self.dropout = nn.Dropout(0.1)
self.linear = nn.Linear(512, num_labels)
def forward(self, input_ids, attn_mask):
# TODO: output hidden states
outputs = self.mt5encoder(input_ids=input_ids, attention_mask=attn_mask)
outputs = self.dropout(outputs[0])
outputs = self.linear(outputs)
return outputs
def forward(self, input_ids):
embed = self.mt5encoder(input_ids=input_ids)
pooled = torch.mean(embed.last_hidden_state, dim=1)
outputs = self.dropout(pooled)
logits = self.linear(outputs)
if self.output_hidden_states:
return ModelOutput(
logits=logits,
pooled=pooled,
)
return ModelOutput(logits=logits)
def save_pretrained(self, checkpoint_dir):
pass # TODO: implement
def from_pretrained(self, checkpoint_dir):
# TODO: implement
return self
class TextualTransformerGen(ViewGen, TransformerGen):
@ -169,11 +183,12 @@ class TextualTransformerGen(ViewGen, TransformerGen):
vgf_name="textual_trf",
classification_type=self.clf_type,
n_jobs=self.n_jobs,
scheduler_name="ReduceLROnPlateau",
# scheduler_name="ReduceLROnPlateau",
scheduler_name=None,
)
trainer.train(
train_dataloader=tra_dataloader,
eval_dataloader=val_dataloader,
eval_dataloader=val_dataloader, # TODO: debug setting
epochs=self.epochs,
)
@ -204,6 +219,10 @@ class TextualTransformerGen(ViewGen, TransformerGen):
with torch.no_grad():
for input_ids, lang in dataloader:
input_ids = input_ids.to(self.device)
# TODO: check this
if isinstance(self.model, MT5ForSequenceClassification):
batch_embeddings = self.model(input_ids).pooled.cpu().numpy()
else:
out = self.model(input_ids).hidden_states[-1]
batch_embeddings = out[:, 0, :].cpu().numpy()
_embeds.append((batch_embeddings, lang))
@ -235,6 +254,11 @@ class TextualTransformerGen(ViewGen, TransformerGen):
pickle.dump(self, f)
return self
def freeze_model(self):
# TODO: up to n-layers? or all? avoid freezing head ovb...
for param in self.model.parameters():
param.requires_grad = False
def __str__(self):
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
return str