This commit is contained in:
andrea 2021-01-19 15:30:15 +01:00
parent 34676167e8
commit 294d7c3be7
6 changed files with 42 additions and 85 deletions

View File

@ -103,7 +103,6 @@ class GfunDataModule(pl.LightningDataModule):
pass pass
def setup(self, stage=None): def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None: if stage == 'fit' or stage is None:
l_train_index, l_train_target = self.multilingualIndex.l_train() l_train_index, l_train_target = self.multilingualIndex.l_train()
self.training_dataset = RecurrentDataset(l_train_index, l_train_target, self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
@ -111,9 +110,8 @@ class GfunDataModule(pl.LightningDataModule):
l_val_index, l_val_target = self.multilingualIndex.l_val() l_val_index, l_val_target = self.multilingualIndex.l_val()
self.val_dataset = RecurrentDataset(l_val_index, l_val_target, self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
lPad_index=self.multilingualIndex.l_pad()) lPad_index=self.multilingualIndex.l_pad())
# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None: if stage == 'test' or stage is None:
l_test_index, l_test_target = self.multilingualIndex.l_val() l_test_index, l_test_target = self.multilingualIndex.l_test()
self.test_dataset = RecurrentDataset(l_test_index, l_test_target, self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
lPad_index=self.multilingualIndex.l_pad()) lPad_index=self.multilingualIndex.l_pad())
@ -136,7 +134,6 @@ class BertDataModule(GfunDataModule):
self.max_len = max_len self.max_len = max_len
def setup(self, stage=None): def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None: if stage == 'fit' or stage is None:
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw() l_train_raw, l_train_target = self.multilingualIndex.l_train_raw()
l_train_index = self.tokenize(l_train_raw, max_len=self.max_len) l_train_index = self.tokenize(l_train_raw, max_len=self.max_len)
@ -146,12 +143,11 @@ class BertDataModule(GfunDataModule):
l_val_index = self.tokenize(l_val_raw, max_len=self.max_len) l_val_index = self.tokenize(l_val_raw, max_len=self.max_len)
self.val_dataset = RecurrentDataset(l_val_index, l_val_target, self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
lPad_index=self.multilingualIndex.l_pad()) lPad_index=self.multilingualIndex.l_pad())
# Assign test dataset for use in dataloader(s)
# TODO # TODO
if stage == 'test' or stage is None: if stage == 'test' or stage is None:
l_val_raw, l_val_target = self.multilingualIndex.l_test_raw() l_test_raw, l_test_target = self.multilingualIndex.l_test_raw()
l_val_index = self.tokenize(l_val_raw) l_test_index = self.tokenize(l_val_raw, max_len=self.max_len)
self.test_dataset = RecurrentDataset(l_val_index, l_val_target, self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
lPad_index=self.multilingualIndex.l_pad()) lPad_index=self.multilingualIndex.l_pad())
@staticmethod @staticmethod

View File

@ -7,29 +7,28 @@ from util.common import MultilingualIndex
def main(args): def main(args):
N_JOBS = 8 N_JOBS = 8
print('Running...') print('Running refactored...')
# _DATASET = '/homenfs/a.pedrotti1/datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle' # _DATASET = '/homenfs/a.pedrotti1/datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
# EMBEDDINGS_PATH = '/homenfs/a.pedrotti1/embeddings/MUSE' # EMBEDDINGS_PATH = '/homenfs/a.pedrotti1/embeddings/MUSE'
_DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle' _DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
EMBEDDINGS_PATH = '/home/andreapdr/funneling_pdr/embeddings' EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings'
data = MultilingualDataset.load(_DATASET) data = MultilingualDataset.load(_DATASET)
# data.set_view(languages=['it']) data.set_view(languages=['it'], categories=[0,1])
lX, ly = data.training() lX, ly = data.training()
lXte, lyte = data.test() lXte, lyte = data.test()
# Init multilingualIndex - mandatory when deploying Neural View Generators... # Init multilingualIndex - mandatory when deploying Neural View Generators...
multilingualIndex = MultilingualIndex() multilingualIndex = MultilingualIndex()
# lMuse = MuseLoader(langs=sorted(lX.keys()), cache=) # lMuse = MuseLoader(langs=sorted(lX.keys()), cache=)
lMuse = MuseLoader(langs=sorted(lX.keys()), cache=EMBEDDINGS_PATH) lMuse = MuseLoader(langs=sorted(lX.keys()), cache=EMBEDDINGS_PATH)
multilingualIndex.index(lX, ly, lXte, l_pretrained_vocabulary=lMuse.vocabulary()) multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary())
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS) # gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS) # gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
# gFun = WordClassGen(n_jobs=N_JOBS) # gFun = WordClassGen(n_jobs=N_JOBS)
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, gpus=args.gpus, n_jobs=N_JOBS, gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=512, gpus=args.gpus, n_jobs=N_JOBS)
stored_path='/home/andreapdr/gfun_refactor/tb_logs/gfun_rnn_dev/version_19/checkpoints/epoch=0-step=14.ckpt')
# gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS) # gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS)
gFun.fit(lX, ly) gFun.fit(lX, ly)

View File

@ -3,25 +3,29 @@ import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
def init_embeddings(pretrained, vocab_size, learnable_length):
def init_embeddings(pretrained, vocab_size, learnable_length, device='cuda'): """
Compute the embedding matrix
:param pretrained:
:param vocab_size:
:param learnable_length:
:return:
"""
pretrained_embeddings = None pretrained_embeddings = None
pretrained_length = 0 pretrained_length = 0
if pretrained is not None: if pretrained is not None:
pretrained_length = pretrained.shape[1] pretrained_length = pretrained.shape[1]
assert pretrained.shape[0] == vocab_size, 'pre-trained matrix does not match with the vocabulary size' assert pretrained.shape[0] == vocab_size, 'pre-trained matrix does not match with the vocabulary size'
pretrained_embeddings = nn.Embedding(vocab_size, pretrained_length) pretrained_embeddings = nn.Embedding(vocab_size, pretrained_length)
# requires_grad=False sets the embedding layer as NOT trainable
pretrained_embeddings.weight = nn.Parameter(pretrained, requires_grad=False) pretrained_embeddings.weight = nn.Parameter(pretrained, requires_grad=False)
# pretrained_embeddings.to(device)
learnable_embeddings = None learnable_embeddings = None
if learnable_length > 0: if learnable_length > 0:
learnable_embeddings = nn.Embedding(vocab_size, learnable_length) learnable_embeddings = nn.Embedding(vocab_size, learnable_length)
# learnable_embeddings.to(device)
embedding_length = learnable_length + pretrained_length embedding_length = learnable_length + pretrained_length
assert embedding_length > 0, '0-size embeddings' assert embedding_length > 0, '0-size embeddings'
return pretrained_embeddings, learnable_embeddings, embedding_length return pretrained_embeddings, learnable_embeddings, embedding_length

View File

@ -1,43 +1,17 @@
# Lightning modules, see https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html
import torch import torch
from torch import nn from torch import nn
from torch.optim import Adam
from transformers import AdamW from transformers import AdamW
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd import Variable from torch.autograd import Variable
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.metrics import F1, Accuracy, Metric from pytorch_lightning.metrics import F1, Accuracy, Metric
from torch.optim.lr_scheduler import StepLR from torch.optim.lr_scheduler import StepLR
from util.evaluation import evaluate
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
from pytorch_lightning.metrics.utils import _input_format_classification_one_hot, class_reduce from pytorch_lightning.metrics.utils import _input_format_classification_one_hot, class_reduce
from models.helpers import init_embeddings
import numpy as np import numpy as np
from util.evaluation import evaluate
def init_embeddings(pretrained, vocab_size, learnable_length):
"""
Compute the embedding matrix
:param pretrained:
:param vocab_size:
:param learnable_length:
:return:
"""
pretrained_embeddings = None
pretrained_length = 0
if pretrained is not None:
pretrained_length = pretrained.shape[1]
assert pretrained.shape[0] == vocab_size, 'pre-trained matrix does not match with the vocabulary size'
pretrained_embeddings = nn.Embedding(vocab_size, pretrained_length)
# requires_grad=False sets the embedding layer as NOT trainable
pretrained_embeddings.weight = nn.Parameter(pretrained, requires_grad=False)
learnable_embeddings = None
if learnable_length > 0:
learnable_embeddings = nn.Embedding(vocab_size, learnable_length)
embedding_length = learnable_length + pretrained_length
assert embedding_length > 0, '0-size embeddings'
return pretrained_embeddings, learnable_embeddings, embedding_length
class RecurrentModel(pl.LightningModule): class RecurrentModel(pl.LightningModule):
@ -97,7 +71,7 @@ class RecurrentModel(pl.LightningModule):
self.label = nn.Linear(ff2, self.output_size) self.label = nn.Linear(ff2, self.output_size)
lPretrained = None # TODO: setting lPretrained to None, letting it to its original value will bug first lPretrained = None # TODO: setting lPretrained to None, letting it to its original value will bug first
# validation step (i.e., checkpoint will store also its ++ value, I guess, making the saving process too slow) # validation step (i.e., checkpoint will store also its ++ value, I guess, making the saving process too slow)
self.save_hyperparameters() self.save_hyperparameters()
def forward(self, lX): def forward(self, lX):
@ -124,7 +98,6 @@ class RecurrentModel(pl.LightningModule):
return output return output
def training_step(self, train_batch, batch_idx): def training_step(self, train_batch, batch_idx):
# TODO: double check StepLR scheduler...
lX, ly = train_batch lX, ly = train_batch
logits = self.forward(lX) logits = self.forward(lX)
_ly = [] _ly = []
@ -132,20 +105,14 @@ class RecurrentModel(pl.LightningModule):
_ly.append(ly[lang]) _ly.append(ly[lang])
ly = torch.cat(_ly, dim=0) ly = torch.cat(_ly, dim=0)
loss = self.loss(logits, ly) loss = self.loss(logits, ly)
# Squashing logits through Sigmoid in order to get confidence score # Squashing logits through Sigmoid in order to get confidence score
predictions = torch.sigmoid(logits) > 0.5 predictions = torch.sigmoid(logits) > 0.5
# microf1 = self.microf1(predictions, ly)
# macrof1 = self.macrof1(predictions, ly)
accuracy = self.accuracy(predictions, ly) accuracy = self.accuracy(predictions, ly)
# l_pred = {lang: predictions.detach().cpu().numpy()} custom = self.customMetrics(predictions, ly)
# l_labels = {lang: ly.detach().cpu().numpy()} self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
# l_eval = evaluate(l_labels, l_pred, n_jobs=1)
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
return loss self.log('custom', custom, on_step=False, on_epoch=True, prog_bar=True, logger=True)
return {'loss': loss}
def validation_step(self, val_batch, batch_idx): def validation_step(self, val_batch, batch_idx):
lX, ly = val_batch lX, ly = val_batch
@ -156,17 +123,10 @@ class RecurrentModel(pl.LightningModule):
ly = torch.cat(_ly, dim=0) ly = torch.cat(_ly, dim=0)
loss = self.loss(logits, ly) loss = self.loss(logits, ly)
predictions = torch.sigmoid(logits) > 0.5 predictions = torch.sigmoid(logits) > 0.5
# microf1 = self.microf1(predictions, ly)
# macrof1 = self.macrof1(predictions, ly)
accuracy = self.accuracy(predictions, ly) accuracy = self.accuracy(predictions, ly)
self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
# l_pred = {lang: predictions.detach().cpu().numpy()}
# l_labels = {lang: y.detach().cpu().numpy()}
# l_eval = evaluate(l_labels, l_pred, n_jobs=1)
self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
self.log('val-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True) self.log('val-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
return return {'loss': loss}
def test_step(self, test_batch, batch_idx): def test_step(self, test_batch, batch_idx):
lX, ly = test_batch lX, ly = test_batch
@ -177,18 +137,9 @@ class RecurrentModel(pl.LightningModule):
ly = torch.cat(_ly, dim=0) ly = torch.cat(_ly, dim=0)
predictions = torch.sigmoid(logits) > 0.5 predictions = torch.sigmoid(logits) > 0.5
accuracy = self.accuracy(predictions, ly) accuracy = self.accuracy(predictions, ly)
custom_metric = self.customMetrics(logits, ly) # TODO
self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True) self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True)
self.log('test-custom', custom_metric, on_step=False, on_epoch=True, prog_bar=False, logger=True)
return {'pred': predictions, 'target': ly}
def test_epoch_end(self, outputs):
# all_pred = torch.vstack([out['pred'] for out in outputs]) # TODO
# all_y = torch.vstack([out['target'] for out in outputs]) # TODO
# r = eval(all_y, all_pred)
# print(r)
# X = torch.cat(X).view([X[0].shape[0], len(X)])
return return
# return {'pred': predictions, 'target': ly}
def embed(self, X, lang): def embed(self, X, lang):
input_list = [] input_list = []
@ -308,5 +259,5 @@ def _fbeta_compute(
new_den = 2 * true_positives + new_fp + new_fn new_den = 2 * true_positives + new_fp + new_fn
if new_den.sum() == 0: if new_den.sum() == 0:
# whats is the correct return type ? TODO # whats is the correct return type ? TODO
return 1. return class_reduce(new_num, new_den, weights=actual_positives, class_reduction=average)
return class_reduce(num, denom, weights=actual_positives, class_reduction=average) return class_reduce(num, denom, weights=actual_positives, class_reduction=average)

View File

@ -52,7 +52,7 @@ class MultilingualIndex:
self.l_index = {} self.l_index = {}
self.l_vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True) self.l_vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
def index(self, l_devel_raw, l_devel_target, l_test_raw, l_pretrained_vocabulary=None): def index(self, l_devel_raw, l_devel_target, l_test_raw, l_test_target, l_pretrained_vocabulary=None):
self.langs = sorted(l_devel_raw.keys()) self.langs = sorted(l_devel_raw.keys())
self.l_vectorizer.fit(l_devel_raw) self.l_vectorizer.fit(l_devel_raw)
l_vocabulary = self.l_vectorizer.vocabulary() l_vocabulary = self.l_vectorizer.vocabulary()
@ -62,7 +62,7 @@ class MultilingualIndex:
for lang in self.langs: for lang in self.langs:
# Init monolingual Index # Init monolingual Index
self.l_index[lang] = Index(l_devel_raw[lang], l_devel_target[lang], l_test_raw[lang], lang) self.l_index[lang] = Index(l_devel_raw[lang], l_devel_target[lang], l_test_raw[lang], l_test_target[lang], lang)
# call to index() function of monolingual Index # call to index() function of monolingual Index
self.l_index[lang].index(l_pretrained_vocabulary[lang], l_analyzer[lang], l_vocabulary[lang]) self.l_index[lang].index(l_pretrained_vocabulary[lang], l_analyzer[lang], l_vocabulary[lang])
@ -163,6 +163,9 @@ class MultilingualIndex:
def l_val_target(self): def l_val_target(self):
return {l: index.val_target for l, index in self.l_index.items()} return {l: index.val_target for l, index in self.l_index.items()}
def l_test_target(self):
return {l: index.test_target for l, index in self.l_index.items()}
def l_test_index(self): def l_test_index(self):
return {l: index.test_index for l, index in self.l_index.items()} return {l: index.test_index for l, index in self.l_index.items()}
@ -182,6 +185,9 @@ class MultilingualIndex:
def l_val(self): def l_val(self):
return self.l_val_index(), self.l_val_target() return self.l_val_index(), self.l_val_target()
def l_test(self):
return self.l_test_index(), self.l_test_target()
def l_train_raw(self): def l_train_raw(self):
return self.l_train_raw_index(), self.l_train_target() return self.l_train_raw_index(), self.l_train_target()
@ -193,7 +199,7 @@ class MultilingualIndex:
class Index: class Index:
def __init__(self, devel_raw, devel_target, test_raw, lang): def __init__(self, devel_raw, devel_target, test_raw, test_target, lang):
""" """
Monolingual Index, takes care of tokenizing raw data, converting strings to ids, splitting the data into Monolingual Index, takes care of tokenizing raw data, converting strings to ids, splitting the data into
training and validation. training and validation.
@ -206,6 +212,7 @@ class Index:
self.devel_raw = devel_raw self.devel_raw = devel_raw
self.devel_target = devel_target self.devel_target = devel_target
self.test_raw = test_raw self.test_raw = test_raw
self.test_target = test_target
def index(self, pretrained_vocabulary, analyzer, vocabulary): def index(self, pretrained_vocabulary, analyzer, vocabulary):
self.word2index = dict(vocabulary) self.word2index = dict(vocabulary)

View File

@ -205,7 +205,7 @@ class RecurrentGen(ViewGen):
:return: :return:
""" """
recurrentDataModule = GfunDataModule(self.multilingualIndex, batchsize=self.batch_size) recurrentDataModule = GfunDataModule(self.multilingualIndex, batchsize=self.batch_size)
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=50) trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=50, checkpoint_callback=False)
# vanilla_torch_model = torch.load( # vanilla_torch_model = torch.load(
# '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle') # '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle')