average pooling for MT5ForSequenceClassification and standardized return data
This commit is contained in:
parent
fece8d059e
commit
26aa0b327a
|
|
@ -11,6 +11,7 @@ import transformers
|
||||||
from transformers import MT5EncoderModel
|
from transformers import MT5EncoderModel
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
from transformers.modeling_outputs import ModelOutput
|
||||||
|
|
||||||
from gfun.vgfs.commons import Trainer
|
from gfun.vgfs.commons import Trainer
|
||||||
from gfun.vgfs.transformerGen import TransformerGen
|
from gfun.vgfs.transformerGen import TransformerGen
|
||||||
|
|
@ -23,18 +24,31 @@ class MT5ForSequenceClassification(nn.Module):
|
||||||
def __init__(self, model_name, num_labels, output_hidden_states):
|
def __init__(self, model_name, num_labels, output_hidden_states):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.output_hidden_states = output_hidden_states
|
||||||
self.mt5encoder = MT5EncoderModel.from_pretrained(
|
self.mt5encoder = MT5EncoderModel.from_pretrained(
|
||||||
model_name, output_hidden_states=output_hidden_states
|
model_name, output_hidden_states=True
|
||||||
)
|
)
|
||||||
self.dropout = nn.Dropout(0.1)
|
self.dropout = nn.Dropout(0.1)
|
||||||
self.linear = nn.Linear(512, num_labels)
|
self.linear = nn.Linear(512, num_labels)
|
||||||
|
|
||||||
def forward(self, input_ids, attn_mask):
|
def forward(self, input_ids):
|
||||||
# TODO: output hidden states
|
embed = self.mt5encoder(input_ids=input_ids)
|
||||||
outputs = self.mt5encoder(input_ids=input_ids, attention_mask=attn_mask)
|
pooled = torch.mean(embed.last_hidden_state, dim=1)
|
||||||
outputs = self.dropout(outputs[0])
|
outputs = self.dropout(pooled)
|
||||||
outputs = self.linear(outputs)
|
logits = self.linear(outputs)
|
||||||
return outputs
|
if self.output_hidden_states:
|
||||||
|
return ModelOutput(
|
||||||
|
logits=logits,
|
||||||
|
pooled=pooled,
|
||||||
|
)
|
||||||
|
return ModelOutput(logits=logits)
|
||||||
|
|
||||||
|
def save_pretrained(self, checkpoint_dir):
|
||||||
|
pass # TODO: implement
|
||||||
|
|
||||||
|
def from_pretrained(self, checkpoint_dir):
|
||||||
|
# TODO: implement
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class TextualTransformerGen(ViewGen, TransformerGen):
|
class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
|
|
@ -169,11 +183,12 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
vgf_name="textual_trf",
|
vgf_name="textual_trf",
|
||||||
classification_type=self.clf_type,
|
classification_type=self.clf_type,
|
||||||
n_jobs=self.n_jobs,
|
n_jobs=self.n_jobs,
|
||||||
scheduler_name="ReduceLROnPlateau",
|
# scheduler_name="ReduceLROnPlateau",
|
||||||
|
scheduler_name=None,
|
||||||
)
|
)
|
||||||
trainer.train(
|
trainer.train(
|
||||||
train_dataloader=tra_dataloader,
|
train_dataloader=tra_dataloader,
|
||||||
eval_dataloader=val_dataloader,
|
eval_dataloader=val_dataloader, # TODO: debug setting
|
||||||
epochs=self.epochs,
|
epochs=self.epochs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -204,6 +219,10 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for input_ids, lang in dataloader:
|
for input_ids, lang in dataloader:
|
||||||
input_ids = input_ids.to(self.device)
|
input_ids = input_ids.to(self.device)
|
||||||
|
# TODO: check this
|
||||||
|
if isinstance(self.model, MT5ForSequenceClassification):
|
||||||
|
batch_embeddings = self.model(input_ids).pooled.cpu().numpy()
|
||||||
|
else:
|
||||||
out = self.model(input_ids).hidden_states[-1]
|
out = self.model(input_ids).hidden_states[-1]
|
||||||
batch_embeddings = out[:, 0, :].cpu().numpy()
|
batch_embeddings = out[:, 0, :].cpu().numpy()
|
||||||
_embeds.append((batch_embeddings, lang))
|
_embeds.append((batch_embeddings, lang))
|
||||||
|
|
@ -235,6 +254,11 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
pickle.dump(self, f)
|
pickle.dump(self, f)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def freeze_model(self):
|
||||||
|
# TODO: up to n-layers? or all? avoid freezing head ovb...
|
||||||
|
for param in self.model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
str = f"[Transformer VGF (t)]\n- model_name: {self.model_name}\n- max_length: {self.max_length}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n"
|
||||||
return str
|
return str
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue