implemented MT5ForSequenceClassification
This commit is contained in:
parent
a3e183d7fc
commit
5e41b4517a
|
@ -6,7 +6,9 @@ from collections import defaultdict
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from transformers import MT5EncoderModel
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
|
@ -17,6 +19,24 @@ from gfun.vgfs.viewGen import ViewGen
|
|||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
||||
class MT5ForSequenceClassification(nn.Module):
|
||||
def __init__(self, model_name, num_labels, output_hidden_states):
|
||||
super().__init__()
|
||||
|
||||
self.mt5encoder = MT5EncoderModel.from_pretrained(
|
||||
model_name, output_hidden_states=output_hidden_states
|
||||
)
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.linear = nn.Linear(512, num_labels)
|
||||
|
||||
def forward(self, input_ids, attn_mask):
|
||||
# TODO: output hidden states
|
||||
outputs = self.mt5encoder(input_ids=input_ids, attention_mask=attn_mask)
|
||||
outputs = self.dropout(outputs[0])
|
||||
outputs = self.linear(outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class TextualTransformerGen(ViewGen, TransformerGen):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -65,13 +85,20 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
|||
return "bert-base-multilingual-uncased"
|
||||
elif "xlm" == model_name:
|
||||
return "xlm-roberta-base"
|
||||
elif "mt5" == model_name:
|
||||
return "google/mt5-small"
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_pretrained_model(self, model_name, num_labels):
|
||||
return AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name, num_labels=num_labels, output_hidden_states=True
|
||||
)
|
||||
if model_name == "google/mt5-small":
|
||||
return MT5ForSequenceClassification(
|
||||
model_name, num_labels=num_labels, output_hidden_states=True
|
||||
)
|
||||
else:
|
||||
return AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name, num_labels=num_labels, output_hidden_states=True
|
||||
)
|
||||
|
||||
def load_tokenizer(self, model_name):
|
||||
return AutoTokenizer.from_pretrained(model_name)
|
||||
|
|
Loading…
Reference in New Issue