diff --git a/gfun/vgfs/textualTransformerGen.py b/gfun/vgfs/textualTransformerGen.py index 8f8d661..70e9909 100644 --- a/gfun/vgfs/textualTransformerGen.py +++ b/gfun/vgfs/textualTransformerGen.py @@ -6,7 +6,9 @@ from collections import defaultdict import numpy as np import torch +import torch.nn as nn import transformers +from transformers import MT5EncoderModel from torch.utils.data import Dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer @@ -17,6 +19,24 @@ from gfun.vgfs.viewGen import ViewGen transformers.logging.set_verbosity_error() +class MT5ForSequenceClassification(nn.Module): + def __init__(self, model_name, num_labels, output_hidden_states): + super().__init__() + + self.mt5encoder = MT5EncoderModel.from_pretrained( + model_name, output_hidden_states=output_hidden_states + ) + self.dropout = nn.Dropout(0.1) + self.linear = nn.Linear(512, num_labels) + + def forward(self, input_ids, attn_mask): + # TODO: output hidden states + outputs = self.mt5encoder(input_ids=input_ids, attention_mask=attn_mask) + outputs = self.dropout(outputs[0]) + outputs = self.linear(outputs) + return outputs + + class TextualTransformerGen(ViewGen, TransformerGen): def __init__( self, @@ -65,13 +85,20 @@ class TextualTransformerGen(ViewGen, TransformerGen): return "bert-base-multilingual-uncased" elif "xlm" == model_name: return "xlm-roberta-base" + elif "mt5" == model_name: + return "google/mt5-small" else: raise NotImplementedError def load_pretrained_model(self, model_name, num_labels): - return AutoModelForSequenceClassification.from_pretrained( - model_name, num_labels=num_labels, output_hidden_states=True - ) + if model_name == "google/mt5-small": + return MT5ForSequenceClassification( + model_name, num_labels=num_labels, output_hidden_states=True + ) + else: + return AutoModelForSequenceClassification.from_pretrained( + model_name, num_labels=num_labels, output_hidden_states=True + ) def load_tokenizer(self, model_name): return AutoTokenizer.from_pretrained(model_name)