avoid training transformers
This commit is contained in:
parent
2554c58fac
commit
60171c1b5e
|
@ -114,6 +114,7 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
model_name, num_labels=num_labels, output_hidden_states=True
|
model_name, num_labels=num_labels, output_hidden_states=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
model_name = "models/vgfs/trained_transformer/mbert-sentiment/checkpoint-8500" # TODO hardcoded to pre-traiend mbert
|
||||||
return AutoModelForSequenceClassification.from_pretrained(
|
return AutoModelForSequenceClassification.from_pretrained(
|
||||||
model_name, num_labels=num_labels, output_hidden_states=True
|
model_name, num_labels=num_labels, output_hidden_states=True
|
||||||
)
|
)
|
||||||
|
@ -145,58 +146,60 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
self.model_name, num_labels=self.num_labels
|
self.model_name, num_labels=self.num_labels
|
||||||
)
|
)
|
||||||
|
|
||||||
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
|
self.model.to("cuda")
|
||||||
lX, lY, split=0.2, seed=42, modality="text"
|
|
||||||
)
|
|
||||||
|
|
||||||
tra_dataloader = self.build_dataloader(
|
# tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
|
||||||
tr_lX,
|
# lX, lY, split=0.2, seed=42, modality="text"
|
||||||
tr_lY,
|
# )
|
||||||
processor_fn=self._tokenize,
|
#
|
||||||
torchDataset=MultilingualDatasetTorch,
|
# tra_dataloader = self.build_dataloader(
|
||||||
batch_size=self.batch_size,
|
# tr_lX,
|
||||||
split="train",
|
# tr_lY,
|
||||||
shuffle=True,
|
# processor_fn=self._tokenize,
|
||||||
)
|
# torchDataset=MultilingualDatasetTorch,
|
||||||
|
# batch_size=self.batch_size,
|
||||||
val_dataloader = self.build_dataloader(
|
# split="train",
|
||||||
val_lX,
|
# shuffle=True,
|
||||||
val_lY,
|
# )
|
||||||
processor_fn=self._tokenize,
|
#
|
||||||
torchDataset=MultilingualDatasetTorch,
|
# val_dataloader = self.build_dataloader(
|
||||||
batch_size=self.batch_size_eval,
|
# val_lX,
|
||||||
split="val",
|
# val_lY,
|
||||||
shuffle=False,
|
# processor_fn=self._tokenize,
|
||||||
)
|
# torchDataset=MultilingualDatasetTorch,
|
||||||
|
# batch_size=self.batch_size_eval,
|
||||||
experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
# split="val",
|
||||||
|
# shuffle=False,
|
||||||
trainer = Trainer(
|
# )
|
||||||
model=self.model,
|
#
|
||||||
optimizer_name="adamW",
|
# experiment_name = f"{self.model_name.replace('/', '-')}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||||
lr=self.lr,
|
#
|
||||||
device=self.device,
|
# trainer = Trainer(
|
||||||
loss_fn=torch.nn.CrossEntropyLoss(),
|
# model=self.model,
|
||||||
print_steps=self.print_steps,
|
# optimizer_name="adamW",
|
||||||
evaluate_step=self.evaluate_step,
|
# lr=self.lr,
|
||||||
patience=self.patience,
|
# device=self.device,
|
||||||
experiment_name=experiment_name,
|
# loss_fn=torch.nn.CrossEntropyLoss(),
|
||||||
checkpoint_path=os.path.join(
|
# print_steps=self.print_steps,
|
||||||
"models",
|
# evaluate_step=self.evaluate_step,
|
||||||
"vgfs",
|
# patience=self.patience,
|
||||||
"trained_transformer",
|
# experiment_name=experiment_name,
|
||||||
self._format_model_name(self.model_name),
|
# checkpoint_path=os.path.join(
|
||||||
),
|
# "models",
|
||||||
vgf_name="textual_trf",
|
# "vgfs",
|
||||||
classification_type=self.clf_type,
|
# "trained_transformer",
|
||||||
n_jobs=self.n_jobs,
|
# self._format_model_name(self.model_name),
|
||||||
scheduler_name=self.scheduler,
|
# ),
|
||||||
)
|
# vgf_name="textual_trf",
|
||||||
trainer.train(
|
# classification_type=self.clf_type,
|
||||||
train_dataloader=tra_dataloader,
|
# n_jobs=self.n_jobs,
|
||||||
eval_dataloader=val_dataloader,
|
# scheduler_name=self.scheduler,
|
||||||
epochs=self.epochs,
|
# )
|
||||||
)
|
# trainer.train(
|
||||||
|
# train_dataloader=tra_dataloader,
|
||||||
|
# eval_dataloader=val_dataloader,
|
||||||
|
# epochs=self.epochs,
|
||||||
|
# )
|
||||||
|
|
||||||
if self.probabilistic:
|
if self.probabilistic:
|
||||||
self.feature2posterior_projector.fit(self.transform(lX), lY)
|
self.feature2posterior_projector.fit(self.transform(lX), lY)
|
||||||
|
@ -225,7 +228,6 @@ class TextualTransformerGen(ViewGen, TransformerGen):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for input_ids, lang in dataloader:
|
for input_ids, lang in dataloader:
|
||||||
input_ids = input_ids.to(self.device)
|
input_ids = input_ids.to(self.device)
|
||||||
# TODO: check this
|
|
||||||
if isinstance(self.model, MT5ForSequenceClassification):
|
if isinstance(self.model, MT5ForSequenceClassification):
|
||||||
batch_embeddings = self.model(input_ids).pooled.cpu().numpy()
|
batch_embeddings = self.model(input_ids).pooled.cpu().numpy()
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue