diff --git a/gfun/vgfs/textualTransformerGen.py b/gfun/vgfs/textualTransformerGen.py index 5ed9d16..93f86b8 100644 --- a/gfun/vgfs/textualTransformerGen.py +++ b/gfun/vgfs/textualTransformerGen.py @@ -114,6 +114,7 @@ class TextualTransformerGen(ViewGen, TransformerGen): model_name, num_labels=num_labels, output_hidden_states=True ) else: + model_name = "models/vgfs/trained_transformer/mbert-sentiment/checkpoint-8500" # TODO hardcoded to pre-traiend mbert return AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=num_labels, output_hidden_states=True ) @@ -145,58 +146,60 @@ class TextualTransformerGen(ViewGen, TransformerGen): self.model_name, num_labels=self.num_labels ) - tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data( - lX, lY, split=0.2, seed=42, modality="text" - ) + self.model.to("cuda") - tra_dataloader = self.build_dataloader( - tr_lX, - tr_lY, - processor_fn=self._tokenize, - torchDataset=MultilingualDatasetTorch, - batch_size=self.batch_size, - split="train", - shuffle=True, - ) - - val_dataloader = self.build_dataloader( - val_lX, - val_lY, - processor_fn=self._tokenize, - torchDataset=MultilingualDatasetTorch, - batch_size=self.batch_size_eval, - split="val", - shuffle=False, - ) - - experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}" - - trainer = Trainer( - model=self.model, - optimizer_name="adamW", - lr=self.lr, - device=self.device, - loss_fn=torch.nn.CrossEntropyLoss(), - print_steps=self.print_steps, - evaluate_step=self.evaluate_step, - patience=self.patience, - experiment_name=experiment_name, - checkpoint_path=os.path.join( - "models", - "vgfs", - "trained_transformer", - self._format_model_name(self.model_name), - ), - vgf_name="textual_trf", - classification_type=self.clf_type, - n_jobs=self.n_jobs, - scheduler_name=self.scheduler, - ) - trainer.train( - train_dataloader=tra_dataloader, - eval_dataloader=val_dataloader, - epochs=self.epochs, - ) + # tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data( + # lX, lY, split=0.2, seed=42, modality="text" + # ) + # + # tra_dataloader = self.build_dataloader( + # tr_lX, + # tr_lY, + # processor_fn=self._tokenize, + # torchDataset=MultilingualDatasetTorch, + # batch_size=self.batch_size, + # split="train", + # shuffle=True, + # ) + # + # val_dataloader = self.build_dataloader( + # val_lX, + # val_lY, + # processor_fn=self._tokenize, + # torchDataset=MultilingualDatasetTorch, + # batch_size=self.batch_size_eval, + # split="val", + # shuffle=False, + # ) + # + # experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}" + # + # trainer = Trainer( + # model=self.model, + # optimizer_name="adamW", + # lr=self.lr, + # device=self.device, + # loss_fn=torch.nn.CrossEntropyLoss(), + # print_steps=self.print_steps, + # evaluate_step=self.evaluate_step, + # patience=self.patience, + # experiment_name=experiment_name, + # checkpoint_path=os.path.join( + # "models", + # "vgfs", + # "trained_transformer", + # self._format_model_name(self.model_name), + # ), + # vgf_name="textual_trf", + # classification_type=self.clf_type, + # n_jobs=self.n_jobs, + # scheduler_name=self.scheduler, + # ) + # trainer.train( + # train_dataloader=tra_dataloader, + # eval_dataloader=val_dataloader, + # epochs=self.epochs, + # ) if self.probabilistic: self.feature2posterior_projector.fit(self.transform(lX), lY) @@ -225,7 +228,6 @@ class TextualTransformerGen(ViewGen, TransformerGen): with torch.no_grad(): for input_ids, lang in dataloader: input_ids = input_ids.to(self.device) - # TODO: check this if isinstance(self.model, MT5ForSequenceClassification): batch_embeddings = self.model(input_ids).pooled.cpu().numpy() else: