From 9c2c43dafba37b8f9e8e177b4c12fb539760c20b Mon Sep 17 00:00:00 2001 From: andreapdr Date: Thu, 9 Feb 2023 18:42:27 +0100 Subject: [PATCH] Visual VGF + MultiNewsDataset, working from data loading to testing --- dataManager/multiNewsDataset.py | 34 +++-- gfun/generalizedFunnelling.py | 1 - gfun/vgfs/textualTransformerGen.py | 223 +++++------------------------ gfun/vgfs/transformerGen.py | 34 ++++- gfun/vgfs/vanillaFun.py | 1 - gfun/vgfs/visualTransformerGen.py | 45 +++--- main.py | 46 ++++-- 7 files changed, 156 insertions(+), 228 deletions(-) diff --git a/dataManager/multiNewsDataset.py b/dataManager/multiNewsDataset.py index 749403a..9693937 100644 --- a/dataManager/multiNewsDataset.py +++ b/dataManager/multiNewsDataset.py @@ -27,11 +27,11 @@ class MultiNewsDataset: def __init__(self, data_dir, excluded_langs=[], debug=False): self.debug = debug self.data_dir = data_dir - self.langs = self.get_langs() + self.dataset_langs = self.get_langs() self.excluded_langs = excluded_langs self.lang_multiModalDataset = {} print( - f"[{'DEBUG MODE: ' if debug else ''}Loaded MultiNewsDataset - langs: {self.langs}]" + f"[{'DEBUG MODE: ' if debug else ''}Loaded MultiNewsDataset - langs: {[l for l in self.dataset_langs if l not in self.excluded_langs]}]" ) self.load_data() self.all_labels = self.get_labels() @@ -39,12 +39,16 @@ class MultiNewsDataset: self.print_stats() def load_data(self): - for lang in self.langs: + for lang in self.dataset_langs: if lang not in self.excluded_langs: self.lang_multiModalDataset[lang] = MultiModalDataset( lang, join(self.data_dir, lang) ) + def langs(self): + return [l for l in self.dataset_langs if l not in self.excluded_langs] + return self.get_langs() + def get_langs(self): from os import listdir @@ -56,13 +60,14 @@ class MultiNewsDataset: def print_stats(self): print(f"[MultiNewsDataset stats]") total_docs = 0 - for lang in self.langs: - _len = len(self.lang_multiModalDataset[lang].data) - total_docs += _len - print( - f" - {lang} docs: {_len}\t- labels: {self._count_lang_labels(self.lang_multiModalDataset[lang].labels)}" - ) - print(f" - total docs: {total_docs}") + for lang in self.dataset_langs: + if lang not in self.excluded_langs: + _len = len(self.lang_multiModalDataset[lang].data) + total_docs += _len + print( + f" - {lang} docs: {_len}\t- labels: {self._count_lang_labels(self.lang_multiModalDataset[lang].labels)}" + ) + print(f" - total docs: {total_docs}\n") def _count_lang_labels(self, labels): lang_labels = set() @@ -77,11 +82,16 @@ class MultiNewsDataset: raise NotImplementedError def training(self): + # TODO: this is a (working) mess - clean this up lXtr = {} lYtr = {} for lang, data in self.lang_multiModalDataset.items(): - lXtr[lang] = data.data - lYtr[lang] = self.label_binarizer.transform(data.labels) + _data = [clean_text for _, clean_text, _, _ in data.data] + lXtr[lang] = _data + lYtr = { + lang: self.label_binarizer.transform(data.labels) + for lang, data in self.lang_multiModalDataset.items() + } return lXtr, lYtr diff --git a/gfun/generalizedFunnelling.py b/gfun/generalizedFunnelling.py index 0d713e3..091e9d4 100644 --- a/gfun/generalizedFunnelling.py +++ b/gfun/generalizedFunnelling.py @@ -78,7 +78,6 @@ class GeneralizedFunnelling: if self.posteriors_vgf: fun = VanillaFunGen( base_learner=get_learner(calibrate=True), - first_tier_parameters=None, n_jobs=self.n_jobs, ) self.first_tier_learners.append(fun) diff --git a/gfun/vgfs/textualTransformerGen.py b/gfun/vgfs/textualTransformerGen.py index c705160..7648fb9 100644 --- a/gfun/vgfs/textualTransformerGen.py +++ b/gfun/vgfs/textualTransformerGen.py @@ -13,8 +13,8 @@ from torch.utils.data import DataLoader, Dataset from transformers import AutoModelForSequenceClassification, AutoTokenizer from vgfs.learners.svms import FeatureSet2Posteriors from vgfs.viewGen import ViewGen - -from evaluation.evaluate import evaluate, log_eval +from vgfs.transformerGen import TransformerGen +from vgfs.commons import Trainer, predict transformers.logging.set_verbosity_error() @@ -23,7 +23,7 @@ transformers.logging.set_verbosity_error() # TODO: multiple inheritance - maybe define a superclass for TransformerGenerator, whether it is a Textual or a Visual one, implementing dataset creation functions -class TextualTransformerGen(ViewGen): +class TextualTransformerGen(ViewGen, TransformerGen): def __init__( self, model_name, @@ -40,23 +40,22 @@ class TextualTransformerGen(ViewGen): verbose=False, patience=5, ): - self.model_name = model_name - self.device = device - self.model = None - self.lr = lr - self.epochs = epochs - self.tokenizer = None - self.max_length = max_length - self.batch_size = batch_size - self.batch_size_eval = batch_size_eval - self.print_steps = print_steps - self.probabilistic = probabilistic - self.n_jobs = n_jobs + super().__init__( + model_name, + epochs, + lr, + batch_size, + batch_size_eval, + max_length, + print_steps, + device, + probabilistic, + n_jobs, + evaluate_step, + verbose, + patience, + ) self.fitted = False - self.datasets = {} - self.evaluate_step = evaluate_step - self.verbose = verbose - self.patience = patience self._init() def _init(self): @@ -93,25 +92,6 @@ class TextualTransformerGen(ViewGen): model_name ) - def get_train_val_data(self, lX, lY, split=0.2, seed=42): - tr_lX, tr_lY, val_lX, val_lY = {}, {}, {}, {} - - for lang in lX.keys(): - tr_X, val_X, tr_Y, val_Y = train_test_split( - lX[lang], lY[lang], test_size=split, random_state=seed, shuffle=False - ) - tr_lX[lang] = tr_X - tr_lY[lang] = tr_Y - val_lX[lang] = val_X - val_lY[lang] = val_Y - - return tr_lX, tr_lY, val_lX, val_lY - - def build_dataloader(self, lX, lY, batch_size, split="train", shuffle=False): - l_tokenized = {lang: self._tokenize(data) for lang, data in lX.items()} - self.datasets[split] = MultilingualDatasetTorch(l_tokenized, lY, split=split) - return DataLoader(self.datasets[split], batch_size=batch_size, shuffle=shuffle) - def _tokenize(self, X): return self.tokenizer( X, @@ -136,11 +116,23 @@ class TextualTransformerGen(ViewGen): ) tra_dataloader = self.build_dataloader( - tr_lX, tr_lY, self.batch_size, split="train", shuffle=True + tr_lX, + tr_lY, + processor_fn=self._tokenize, + torchDataset=MultilingualDatasetTorch, + batch_size=self.batch_size, + split="train", + shuffle=True, ) val_dataloader = self.build_dataloader( - val_lX, val_lY, self.batch_size_eval, split="val", shuffle=False + val_lX, + val_lY, + processor_fn=self._tokenize, + torchDataset=MultilingualDatasetTorch, + batch_size=self.batch_size_eval, + split="val", + shuffle=False, ) experiment_name = f"{self.model_name}-{self.epochs}-{self.batch_size}" @@ -173,7 +165,13 @@ class TextualTransformerGen(ViewGen): l_embeds = defaultdict(list) dataloader = self.build_dataloader( - lX, lY=None, batch_size=self.batch_size_eval, split="whole", shuffle=False + lX, + lY=None, + processor_fn=self._tokenize, + torchDataset=MultilingualDatasetTorch, + batch_size=self.batch_size_eval, + split="whole", + shuffle=False, ) self.model.eval() @@ -245,146 +243,3 @@ class MultilingualDatasetTorch(Dataset): if self.split == "whole": return self.X[index], self.langs[index] return self.X[index], self.Y[index], self.langs[index] - - -class Trainer: - def __init__( - self, - model, - optimizer_name, - device, - loss_fn, - lr, - print_steps, - evaluate_step, - patience, - experiment_name, - ): - self.device = device - self.model = model.to(device) - self.optimizer = self.init_optimizer(optimizer_name, lr) - self.evaluate_steps = evaluate_step - self.loss_fn = loss_fn.to(device) - self.print_steps = print_steps - self.earlystopping = EarlyStopping( - patience=patience, - checkpoint_path="models/vgfs/transformers/", - verbose=True, - experiment_name=experiment_name, - ) - - def init_optimizer(self, optimizer_name, lr): - if optimizer_name.lower() == "adamw": - return AdamW(self.model.parameters(), lr=lr) - else: - raise ValueError(f"Optimizer {optimizer_name} not supported") - - def train(self, train_dataloader, eval_dataloader, epochs=10): - print( - f"""- Training params: - - epochs: {epochs} - - learning rate: {self.optimizer.defaults['lr']} - - train batch size: {train_dataloader.batch_size} - - eval batch size: {eval_dataloader.batch_size} - - max len: {train_dataloader.dataset.X.shape[-1]}\n""", - ) - for epoch in range(epochs): - self.train_epoch(train_dataloader, epoch) - if (epoch + 1) % self.evaluate_steps == 0: - metric_watcher = self.evaluate(eval_dataloader) - stop = self.earlystopping(metric_watcher, self.model, epoch + 1) - if stop: - break - return self.model - - def train_epoch(self, dataloader, epoch): - self.model.train() - for b_idx, (x, y, lang) in enumerate(dataloader): - self.optimizer.zero_grad() - y_hat = self.model(x.to(self.device)) - loss = self.loss_fn(y_hat.logits, y.to(self.device)) - loss.backward() - self.optimizer.step() - if b_idx % self.print_steps == 0: - print(f"Epoch: {epoch+1} Step: {b_idx+1} Loss: {loss:.4f}") - return self - - def evaluate(self, dataloader): - self.model.eval() - - lY = defaultdict(list) - lY_hat = defaultdict(list) - - for b_idx, (x, y, lang) in enumerate(dataloader): - y_hat = self.model(x.to(self.device)) - loss = self.loss_fn(y_hat.logits, y.to(self.device)) - predictions = predict(y_hat.logits, classification_type="multilabel") - - for l, _true, _pred in zip(lang, y, predictions): - lY[l].append(_true.detach().cpu().numpy()) - lY_hat[l].append(_pred) - - for lang in lY: - lY[lang] = np.vstack(lY[lang]) - lY_hat[lang] = np.vstack(lY_hat[lang]) - - l_eval = evaluate(lY, lY_hat) - average_metrics = log_eval(l_eval, phase="validation") - return average_metrics[0] # macro-F1 - - -class EarlyStopping: - def __init__( - self, - patience=5, - min_delta=0, - verbose=True, - checkpoint_path="checkpoint.pt", - experiment_name="experiment", - ): - self.patience = patience - self.min_delta = min_delta - self.counter = 0 - self.best_score = 0 - self.best_epoch = None - self.verbose = verbose - self.checkpoint_path = checkpoint_path - self.experiment_name = experiment_name - - def __call__(self, validation, model, epoch): - if validation > self.best_score: - print( - f"- earlystopping: Validation score improved from {self.best_score:.3f} to {validation:.3f}" - ) - self.best_score = validation - self.counter = 0 - # self.save_model(model) - elif validation < (self.best_score + self.min_delta): - self.counter += 1 - print( - f"- earlystopping: Validation score decreased from {self.best_score:.3f} to {validation:.3f}, current patience: {self.patience - self.counter}" - ) - if self.counter >= self.patience: - if self.verbose: - print(f"- earlystopping: Early stopping at epoch {epoch}") - return True - - def save_model(self, model): - _checkpoint_dir = os.path.join(self.checkpoint_path, self.experiment_name) - print(f"- saving model to {_checkpoint_dir}") - os.makedirs(_checkpoint_dir, exist_ok=True) - model.save_pretrained(_checkpoint_dir) - - -def predict(logits, classification_type="multilabel"): - """ - Converts soft precictions to hard predictions [0,1] - """ - if classification_type == "multilabel": - prediction = torch.sigmoid(logits) > 0.5 - elif classification_type == "singlelabel": - prediction = torch.argmax(logits, dim=1).view(-1, 1) - else: - print("unknown classification type") - - return prediction.detach().cpu().numpy() diff --git a/gfun/vgfs/transformerGen.py b/gfun/vgfs/transformerGen.py index 3dc3814..9c56451 100644 --- a/gfun/vgfs/transformerGen.py +++ b/gfun/vgfs/transformerGen.py @@ -9,7 +9,39 @@ class TransformerGen: form of dictioanries {lang: data} """ - def __init__(self): + def __init__( + self, + model_name, + epochs=10, + lr=1e-5, + batch_size=4, + batch_size_eval=32, + max_length=512, + print_steps=50, + device="cpu", + probabilistic=False, + n_jobs=-1, + evaluate_step=10, + verbose=False, + patience=5, + ): + self.model_name = model_name + self.device = device + self.model = None + self.lr = lr + self.epochs = epochs + self.tokenizer = None + self.max_length = max_length + self.batch_size = batch_size + self.batch_size_eval = batch_size_eval + self.print_steps = print_steps + self.probabilistic = probabilistic + self.n_jobs = n_jobs + self.fitted = False + self.datasets = {} + self.evaluate_step = evaluate_step + self.verbose = verbose + self.patience = patience self.datasets = {} def build_dataloader( diff --git a/gfun/vgfs/vanillaFun.py b/gfun/vgfs/vanillaFun.py index a4842da..3416766 100644 --- a/gfun/vgfs/vanillaFun.py +++ b/gfun/vgfs/vanillaFun.py @@ -22,7 +22,6 @@ class VanillaFunGen(ViewGen): self.n_jobs = n_jobs self.doc_projector = NaivePolylingualClassifier( base_learner=self.learners, - parameters=self.first_tier_parameters, n_jobs=self.n_jobs, ) self.vectorizer = None diff --git a/gfun/vgfs/visualTransformerGen.py b/gfun/vgfs/visualTransformerGen.py index aaa7651..80f3682 100644 --- a/gfun/vgfs/visualTransformerGen.py +++ b/gfun/vgfs/visualTransformerGen.py @@ -10,21 +10,33 @@ from torch.utils.data import DataLoader, Dataset from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor from gfun.vgfs.commons import Trainer, predict from gfun.vgfs.transformerGen import TransformerGen -from transformers import AutoModelForImageClassification, TrainingArguments, Trainer +from transformers import AutoModelForImageClassification transformers.logging.set_verbosity_error() class VisualTransformerGen(ViewGen, TransformerGen): def __init__( - self, model_name, lr=1e-5, epochs=10, batch_size=32, batch_size_eval=128 + self, + model_name, + lr=1e-5, + epochs=10, + batch_size=32, + batch_size_eval=128, + evaluate_step=10, + device="cpu", + patience=5, ): - self.model_name = model_name - self.datasets = {} - self.lr = lr - self.epochs = epochs - self.batch_size = batch_size - self.batch_size_eval = batch_size_eval + super().__init__( + model_name, + lr=lr, + epochs=epochs, + batch_size=batch_size, + batch_size_eval=batch_size_eval, + device=device, + evaluate_step=evaluate_step, + patience=patience, + ) def _validate_model_name(self, model_name): if "vit" == model_name: @@ -33,10 +45,8 @@ class VisualTransformerGen(ViewGen, TransformerGen): raise NotImplementedError def init_model(self, model_name, num_labels): - model = ( - AutoModelForImageClassification.from_pretrained( - model_name, num_labels=num_labels - ), + model = AutoModelForImageClassification.from_pretrained( + model_name, num_labels=num_labels ) image_processor = AutoImageProcessor.from_pretrained(model_name) transforms = self.init_preprocessor(image_processor) @@ -100,9 +110,9 @@ class VisualTransformerGen(ViewGen, TransformerGen): trainer = Trainer( model=self.model, optimizer_name="adamW", - lr=self.lr, device=self.device, loss_fn=torch.nn.CrossEntropyLoss(), + lr=self.lr, print_steps=self.print_steps, evaluate_step=self.evaluate_step, patience=self.patience, @@ -111,7 +121,7 @@ class VisualTransformerGen(ViewGen, TransformerGen): trainer.train( train_dataloader=tra_dataloader, - val_dataloader=val_dataloader, + eval_dataloader=val_dataloader, epochs=self.epochs, ) @@ -149,7 +159,7 @@ class MultimodalDatasetTorch(Dataset): ], [], ) - print(f"- lX has shape: {self.X.shape}\n- lY has shape: {self.Y.shape}") + # print(f"- lX has shape: {self.X.shape}\n- lY has shape: {self.Y.shape}") def __len__(self): return len(self.X) @@ -169,7 +179,8 @@ if __name__ == "__main__": dataset = MultiNewsDataset(expanduser(_dataset_path_hardcoded), debug=True) lXtr, lYtr = dataset.training() - vg = VisualTransformerGen(model_name="vit") + vg = VisualTransformerGen( + model_name="vit", device="cuda", epochs=1000, evaluate_step=10, patience=100 + ) lX, lY = dataset.training() vg.fit(lX, lY) - print("lel") diff --git a/main.py b/main.py index 3890676..783e58a 100644 --- a/main.py +++ b/main.py @@ -16,26 +16,46 @@ TODO: """ -def main(args): - # Loading dataset ------------------------ +def get_dataset(datasetname): + assert datasetname in ["multinews", "amazon", "rcv1-2"], "dataset not supported" RCV_DATAPATH = expanduser( "~/datasets/rcv1-2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle" ) - # dataset = MultiNewsDataset(expanduser(args.dataset_path)) - # dataset = AmazonDataset(domains=args.domains,nrows=args.nrows,min_count=args.min_count,max_labels=args.max_labels) - dataset = ( - MultilingualDataset(dataset_name="rcv1-2") - .load(RCV_DATAPATH) - .reduce_data(langs=["en", "it", "fr"], maxn=100) - ) + MULTINEWS_DATAPATH = expanduser("~/datasets/MultiNews/20110730/") + if datasetname == "multinews": + dataset = MultiNewsDataset( + expanduser(MULTINEWS_DATAPATH), + excluded_langs=["ar", "pe", "pl", "tr", "ua"], + ) + elif datasetname == "amazon": + dataset = AmazonDataset( + domains=args.domains, + nrows=args.nrows, + min_count=args.min_count, + max_labels=args.max_labels, + ) + elif datasetname == "rcv1-2": + dataset = ( + MultilingualDataset(dataset_name="rcv1-2") + .load(RCV_DATAPATH) + .reduce_data(langs=["en", "it", "fr"], maxn=100) + ) + else: + raise NotImplementedError + return dataset - if isinstance(dataset, MultilingualDataset): + +def main(args): + dataset = get_dataset(args.dataset) + if isinstance(dataset, MultilingualDataset) or isinstance( + dataset, MultiNewsDataset + ): lX, lY = dataset.training() - lX_te, lY_te = dataset.test() + # lX_te, lY_te = dataset.test() + lX_te, lY_te = dataset.training() else: _lX = dataset.dX _lY = dataset.dY - # ---------------------------------------- tinit = time() @@ -74,6 +94,7 @@ def main(args): gfun.fit(lX, lY) if args.load_trained is None: + print("[NB: FORCE-SKIPPING MODEL SAVE]") gfun.save() # if not args.load_model: @@ -98,6 +119,7 @@ if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-l", "--load_trained", type=str, default=None) # Dataset parameters ------------------- + parser.add_argument("-d", "--dataset", type=str, default="multinews") parser.add_argument("--domains", type=str, default="all") parser.add_argument("--nrows", type=int, default=10000) parser.add_argument("--min_count", type=int, default=10)