implemented MT5ForSequenceClassification

This commit is contained in:
Andrea Pedrotti 2023-03-14 11:53:50 +01:00
parent a3e183d7fc
commit 5e41b4517a
1 changed files with 30 additions and 3 deletions

View File

@ -6,7 +6,9 @@ from collections import defaultdict
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
import transformers import transformers
from transformers import MT5EncoderModel
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer
@ -17,6 +19,24 @@ from gfun.vgfs.viewGen import ViewGen
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
class MT5ForSequenceClassification(nn.Module):
def __init__(self, model_name, num_labels, output_hidden_states):
super().__init__()
self.mt5encoder = MT5EncoderModel.from_pretrained(
model_name, output_hidden_states=output_hidden_states
)
self.dropout = nn.Dropout(0.1)
self.linear = nn.Linear(512, num_labels)
def forward(self, input_ids, attn_mask):
# TODO: output hidden states
outputs = self.mt5encoder(input_ids=input_ids, attention_mask=attn_mask)
outputs = self.dropout(outputs[0])
outputs = self.linear(outputs)
return outputs
class TextualTransformerGen(ViewGen, TransformerGen): class TextualTransformerGen(ViewGen, TransformerGen):
def __init__( def __init__(
self, self,
@ -65,10 +85,17 @@ class TextualTransformerGen(ViewGen, TransformerGen):
return "bert-base-multilingual-uncased" return "bert-base-multilingual-uncased"
elif "xlm" == model_name: elif "xlm" == model_name:
return "xlm-roberta-base" return "xlm-roberta-base"
elif "mt5" == model_name:
return "google/mt5-small"
else: else:
raise NotImplementedError raise NotImplementedError
def load_pretrained_model(self, model_name, num_labels): def load_pretrained_model(self, model_name, num_labels):
if model_name == "google/mt5-small":
return MT5ForSequenceClassification(
model_name, num_labels=num_labels, output_hidden_states=True
)
else:
return AutoModelForSequenceClassification.from_pretrained( return AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, output_hidden_states=True model_name, num_labels=num_labels, output_hidden_states=True
) )