average pooling for MT5ForSequenceClassification and standardized return data
This commit is contained in:
parent
fece8d059e
commit
26aa0b327a
|
|
@ -11,6 +11,7 @@ import transformers
|
|||
from transformers import MT5EncoderModel
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from transformers.modeling_outputs import ModelOutput
|
||||
|
||||
from gfun.vgfs.commons import Trainer
|
||||
from gfun.vgfs.transformerGen import TransformerGen
|
||||
|
|
@ -23,18 +24,31 @@ class MT5ForSequenceClassification(nn.Module):
|
|||
def __init__(self, model_name, num_labels, output_hidden_states):
|
||||
super().__init__()
|
||||
|
||||
self.output_hidden_states = output_hidden_states
|
||||
self.mt5encoder = MT5EncoderModel.from_pretrained(
|
||||
model_name, output_hidden_states=output_hidden_states
|
||||
model_name, output_hidden_states=True
|
||||
)
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.linear = nn.Linear(512, num_labels)
|
||||
|
||||
def forward(self, input_ids, attn_mask):
|
||||
# TODO: output hidden states
|
||||
outputs = self.mt5encoder(input_ids=input_ids, attention_mask=attn_mask)
|
||||
outputs = self.dropout(outputs[0])
|
||||
outputs = self.linear(outputs)
|
||||
return outputs
|
||||
def forward(self, input_ids):
|
||||
embed = self.mt5encoder(input_ids=input_ids)
|
||||
pooled = torch.mean(embed.last_hidden_state, dim=1)
|
||||
outputs = self.dropout(pooled)
|
||||
logits = self.linear(outputs)
|
||||
if self.output_hidden_states:
|
||||
return ModelOutput(
|
||||
logits=logits,
|
||||
pooled=pooled,
|
||||
)
|
||||
return ModelOutput(logits=logits)
|
||||
|
||||
def save_pretrained(self, checkpoint_dir):
|
||||
pass # TODO: implement
|
||||
|
||||
def from_pretrained(self, checkpoint_dir):
|
||||
# TODO: implement
|
||||
return self
|
||||
|
||||
|
||||
class TextualTransformerGen(ViewGen, TransformerGen):
|
||||
|
|
@ -169,11 +183,12 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
vgf_name="textual_trf",
|
||||
classification_type=self.clf_type,
|
||||
n_jobs=self.n_jobs,
|
||||
scheduler_name="ReduceLROnPlateau",
|
||||
# scheduler_name="ReduceLROnPlateau",
|
||||
scheduler_name=None,
|
||||
)
|
||||
trainer.train(
|
||||
train_dataloader=tra_dataloader,
|
||||
eval_dataloader=val_dataloader,
|
||||
eval_dataloader=val_dataloader, # TODO: debug setting
|
||||
epochs=self.epochs,
|
||||
)
|
||||
|
||||
|
|
@ -204,6 +219,10 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
with torch.no_grad():
|
||||
for input_ids, lang in dataloader:
|
||||
input_ids = input_ids.to(self.device)
|
||||
# TODO: check this
|
||||
if isinstance(self.model, MT5ForSequenceClassification):
|
||||
batch_embeddings = self.model(input_ids).pooled.cpu().numpy()
|
||||
else:
|
||||
out = self.model(input_ids).hidden_states[-1]
|
||||
batch_embeddings = out[:, 0, :].cpu().numpy()
|
||||
_embeds.append((batch_embeddings, lang))
|
||||
|
|
@ -235,6 +254,11 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
pickle.dump(self, f)
|
||||
return self
|
||||
|
||||
def freeze_model(self):
|
||||
# TODO: up to n-layers? or all? avoid freezing head ovb...
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def __str__(self):
|
||||
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
||||
return str
|
||||
|
|
|
|||
Loading…
Reference in New Issue