From 875af6d362caf1ddd2e76304e7c22f2e6c69ef03 Mon Sep 17 00:00:00 2001 From: andreapdr Date: Thu, 5 Oct 2023 15:46:54 +0200 Subject: [PATCH] removed unused --- gfun/vgfs/visualTransformerGen.py | 189 ------------------------------ 1 file changed, 189 deletions(-) delete mode 100644 gfun/vgfs/visualTransformerGen.py diff --git a/gfun/vgfs/visualTransformerGen.py b/gfun/vgfs/visualTransformerGen.py deleted file mode 100644 index 84514e4..0000000 --- a/gfun/vgfs/visualTransformerGen.py +++ /dev/null @@ -1,189 +0,0 @@ -from collections import defaultdict - -import numpy as np -import torch -import transformers -from PIL import Image -from transformers import AutoImageProcessor, AutoModelForImageClassification - -from gfun.vgfs.commons import Trainer -from gfun.vgfs.transformerGen import TransformerGen -from gfun.vgfs.viewGen import ViewGen -from dataManager.torchDataset import MultimodalDatasetTorch - -transformers.logging.set_verbosity_error() - - -class VisualTransformerGen(ViewGen, TransformerGen): - def __init__( - self, - model_name, - dataset_name, - lr=1e-5, - scheduler="ReduceLROnPlateau", - epochs=10, - batch_size=32, - batch_size_eval=128, - evaluate_step=10, - device="cpu", - probabilistic=False, - patience=5, - classification_type="multilabel", - ): - super().__init__( - model_name, - dataset_name, - epochs=epochs, - lr=lr, - scheduler=scheduler, - batch_size=batch_size, - batch_size_eval=batch_size_eval, - device=device, - evaluate_step=evaluate_step, - patience=patience, - probabilistic=probabilistic, - ) - self.clf_type = classification_type - self.fitted = False - print( - f"- init Visual TransformerModel model_name: {self.model_name}, device: {self.device}]" - ) - - def _validate_model_name(self, model_name): - if "vit" == model_name: - return "google/vit-base-patch16-224-in21k" - else: - raise NotImplementedError - - def init_model(self, model_name, num_labels): - model = AutoModelForImageClassification.from_pretrained( - model_name, num_labels=num_labels, output_hidden_states=True - ) - image_processor = AutoImageProcessor.from_pretrained(model_name) - return model, image_processor - - def process_all(self, X): - # TODO: should be moved as a collate_fn to avoid this overhead - processed = self.image_preprocessor( - [Image.open(img).convert("RGB") for img in X], return_tensors="pt" - ) - return processed["pixel_values"] - - def fit(self, lX, lY): - print("- fitting Visual Transformer View Generating Function") - _l = list(lX.keys())[0] - self.num_labels = lY[_l].shape[-1] - self.model, self.image_preprocessor = self.init_model( - self._validate_model_name(self.model_name), num_labels=self.num_labels - ) - - tr_lX, tr_lY, val_lX, val_lY = self.get_train_val_data( - lX, lY, split=0.2, seed=42, modality="image" - ) - - tra_dataloader = self.build_dataloader( - tr_lX, - tr_lY, - processor_fn=self.process_all, - torchDataset=MultimodalDatasetTorch, - batch_size=self.batch_size, - split="train", - shuffle=True, - ) - - val_dataloader = self.build_dataloader( - val_lX, - val_lY, - processor_fn=self.process_all, - torchDataset=MultimodalDatasetTorch, - batch_size=self.batch_size_eval, - split="val", - shuffle=False, - ) - - experiment_name = ( - f"{self.model_name}-{self.epochs}-{self.batch_size}-{self.dataset_name}" - ) - - trainer = Trainer( - model=self.model, - optimizer_name="adamW", - device=self.device, - loss_fn=torch.nn.CrossEntropyLoss(), - lr=self.lr, - print_steps=self.print_steps, - evaluate_step=self.evaluate_step, - patience=self.patience, - experiment_name=experiment_name, - checkpoint_path="models/vgfs/transformer", - vgf_name="visual_trf", - classification_type=self.clf_type, - n_jobs=self.n_jobs, - ) - - trainer.train( - train_dataloader=tra_dataloader, - eval_dataloader=val_dataloader, - epochs=self.epochs, - ) - - if self.probabilistic: - self.feature2posterior_projector.fit(self.transform(lX), lY) - - self.fitted = True - - return self - - def transform(self, lX): - # forcing to only image modality - lX = {lang: data["image"] for lang, data in lX.items()} - _embeds = [] - l_embeds = defaultdict(list) - - dataloader = self.build_dataloader( - lX, - lY=None, - processor_fn=self.process_all, - torchDataset=MultimodalDatasetTorch, - batch_size=self.batch_size_eval, - split="whole", - shuffle=False, - ) - - self.model.eval() - with torch.no_grad(): - for input_ids, lang in dataloader: - input_ids = input_ids.to(self.device) - out = self.model(input_ids).hidden_states[-1] - batch_embeddings = out[:, 0, :].cpu().numpy() - _embeds.append((batch_embeddings, lang)) - - for embed, lang in _embeds: - for sample_embed, sample_lang in zip(embed, lang): - l_embeds[sample_lang].append(sample_embed) - - if self.probabilistic and self.fitted: - l_embeds = self.feature2posterior_projector.transform(l_embeds) - elif not self.probabilistic and self.fitted: - l_embeds = {lang: np.array(preds) for lang, preds in l_embeds.items()} - - return l_embeds - - def fit_transform(self, lX, lY): - return self.fit(lX, lY).transform(lX) - - def save_vgf(self, model_id): - import pickle - from os import makedirs - from os.path import join - - vgf_name = "visualTransformerGen" - _basedir = join("models", "vgfs", "visual_transformer") - makedirs(_basedir, exist_ok=True) - _path = join(_basedir, f"{vgf_name}_{model_id}.pkl") - with open(_path, "wb") as f: - pickle.dump(self, f) - return self - - def get_config(self): - return {"name": "visual-transformer VGF", "visual_trf": super().get_config()}