import sys, os sys.path.append(os.getcwd()) import torch import transformers from gfun.vgfs.viewGen import ViewGen from transformers import AutoImageProcessor from torch.utils.data import DataLoader, Dataset from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor from gfun.vgfs.commons import Trainer, predict from gfun.vgfs.transformerGen import TransformerGen from transformers import AutoModelForImageClassification, TrainingArguments, Trainer transformers.logging.set_verbosity_error() class VisualTransformerGen(ViewGen, TransformerGen): def __init__( self, model_name, lr=1e-5, epochs=10, batch_size=32, batch_size_eval=128 ): self.model_name = model_name self.datasets = {} self.lr = lr self.epochs = epochs self.batch_size = batch_size self.batch_size_eval = batch_size_eval def _validate_model_name(self, model_name): if "vit" == model_name: return "google/vit-base-patch16-224-in21k" else: raise NotImplementedError def init_model(self, model_name, num_labels): model = ( AutoModelForImageClassification.from_pretrained( model_name, num_labels=num_labels ), ) image_processor = AutoImageProcessor.from_pretrained(model_name) transforms = self.init_preprocessor(image_processor) return model, image_processor, transforms def init_preprocessor(self, image_processor): normalize = Normalize( mean=image_processor.image_mean, std=image_processor.image_std ) size = ( image_processor.size["shortest_edge"] if "shortest_edge" in image_processor.size else (image_processor.size["height"], image_processor.size["width"]) ) # these are the transformations that we are applying to the images transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize]) return transforms def preprocess(self, images, transforms): processed = transforms(img.convert("RGB") for img in images) return processed def process_all(self, X): # TODO: every element in X is a tuple (doc_id, clean_text, text, Pil.Image), so we're taking just the last element for processing processed = torch.stack([self.transforms(img[-1]) for img in X]) return processed def fit(self, lX, lY): print("- fitting Visual Transformer View Generating Function") _l = list(lX.keys())[0] self.num_labels = lY[_l].shape[-1] self.model, self.image_preprocessor, self.transforms = self.init_model( self._validate_model_name(self.model_name), num_labels=self.num_labels ) tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data( lX, lY, split=0.2, seed=42 ) tra_dataloader = self.build_dataloader( tr_lX, tr_lY, processor_fn=self.process_all, torchDataset=MultimodalDatasetTorch, batch_size=self.batch_size, split="train", shuffle=True, ) val_dataloader = self.build_dataloader( val_lX, val_lY, processor_fn=self.process_all, torchDataset=MultimodalDatasetTorch, batch_size=self.batch_size_eval, split="val", shuffle=False, ) experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}" trainer = Trainer( model=self.model, optimizer_name="adamW", lr=self.lr, device=self.device, loss_fn=torch.nn.CrossEntropyLoss(), print_steps=self.print_steps, evaluate_step=self.evaluate_step, patience=self.patience, experiment_name=experiment_name, ) trainer.train( train_dataloader=tra_dataloader, val_dataloader=val_dataloader, epochs=self.epochs, ) def transform(self, lX): raise NotImplementedError def fit_transform(self, lX, lY): raise NotImplementedError def save_vgf(self, model_id): raise NotImplementedError def save_vgf(self, model_id): raise NotImplementedError class MultimodalDatasetTorch(Dataset): def __init__(self, lX, lY, split="train"): self.lX = lX self.lY = lY self.split = split self.langs = [] self.init() def init(self): self.X = torch.vstack([imgs for imgs in self.lX.values()]) if self.split != "whole": self.Y = torch.vstack([torch.Tensor(data) for data in self.lY.values()]) self.langs = sum( [ v for v in { lang: [lang] * len(data) for lang, data in self.lX.items() }.values() ], [], ) print(f"- lX has shape: {self.X.shape}\n- lY has shape: {self.Y.shape}") def __len__(self): return len(self.X) def __getitem__(self, index): if self.split == "whole": return self.X[index], self.langs[index] return self.X[index], self.Y[index], self.langs[index] if __name__ == "__main__": from os.path import expanduser from dataManager.multiNewsDataset import MultiNewsDataset _dataset_path_hardcoded = "~/datasets/MultiNews/20110730/" dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True) lXtr, lYtr = dataset.training() vg = VisualTransformerGen(model_name="vit") lX, lY = dataset.training() vg.fit(lX, lY) print("lel")