implemented MT5ForSequenceClassification
This commit is contained in:
parent
a3e183d7fc
commit
5e41b4517a
|
|
@ -6,7 +6,9 @@ from collections import defaultdict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import transformers
|
import transformers
|
||||||
|
from transformers import MT5EncoderModel
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
|
||||||
|
|
@ -17,6 +19,24 @@ from gfun.vgfs.viewGen import ViewGen
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
|
||||||
|
class MT5ForSequenceClassification(nn.Module):
|
||||||
|
def __init__(self, model_name, num_labels, output_hidden_states):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.mt5encoder = MT5EncoderModel.from_pretrained(
|
||||||
|
model_name, output_hidden_states=output_hidden_states
|
||||||
|
)
|
||||||
|
self.dropout = nn.Dropout(0.1)
|
||||||
|
self.linear = nn.Linear(512, num_labels)
|
||||||
|
|
||||||
|
def forward(self, input_ids, attn_mask):
|
||||||
|
# TODO: output hidden states
|
||||||
|
outputs = self.mt5encoder(input_ids=input_ids, attention_mask=attn_mask)
|
||||||
|
outputs = self.dropout(outputs[0])
|
||||||
|
outputs = self.linear(outputs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class TextualTransformerGen(ViewGen, TransformerGen):
|
class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -65,13 +85,20 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
return "bert-base-multilingual-uncased"
|
return "bert-base-multilingual-uncased"
|
||||||
elif "xlm" == model_name:
|
elif "xlm" == model_name:
|
||||||
return "xlm-roberta-base"
|
return "xlm-roberta-base"
|
||||||
|
elif "mt5" == model_name:
|
||||||
|
return "google/mt5-small"
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def load_pretrained_model(self, model_name, num_labels):
|
def load_pretrained_model(self, model_name, num_labels):
|
||||||
return AutoModelForSequenceClassification.from_pretrained(
|
if model_name == "google/mt5-small":
|
||||||
model_name, num_labels=num_labels, output_hidden_states=True
|
return MT5ForSequenceClassification(
|
||||||
)
|
model_name, num_labels=num_labels, output_hidden_states=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
model_name, num_labels=num_labels, output_hidden_states=True
|
||||||
|
)
|
||||||
|
|
||||||
def load_tokenizer(self, model_name):
|
def load_tokenizer(self, model_name):
|
||||||
return AutoTokenizer.from_pretrained(model_name)
|
return AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue