removed unused
This commit is contained in:
parent
22a36e5ddf
commit
875af6d362
|
|
@ -1,189 +0,0 @@
|
|||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from PIL import Image
|
||||
from transformers import AutoImageProcessor, AutoModelForImageClassification
|
||||
|
||||
from gfun.vgfs.commons import Trainer
|
||||
from gfun.vgfs.transformerGen import TransformerGen
|
||||
from gfun.vgfs.viewGen import ViewGen
|
||||
from dataManager.torchDataset import MultimodalDatasetTorch
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
||||
class VisualTransformerGen(ViewGen, TransformerGen):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
dataset_name,
|
||||
lr=1e-5,
|
||||
scheduler="ReduceLROnPlateau",
|
||||
epochs=10,
|
||||
batch_size=32,
|
||||
batch_size_eval=128,
|
||||
evaluate_step=10,
|
||||
device="cpu",
|
||||
probabilistic=False,
|
||||
patience=5,
|
||||
classification_type="multilabel",
|
||||
):
|
||||
super().__init__(
|
||||
model_name,
|
||||
dataset_name,
|
||||
epochs=epochs,
|
||||
lr=lr,
|
||||
scheduler=scheduler,
|
||||
batch_size=batch_size,
|
||||
batch_size_eval=batch_size_eval,
|
||||
device=device,
|
||||
evaluate_step=evaluate_step,
|
||||
patience=patience,
|
||||
probabilistic=probabilistic,
|
||||
)
|
||||
self.clf_type = classification_type
|
||||
self.fitted = False
|
||||
print(
|
||||
f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]"
|
||||
)
|
||||
|
||||
def _validate_model_name(self, model_name):
|
||||
if "vit" == model_name:
|
||||
return "google/vit-base-patch16-224-in21k"
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def init_model(self, model_name, num_labels):
|
||||
model = AutoModelForImageClassification.from_pretrained(
|
||||
model_name, num_labels=num_labels, output_hidden_states=True
|
||||
)
|
||||
image_processor = AutoImageProcessor.from_pretrained(model_name)
|
||||
return model, image_processor
|
||||
|
||||
def process_all(self, X):
|
||||
# TODO: should be moved as a collate_fn to avoid this overhead
|
||||
processed = self.image_preprocessor(
|
||||
[Image.open(img).convert("RGB") for img in X], return_tensors="pt"
|
||||
)
|
||||
return processed["pixel_values"]
|
||||
|
||||
def fit(self, lX, lY):
|
||||
print("- fitting Visual Transformer View Generating Function")
|
||||
_l = list(lX.keys())[0]
|
||||
self.num_labels = lY[_l].shape[-1]
|
||||
self.model, self.image_preprocessor = self.init_model(
|
||||
self._validate_model_name(self.model_name), num_labels=self.num_labels
|
||||
)
|
||||
|
||||
tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data(
|
||||
lX, lY, split=0.2, seed=42, modality="image"
|
||||
)
|
||||
|
||||
tra_dataloader = self.build_dataloader(
|
||||
tr_lX,
|
||||
tr_lY,
|
||||
processor_fn=self.process_all,
|
||||
torchDataset=MultimodalDatasetTorch,
|
||||
batch_size=self.batch_size,
|
||||
split="train",
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
val_dataloader = self.build_dataloader(
|
||||
val_lX,
|
||||
val_lY,
|
||||
processor_fn=self.process_all,
|
||||
torchDataset=MultimodalDatasetTorch,
|
||||
batch_size=self.batch_size_eval,
|
||||
split="val",
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
experiment_name = (
|
||||
f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}"
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=self.model,
|
||||
optimizer_name="adamW",
|
||||
device=self.device,
|
||||
loss_fn=torch.nn.CrossEntropyLoss(),
|
||||
lr=self.lr,
|
||||
print_steps=self.print_steps,
|
||||
evaluate_step=self.evaluate_step,
|
||||
patience=self.patience,
|
||||
experiment_name=experiment_name,
|
||||
checkpoint_path="models/vgfs/transformer",
|
||||
vgf_name="visual_trf",
|
||||
classification_type=self.clf_type,
|
||||
n_jobs=self.n_jobs,
|
||||
)
|
||||
|
||||
trainer.train(
|
||||
train_dataloader=tra_dataloader,
|
||||
eval_dataloader=val_dataloader,
|
||||
epochs=self.epochs,
|
||||
)
|
||||
|
||||
if self.probabilistic:
|
||||
self.feature2posterior_projector.fit(self.transform(lX), lY)
|
||||
|
||||
self.fitted = True
|
||||
|
||||
return self
|
||||
|
||||
def transform(self, lX):
|
||||
# forcing to only image modality
|
||||
lX = {lang: data["image"] for lang, data in lX.items()}
|
||||
_embeds = []
|
||||
l_embeds = defaultdict(list)
|
||||
|
||||
dataloader = self.build_dataloader(
|
||||
lX,
|
||||
lY=None,
|
||||
processor_fn=self.process_all,
|
||||
torchDataset=MultimodalDatasetTorch,
|
||||
batch_size=self.batch_size_eval,
|
||||
split="whole",
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
for input_ids, lang in dataloader:
|
||||
input_ids = input_ids.to(self.device)
|
||||
out = self.model(input_ids).hidden_states[-1]
|
||||
batch_embeddings = out[:, 0, :].cpu().numpy()
|
||||
_embeds.append((batch_embeddings, lang))
|
||||
|
||||
for embed, lang in _embeds:
|
||||
for sample_embed, sample_lang in zip(embed, lang):
|
||||
l_embeds[sample_lang].append(sample_embed)
|
||||
|
||||
if self.probabilistic and self.fitted:
|
||||
l_embeds = self.feature2posterior_projector.transform(l_embeds)
|
||||
elif not self.probabilistic and self.fitted:
|
||||
l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()}
|
||||
|
||||
return l_embeds
|
||||
|
||||
def fit_transform(self, lX, lY):
|
||||
return self.fit(lX, lY).transform(lX)
|
||||
|
||||
def save_vgf(self, model_id):
|
||||
import pickle
|
||||
from os import makedirs
|
||||
from os.path import join
|
||||
|
||||
vgf_name = "visualTransformerGen"
|
||||
_basedir = join("models", "vgfs", "visual_transformer")
|
||||
makedirs(_basedir, exist_ok=True)
|
||||
_path = join(_basedir, f"{vgf_name}_{model_id}.pkl")
|
||||
with open(_path, "wb") as f:
|
||||
pickle.dump(self, f)
|
||||
return self
|
||||
|
||||
def get_config(self):
|
||||
return {"name": "visual-transformer VGF", "visual_trf": super().get_config()}
|
||||
Loading…
Reference in New Issue