from collections import defaultdict import numpy as np import torch import transformers from PIL import Image from torch.utils.data import Dataset from transformers import AutoImageProcessor, AutoModelForImageClassification from gfun.vgfs.commons import Trainer from gfun.vgfs.transformerGen import TransformerGen from gfun.vgfs.viewGen import ViewGen transformers.logging.set_verbosity_error() class VisualTransformerGen(ViewGen, TransformerGen): def __init__( self, model_name, dataset_name, lr=1e-5, epochs=10, batch_size=32, batch_size_eval=128, evaluate_step=10, device="cpu", probabilistic=False, patience=5, ): super().__init__( model_name, dataset_name, lr=lr, epochs=epochs, batch_size=batch_size, batch_size_eval=batch_size_eval, device=device, evaluate_step=evaluate_step, patience=patience, probabilistic=probabilistic, ) self.fitted = False print( f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]" ) def _validate_model_name(self, model_name): if "vit" == model_name: return "google/vit-base-patch16-224-in21k" else: raise NotImplementedError def init_model(self, model_name, num_labels): model = AutoModelForImageClassification.from_pretrained( model_name, num_labels=num_labels, output_hidden_states=True ) image_processor = AutoImageProcessor.from_pretrained(model_name) return model, image_processor def process_all(self, X): # TODO: should be moved as a collate_fn to avoid this overhead processed = self.image_preprocessor( [Image.open(img).convert("RGB") for img in X], return_tensors="pt" ) return processed["pixel_values"] def fit(self, lX, lY): print("- fitting Visual Transformer View Generating Function") _l = list(lX.keys())[0] self.num_labels = lY[_l].shape[-1] self.model, self.image_preprocessor = self.init_model( self._validate_model_name(self.model_name), num_labels=self.num_labels ) tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data( lX, lY, split=0.2, seed=42, modality="image" ) tra_dataloader = self.build_dataloader( tr_lX, tr_lY, processor_fn=self.process_all, torchDataset=MultimodalDatasetTorch, batch_size=self.batch_size, split="train", shuffle=True, ) val_dataloader = self.build_dataloader( val_lX, val_lY, processor_fn=self.process_all, torchDataset=MultimodalDatasetTorch, batch_size=self.batch_size_eval, split="val", shuffle=False, ) experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}" trainer = Trainer( model=self.model, optimizer_name="adamW", device=self.device, loss_fn=torch.nn.CrossEntropyLoss(), lr=self.lr, print_steps=self.print_steps, evaluate_step=self.evaluate_step, patience=self.patience, experiment_name=experiment_name, checkpoint_path="models/vgfs/transformer", ) trainer.train( train_dataloader=tra_dataloader, eval_dataloader=val_dataloader, epochs=self.epochs, ) if self.probabilistic: self.feature2posterior_projector.fit(self.transform(lX), lY) self.fitted = True return self def transform(self, lX): # forcing to only image modality lX = {lang: data["image"] for lang, data in lX.items()} _embeds = [] l_embeds = defaultdict(list) dataloader = self.build_dataloader( lX, lY=None, processor_fn=self.process_all, torchDataset=MultimodalDatasetTorch, batch_size=self.batch_size_eval, split="whole", shuffle=False, ) self.model.eval() with torch.no_grad(): for input_ids, lang in dataloader: input_ids = input_ids.to(self.device) out = self.model(input_ids).hidden_states[-1] batch_embeddings = out[:, 0, :].cpu().numpy() _embeds.append((batch_embeddings, lang)) for embed, lang in _embeds: for sample_embed, sample_lang in zip(embed, lang): l_embeds[sample_lang].append(sample_embed) if self.probabilistic and self.fitted: l_embeds = self.feature2posterior_projector.transform(l_embeds) elif not self.probabilistic and self.fitted: l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()} return l_embeds def fit_transform(self, lX, lY): return self.fit(lX, lY).transform(lX) def save_vgf(self, model_id): import pickle from os import makedirs from os.path import join vgf_name = "visualTransformerGen" _basedir = join("models", "vgfs", "visual_transformer") makedirs(_basedir, exist_ok=True) _path = join(_basedir, f"{vgf_name}_{model_id}.pkl") with open(_path, "wb") as f: pickle.dump(self, f) return self def __str__(self): str = f"[Visual Transformer VGF (v)]\n- model_name: {self.model_name}\n- batch_size: {self.batch_size}\n- batch_size_eval: {self.batch_size_eval}\n- lr: {self.lr}\n- epochs: {self.epochs}\n- device: {self.device}\n- print_steps: {self.print_steps}\n- evaluate_step: {self.evaluate_step}\n- patience: {self.patience}\n- probabilistic: {self.probabilistic}\n" return str class MultimodalDatasetTorch(Dataset): def __init__(self, lX, lY, split="train"): self.lX = lX self.lY = lY self.split = split self.langs = [] self.init() def init(self): self.X = torch.vstack([imgs for imgs in self.lX.values()]) if self.split != "whole": self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()]) self.langs = sum( [ v for v in { lang: [lang] * len(data) for lang, data in self.lX.items() }.values() ], [], ) def __len__(self): return len(self.X) def __getitem__(self, index): if self.split == "whole": return self.X[index], self.langs[index] return self.X[index], self.Y[index], self.langs[index] if __name__ == "__main__": from os.path import expanduser from dataManager.gFunDataset import gFunDataset GLAMI_DATAPATH = expanduser("~/datasets/GLAMI-1M-dataset") dataset = gFunDataset( dataset_dir=GLAMI_DATAPATH, is_textual=True, is_visual=True, is_multilabel=False, nrows=50, ) vg = VisualTransformerGen( dataset_name=dataset.dataset_name, model_name="vit", device="cuda", epochs=5, evaluate_step=10, patience=10, probabilistic=True, ) lX, lY = dataset.training() vg.fit(lX, lY) out = vg.transform(lX) exit(0)