implemented MT5ForSequenceClassification

This commit is contained in:
Andrea Pedrotti 2023-03-14 11:53:50 +01:00
parent a3e183d7fc
commit 5e41b4517a
1 changed files with 30 additions and 3 deletions

View File

@ -6,7 +6,9 @@ from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import transformers
from transformers import MT5EncoderModel
from torch.utils.data import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
@ -17,6 +19,24 @@ from gfun.vgfs.viewGen import ViewGen
transformers.logging.set_verbosity_error()
class MT5ForSequenceClassification(nn.Module):
def __init__(self, model_name, num_labels, output_hidden_states):
super().__init__()
self.mt5encoder = MT5EncoderModel.from_pretrained(
model_name, output_hidden_states=output_hidden_states
)
self.dropout = nn.Dropout(0.1)
self.linear = nn.Linear(512, num_labels)
def forward(self, input_ids, attn_mask):
# TODO: output hidden states
outputs = self.mt5encoder(input_ids=input_ids, attention_mask=attn_mask)
outputs = self.dropout(outputs[0])
outputs = self.linear(outputs)
return outputs
class TextualTransformerGen(ViewGen, TransformerGen):
def __init__(
self,
@ -65,13 +85,20 @@ class TextualTransformerGen(ViewGen, TransformerGen):
return "bert-base-multilingual-uncased"
elif "xlm" == model_name:
return "xlm-roberta-base"
elif "mt5" == model_name:
return "google/mt5-small"
else:
raise NotImplementedError
def load_pretrained_model(self, model_name, num_labels):
return AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, output_hidden_states=True
)
if model_name == "google/mt5-small":
return MT5ForSequenceClassification(
model_name, num_labels=num_labels, output_hidden_states=True
)
else:
return AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, output_hidden_states=True
)
def load_tokenizer(self, model_name):
return AutoTokenizer.from_pretrained(model_name)