switched from mbert uncased to cased version

This commit is contained in:
Andrea Pedrotti 2023-07-03 19:04:26 +02:00
parent 6995854e3d
commit 8354d76513
1 changed files with 9 additions and 6 deletions

View File

@ -100,7 +100,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
if "bert" == model_name:
return "bert-base-uncased"
elif "mbert" == model_name:
return "bert-base-multilingual-uncased"
return "bert-base-multilingual-cased"
elif "xlm-roberta" == model_name:
return "xlm-roberta-base"
elif "mt5" == model_name:
@ -114,12 +114,14 @@ class TextualTransformerGen(ViewGen, TransformerGen):
model_name, num_labels=num_labels, output_hidden_states=True
)
else:
model_name = "models/vgfs/trained_transformer/mbert-sentiment/checkpoint-8500" # TODO hardcoded to pre-traiend mbert
# model_name = "models/vgfs/trained_transformer/mbert-sentiment/checkpoint-8500" # TODO hardcoded to pre-traiend mbert
model_name = "mbert-rai-multi-2000/checkpoint-1500" # TODO hardcoded to pre-traiend mbert
return AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, output_hidden_states=True
)
def load_tokenizer(self, model_name):
# model_name = "mbert-rai-multi-2000/checkpoint-1500" # TODO hardcoded to pre-traiend mbert
return AutoTokenizer.from_pretrained(model_name)
def init_model(self, model_name, num_labels):
@ -161,7 +163,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
# split="train",
# shuffle=True,
# )
#
#
# val_dataloader = self.build_dataloader(
# val_lX,
# val_lY,
@ -171,9 +173,9 @@ class TextualTransformerGen(ViewGen, TransformerGen):
# split="val",
# shuffle=False,
# )
#
#
# experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
#
#
# trainer = Trainer(
# model=self.model,
# optimizer_name="adamW",
@ -202,7 +204,8 @@ class TextualTransformerGen(ViewGen, TransformerGen):
# )
if self.probabilistic:
self.feature2posterior_projector.fit(self.transform(lX), lY)
transformed = self.transform(lX)
self.feature2posterior_projector.fit(transformed, lY)
self.fitted = True