From 294d7c3be72415c054e3ea899f6fe159ed83cfc7 Mon Sep 17 00:00:00 2001 From: andrea Date: Tue, 19 Jan 2021 15:30:15 +0100 Subject: [PATCH] refactor --- refactor/data/datamodule.py | 12 ++---- refactor/main.py | 13 +++---- refactor/models/helpers.py | 14 ++++--- refactor/models/pl_gru.py | 73 ++++++------------------------------- refactor/util/common.py | 13 +++++-- refactor/view_generators.py | 2 +- 6 files changed, 42 insertions(+), 85 deletions(-) diff --git a/refactor/data/datamodule.py b/refactor/data/datamodule.py index bbb7cc1..67a83d6 100644 --- a/refactor/data/datamodule.py +++ b/refactor/data/datamodule.py @@ -103,7 +103,6 @@ class GfunDataModule(pl.LightningDataModule): pass def setup(self, stage=None): - # Assign train/val datasets for use in dataloaders if stage == 'fit' or stage is None: l_train_index, l_train_target = self.multilingualIndex.l_train() self.training_dataset = RecurrentDataset(l_train_index, l_train_target, @@ -111,9 +110,8 @@ class GfunDataModule(pl.LightningDataModule): l_val_index, l_val_target = self.multilingualIndex.l_val() self.val_dataset = RecurrentDataset(l_val_index, l_val_target, lPad_index=self.multilingualIndex.l_pad()) - # Assign test dataset for use in dataloader(s) if stage == 'test' or stage is None: - l_test_index, l_test_target = self.multilingualIndex.l_val() + l_test_index, l_test_target = self.multilingualIndex.l_test() self.test_dataset = RecurrentDataset(l_test_index, l_test_target, lPad_index=self.multilingualIndex.l_pad()) @@ -136,7 +134,6 @@ class BertDataModule(GfunDataModule): self.max_len = max_len def setup(self, stage=None): - # Assign train/val datasets for use in dataloaders if stage == 'fit' or stage is None: l_train_raw, l_train_target = self.multilingualIndex.l_train_raw() l_train_index = self.tokenize(l_train_raw, max_len=self.max_len) @@ -146,12 +143,11 @@ class BertDataModule(GfunDataModule): l_val_index = self.tokenize(l_val_raw, max_len=self.max_len) self.val_dataset = RecurrentDataset(l_val_index, l_val_target, lPad_index=self.multilingualIndex.l_pad()) - # Assign test dataset for use in dataloader(s) # TODO if stage == 'test' or stage is None: - l_val_raw, l_val_target = self.multilingualIndex.l_test_raw() - l_val_index = self.tokenize(l_val_raw) - self.test_dataset = RecurrentDataset(l_val_index, l_val_target, + l_test_raw, l_test_target = self.multilingualIndex.l_test_raw() + l_test_index = self.tokenize(l_val_raw, max_len=self.max_len) + self.test_dataset = RecurrentDataset(l_test_index, l_test_target, lPad_index=self.multilingualIndex.l_pad()) @staticmethod diff --git a/refactor/main.py b/refactor/main.py index 76c5e54..42ef9c9 100644 --- a/refactor/main.py +++ b/refactor/main.py @@ -7,29 +7,28 @@ from util.common import MultilingualIndex def main(args): N_JOBS = 8 - print('Running...') + print('Running refactored...') # _DATASET = '/homenfs/a.pedrotti1/datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle' # EMBEDDINGS_PATH = '/homenfs/a.pedrotti1/embeddings/MUSE' _DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle' - EMBEDDINGS_PATH = '/home/andreapdr/funneling_pdr/embeddings' + EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings' data = MultilingualDataset.load(_DATASET) - # data.set_view(languages=['it']) + data.set_view(languages=['it'], categories=[0,1]) lX, ly = data.training() lXte, lyte = data.test() - # Init multilingualIndex - mandatory when deploying Neural View Generators... + # Init multilingualIndex - mandatory when deploying Neural View Generators... multilingualIndex = MultilingualIndex() # lMuse = MuseLoader(langs=sorted(lX.keys()), cache=) lMuse = MuseLoader(langs=sorted(lX.keys()), cache=EMBEDDINGS_PATH) - multilingualIndex.index(lX, ly, lXte, l_pretrained_vocabulary=lMuse.vocabulary()) + multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary()) # gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS) # gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS) # gFun = WordClassGen(n_jobs=N_JOBS) - gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, gpus=args.gpus, n_jobs=N_JOBS, - stored_path='/home/andreapdr/gfun_refactor/tb_logs/gfun_rnn_dev/version_19/checkpoints/epoch=0-step=14.ckpt') + gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=512, gpus=args.gpus, n_jobs=N_JOBS) # gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS) gFun.fit(lX, ly) diff --git a/refactor/models/helpers.py b/refactor/models/helpers.py index 93e5805..b466f28 100755 --- a/refactor/models/helpers.py +++ b/refactor/models/helpers.py @@ -3,25 +3,29 @@ import torch.nn as nn from torch.nn import functional as F - -def init_embeddings(pretrained, vocab_size, learnable_length, device='cuda'): +def init_embeddings(pretrained, vocab_size, learnable_length): + """ + Compute the embedding matrix + :param pretrained: + :param vocab_size: + :param learnable_length: + :return: + """ pretrained_embeddings = None pretrained_length = 0 if pretrained is not None: pretrained_length = pretrained.shape[1] assert pretrained.shape[0] == vocab_size, 'pre-trained matrix does not match with the vocabulary size' pretrained_embeddings = nn.Embedding(vocab_size, pretrained_length) + # requires_grad=False sets the embedding layer as NOT trainable pretrained_embeddings.weight = nn.Parameter(pretrained, requires_grad=False) - # pretrained_embeddings.to(device) learnable_embeddings = None if learnable_length > 0: learnable_embeddings = nn.Embedding(vocab_size, learnable_length) - # learnable_embeddings.to(device) embedding_length = learnable_length + pretrained_length assert embedding_length > 0, '0-size embeddings' - return pretrained_embeddings, learnable_embeddings, embedding_length diff --git a/refactor/models/pl_gru.py b/refactor/models/pl_gru.py index 7987156..268a694 100644 --- a/refactor/models/pl_gru.py +++ b/refactor/models/pl_gru.py @@ -1,43 +1,17 @@ +# Lightning modules, see https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html import torch from torch import nn -from torch.optim import Adam from transformers import AdamW import torch.nn.functional as F from torch.autograd import Variable import pytorch_lightning as pl from pytorch_lightning.metrics import F1, Accuracy, Metric from torch.optim.lr_scheduler import StepLR - -from util.evaluation import evaluate from typing import Any, Optional, Tuple from pytorch_lightning.metrics.utils import _input_format_classification_one_hot, class_reduce +from models.helpers import init_embeddings import numpy as np - - -def init_embeddings(pretrained, vocab_size, learnable_length): - """ - Compute the embedding matrix - :param pretrained: - :param vocab_size: - :param learnable_length: - :return: - """ - pretrained_embeddings = None - pretrained_length = 0 - if pretrained is not None: - pretrained_length = pretrained.shape[1] - assert pretrained.shape[0] == vocab_size, 'pre-trained matrix does not match with the vocabulary size' - pretrained_embeddings = nn.Embedding(vocab_size, pretrained_length) - # requires_grad=False sets the embedding layer as NOT trainable - pretrained_embeddings.weight = nn.Parameter(pretrained, requires_grad=False) - - learnable_embeddings = None - if learnable_length > 0: - learnable_embeddings = nn.Embedding(vocab_size, learnable_length) - - embedding_length = learnable_length + pretrained_length - assert embedding_length > 0, '0-size embeddings' - return pretrained_embeddings, learnable_embeddings, embedding_length +from util.evaluation import evaluate class RecurrentModel(pl.LightningModule): @@ -97,7 +71,7 @@ class RecurrentModel(pl.LightningModule): self.label = nn.Linear(ff2, self.output_size) lPretrained = None # TODO: setting lPretrained to None, letting it to its original value will bug first - # validation step (i.e., checkpoint will store also its ++ value, I guess, making the saving process too slow) + # validation step (i.e., checkpoint will store also its ++ value, I guess, making the saving process too slow) self.save_hyperparameters() def forward(self, lX): @@ -124,7 +98,6 @@ class RecurrentModel(pl.LightningModule): return output def training_step(self, train_batch, batch_idx): - # TODO: double check StepLR scheduler... lX, ly = train_batch logits = self.forward(lX) _ly = [] @@ -132,20 +105,14 @@ class RecurrentModel(pl.LightningModule): _ly.append(ly[lang]) ly = torch.cat(_ly, dim=0) loss = self.loss(logits, ly) - # Squashing logits through Sigmoid in order to get confidence score predictions = torch.sigmoid(logits) > 0.5 - - # microf1 = self.microf1(predictions, ly) - # macrof1 = self.macrof1(predictions, ly) accuracy = self.accuracy(predictions, ly) - # l_pred = {lang: predictions.detach().cpu().numpy()} - # l_labels = {lang: ly.detach().cpu().numpy()} - # l_eval = evaluate(l_labels, l_pred, n_jobs=1) - - self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + custom = self.customMetrics(predictions, ly) + self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True) - return loss + self.log('custom', custom, on_step=False, on_epoch=True, prog_bar=True, logger=True) + return {'loss': loss} def validation_step(self, val_batch, batch_idx): lX, ly = val_batch @@ -156,17 +123,10 @@ class RecurrentModel(pl.LightningModule): ly = torch.cat(_ly, dim=0) loss = self.loss(logits, ly) predictions = torch.sigmoid(logits) > 0.5 - # microf1 = self.microf1(predictions, ly) - # macrof1 = self.macrof1(predictions, ly) accuracy = self.accuracy(predictions, ly) - - # l_pred = {lang: predictions.detach().cpu().numpy()} - # l_labels = {lang: y.detach().cpu().numpy()} - # l_eval = evaluate(l_labels, l_pred, n_jobs=1) - - self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) + self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('val-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True) - return + return {'loss': loss} def test_step(self, test_batch, batch_idx): lX, ly = test_batch @@ -177,18 +137,9 @@ class RecurrentModel(pl.LightningModule): ly = torch.cat(_ly, dim=0) predictions = torch.sigmoid(logits) > 0.5 accuracy = self.accuracy(predictions, ly) - custom_metric = self.customMetrics(logits, ly) # TODO self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True) - self.log('test-custom', custom_metric, on_step=False, on_epoch=True, prog_bar=False, logger=True) - return {'pred': predictions, 'target': ly} - - def test_epoch_end(self, outputs): - # all_pred = torch.vstack([out['pred'] for out in outputs]) # TODO - # all_y = torch.vstack([out['target'] for out in outputs]) # TODO - # r = eval(all_y, all_pred) - # print(r) - # X = torch.cat(X).view([X[0].shape[0], len(X)]) return + # return {'pred': predictions, 'target': ly} def embed(self, X, lang): input_list = [] @@ -308,5 +259,5 @@ def _fbeta_compute( new_den = 2 * true_positives + new_fp + new_fn if new_den.sum() == 0: # whats is the correct return type ? TODO - return 1. + return class_reduce(new_num, new_den, weights=actual_positives, class_reduction=average) return class_reduce(num, denom, weights=actual_positives, class_reduction=average) diff --git a/refactor/util/common.py b/refactor/util/common.py index 7792b1c..4bd0c20 100644 --- a/refactor/util/common.py +++ b/refactor/util/common.py @@ -52,7 +52,7 @@ class MultilingualIndex: self.l_index = {} self.l_vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True) - def index(self, l_devel_raw, l_devel_target, l_test_raw, l_pretrained_vocabulary=None): + def index(self, l_devel_raw, l_devel_target, l_test_raw, l_test_target, l_pretrained_vocabulary=None): self.langs = sorted(l_devel_raw.keys()) self.l_vectorizer.fit(l_devel_raw) l_vocabulary = self.l_vectorizer.vocabulary() @@ -62,7 +62,7 @@ class MultilingualIndex: for lang in self.langs: # Init monolingual Index - self.l_index[lang] = Index(l_devel_raw[lang], l_devel_target[lang], l_test_raw[lang], lang) + self.l_index[lang] = Index(l_devel_raw[lang], l_devel_target[lang], l_test_raw[lang], l_test_target[lang], lang) # call to index() function of monolingual Index self.l_index[lang].index(l_pretrained_vocabulary[lang], l_analyzer[lang], l_vocabulary[lang]) @@ -163,6 +163,9 @@ class MultilingualIndex: def l_val_target(self): return {l: index.val_target for l, index in self.l_index.items()} + def l_test_target(self): + return {l: index.test_target for l, index in self.l_index.items()} + def l_test_index(self): return {l: index.test_index for l, index in self.l_index.items()} @@ -182,6 +185,9 @@ class MultilingualIndex: def l_val(self): return self.l_val_index(), self.l_val_target() + def l_test(self): + return self.l_test_index(), self.l_test_target() + def l_train_raw(self): return self.l_train_raw_index(), self.l_train_target() @@ -193,7 +199,7 @@ class MultilingualIndex: class Index: - def __init__(self, devel_raw, devel_target, test_raw, lang): + def __init__(self, devel_raw, devel_target, test_raw, test_target, lang): """ Monolingual Index, takes care of tokenizing raw data, converting strings to ids, splitting the data into training and validation. @@ -206,6 +212,7 @@ class Index: self.devel_raw = devel_raw self.devel_target = devel_target self.test_raw = test_raw + self.test_target = test_target def index(self, pretrained_vocabulary, analyzer, vocabulary): self.word2index = dict(vocabulary) diff --git a/refactor/view_generators.py b/refactor/view_generators.py index 0ea3323..abe2442 100644 --- a/refactor/view_generators.py +++ b/refactor/view_generators.py @@ -205,7 +205,7 @@ class RecurrentGen(ViewGen): :return: """ recurrentDataModule = GfunDataModule(self.multilingualIndex, batchsize=self.batch_size) - trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=50) + trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=50, checkpoint_callback=False) # vanilla_torch_model = torch.load( # '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle')