avoid training transformers

This commit is contained in:
Andrea Pedrotti 2023-06-22 11:32:50 +02:00
parent 2554c58fac
commit 60171c1b5e
1 changed files with 54 additions and 52 deletions

View File

@ -114,6 +114,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
model_name, num_labels=num_labels, output_hidden_states=True model_name, num_labels=num_labels, output_hidden_states=True
) )
else: else:
model_name = "models/vgfs/trained_transformer/mbert-sentiment/checkpoint-8500" # TODO hardcoded to pre-traiend mbert
return AutoModelForSequenceClassification.from_pretrained( return AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels, output_hidden_states=True model_name, num_labels=num_labels, output_hidden_states=True
) )
@ -145,58 +146,60 @@ class TextualTransformerGen(ViewGen, TransformerGen):
self.model_name, num_labels=self.num_labels self.model_name, num_labels=self.num_labels
) )
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data( self.model.to("cuda")
lX, lY, split=0.2, seed=42, modality="text"
)
tra_dataloader = self.build_dataloader( # tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
tr_lX, # lX, lY, split=0.2, seed=42, modality="text"
tr_lY, # )
processor_fn=self._tokenize, #
torchDataset=MultilingualDatasetTorch, # tra_dataloader = self.build_dataloader(
batch_size=self.batch_size, # tr_lX,
split="train", # tr_lY,
shuffle=True, # processor_fn=self._tokenize,
) # torchDataset=MultilingualDatasetTorch,
# batch_size=self.batch_size,
val_dataloader = self.build_dataloader( # split="train",
val_lX, # shuffle=True,
val_lY, # )
processor_fn=self._tokenize, #
torchDataset=MultilingualDatasetTorch, # val_dataloader = self.build_dataloader(
batch_size=self.batch_size_eval, # val_lX,
split="val", # val_lY,
shuffle=False, # processor_fn=self._tokenize,
) # torchDataset=MultilingualDatasetTorch,
# batch_size=self.batch_size_eval,
experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}" # split="val",
# shuffle=False,
trainer = Trainer( # )
model=self.model, #
optimizer_name="adamW", # experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
lr=self.lr, #
device=self.device, # trainer = Trainer(
loss_fn=torch.nn.CrossEntropyLoss(), # model=self.model,
print_steps=self.print_steps, # optimizer_name="adamW",
evaluate_step=self.evaluate_step, # lr=self.lr,
patience=self.patience, # device=self.device,
experiment_name=experiment_name, # loss_fn=torch.nn.CrossEntropyLoss(),
checkpoint_path=os.path.join( # print_steps=self.print_steps,
"models", # evaluate_step=self.evaluate_step,
"vgfs", # patience=self.patience,
"trained_transformer", # experiment_name=experiment_name,
self._format_model_name(self.model_name), # checkpoint_path=os.path.join(
), # "models",
vgf_name="textual_trf", # "vgfs",
classification_type=self.clf_type, # "trained_transformer",
n_jobs=self.n_jobs, # self._format_model_name(self.model_name),
scheduler_name=self.scheduler, # ),
) # vgf_name="textual_trf",
trainer.train( # classification_type=self.clf_type,
train_dataloader=tra_dataloader, # n_jobs=self.n_jobs,
eval_dataloader=val_dataloader, # scheduler_name=self.scheduler,
epochs=self.epochs, # )
) # trainer.train(
# train_dataloader=tra_dataloader,
# eval_dataloader=val_dataloader,
# epochs=self.epochs,
# )
if self.probabilistic: if self.probabilistic:
self.feature2posterior_projector.fit(self.transform(lX), lY) self.feature2posterior_projector.fit(self.transform(lX), lY)
@ -225,7 +228,6 @@ class TextualTransformerGen(ViewGen, TransformerGen):
with torch.no_grad(): with torch.no_grad():
for input_ids, lang in dataloader: for input_ids, lang in dataloader:
input_ids = input_ids.to(self.device) input_ids = input_ids.to(self.device)
# TODO: check this
if isinstance(self.model, MT5ForSequenceClassification): if isinstance(self.model, MT5ForSequenceClassification):
batch_embeddings = self.model(input_ids).pooled.cpu().numpy() batch_embeddings = self.model(input_ids).pooled.cpu().numpy()
else: else: