refactor
This commit is contained in:
parent
34676167e8
commit
294d7c3be7
|
|
@ -103,7 +103,6 @@ class GfunDataModule(pl.LightningDataModule):
|
|||
pass
|
||||
|
||||
def setup(self, stage=None):
|
||||
# Assign train/val datasets for use in dataloaders
|
||||
if stage == 'fit' or stage is None:
|
||||
l_train_index, l_train_target = self.multilingualIndex.l_train()
|
||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||
|
|
@ -111,9 +110,8 @@ class GfunDataModule(pl.LightningDataModule):
|
|||
l_val_index, l_val_target = self.multilingualIndex.l_val()
|
||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||
lPad_index=self.multilingualIndex.l_pad())
|
||||
# Assign test dataset for use in dataloader(s)
|
||||
if stage == 'test' or stage is None:
|
||||
l_test_index, l_test_target = self.multilingualIndex.l_val()
|
||||
l_test_index, l_test_target = self.multilingualIndex.l_test()
|
||||
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
||||
lPad_index=self.multilingualIndex.l_pad())
|
||||
|
||||
|
|
@ -136,7 +134,6 @@ class BertDataModule(GfunDataModule):
|
|||
self.max_len = max_len
|
||||
|
||||
def setup(self, stage=None):
|
||||
# Assign train/val datasets for use in dataloaders
|
||||
if stage == 'fit' or stage is None:
|
||||
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw()
|
||||
l_train_index = self.tokenize(l_train_raw, max_len=self.max_len)
|
||||
|
|
@ -146,12 +143,11 @@ class BertDataModule(GfunDataModule):
|
|||
l_val_index = self.tokenize(l_val_raw, max_len=self.max_len)
|
||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||
lPad_index=self.multilingualIndex.l_pad())
|
||||
# Assign test dataset for use in dataloader(s)
|
||||
# TODO
|
||||
if stage == 'test' or stage is None:
|
||||
l_val_raw, l_val_target = self.multilingualIndex.l_test_raw()
|
||||
l_val_index = self.tokenize(l_val_raw)
|
||||
self.test_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw()
|
||||
l_test_index = self.tokenize(l_val_raw, max_len=self.max_len)
|
||||
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
||||
lPad_index=self.multilingualIndex.l_pad())
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -7,29 +7,28 @@ from util.common import MultilingualIndex
|
|||
|
||||
def main(args):
|
||||
N_JOBS = 8
|
||||
print('Running...')
|
||||
print('Running refactored...')
|
||||
|
||||
# _DATASET = '/homenfs/a.pedrotti1/datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
|
||||
# EMBEDDINGS_PATH = '/homenfs/a.pedrotti1/embeddings/MUSE'
|
||||
|
||||
_DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
|
||||
EMBEDDINGS_PATH = '/home/andreapdr/funneling_pdr/embeddings'
|
||||
EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings'
|
||||
data = MultilingualDataset.load(_DATASET)
|
||||
# data.set_view(languages=['it'])
|
||||
data.set_view(languages=['it'], categories=[0,1])
|
||||
lX, ly = data.training()
|
||||
lXte, lyte = data.test()
|
||||
|
||||
# Init multilingualIndex - mandatory when deploying Neural View Generators...
|
||||
# Init multilingualIndex - mandatory when deploying Neural View Generators...
|
||||
multilingualIndex = MultilingualIndex()
|
||||
# lMuse = MuseLoader(langs=sorted(lX.keys()), cache=)
|
||||
lMuse = MuseLoader(langs=sorted(lX.keys()), cache=EMBEDDINGS_PATH)
|
||||
multilingualIndex.index(lX, ly, lXte, l_pretrained_vocabulary=lMuse.vocabulary())
|
||||
multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary())
|
||||
|
||||
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
||||
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
|
||||
# gFun = WordClassGen(n_jobs=N_JOBS)
|
||||
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, gpus=args.gpus, n_jobs=N_JOBS,
|
||||
stored_path='/home/andreapdr/gfun_refactor/tb_logs/gfun_rnn_dev/version_19/checkpoints/epoch=0-step=14.ckpt')
|
||||
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=512, gpus=args.gpus, n_jobs=N_JOBS)
|
||||
# gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS)
|
||||
|
||||
gFun.fit(lX, ly)
|
||||
|
|
|
|||
|
|
@ -3,25 +3,29 @@ import torch.nn as nn
|
|||
from torch.nn import functional as F
|
||||
|
||||
|
||||
|
||||
def init_embeddings(pretrained, vocab_size, learnable_length, device='cuda'):
|
||||
def init_embeddings(pretrained, vocab_size, learnable_length):
|
||||
"""
|
||||
Compute the embedding matrix
|
||||
:param pretrained:
|
||||
:param vocab_size:
|
||||
:param learnable_length:
|
||||
:return:
|
||||
"""
|
||||
pretrained_embeddings = None
|
||||
pretrained_length = 0
|
||||
if pretrained is not None:
|
||||
pretrained_length = pretrained.shape[1]
|
||||
assert pretrained.shape[0] == vocab_size, 'pre-trained matrix does not match with the vocabulary size'
|
||||
pretrained_embeddings = nn.Embedding(vocab_size, pretrained_length)
|
||||
# requires_grad=False sets the embedding layer as NOT trainable
|
||||
pretrained_embeddings.weight = nn.Parameter(pretrained, requires_grad=False)
|
||||
# pretrained_embeddings.to(device)
|
||||
|
||||
learnable_embeddings = None
|
||||
if learnable_length > 0:
|
||||
learnable_embeddings = nn.Embedding(vocab_size, learnable_length)
|
||||
# learnable_embeddings.to(device)
|
||||
|
||||
embedding_length = learnable_length + pretrained_length
|
||||
assert embedding_length > 0, '0-size embeddings'
|
||||
|
||||
return pretrained_embeddings, learnable_embeddings, embedding_length
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,43 +1,17 @@
|
|||
# Lightning modules, see https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.optim import Adam
|
||||
from transformers import AdamW
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.metrics import F1, Accuracy, Metric
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
|
||||
from util.evaluation import evaluate
|
||||
from typing import Any, Optional, Tuple
|
||||
from pytorch_lightning.metrics.utils import _input_format_classification_one_hot, class_reduce
|
||||
from models.helpers import init_embeddings
|
||||
import numpy as np
|
||||
|
||||
|
||||
def init_embeddings(pretrained, vocab_size, learnable_length):
|
||||
"""
|
||||
Compute the embedding matrix
|
||||
:param pretrained:
|
||||
:param vocab_size:
|
||||
:param learnable_length:
|
||||
:return:
|
||||
"""
|
||||
pretrained_embeddings = None
|
||||
pretrained_length = 0
|
||||
if pretrained is not None:
|
||||
pretrained_length = pretrained.shape[1]
|
||||
assert pretrained.shape[0] == vocab_size, 'pre-trained matrix does not match with the vocabulary size'
|
||||
pretrained_embeddings = nn.Embedding(vocab_size, pretrained_length)
|
||||
# requires_grad=False sets the embedding layer as NOT trainable
|
||||
pretrained_embeddings.weight = nn.Parameter(pretrained, requires_grad=False)
|
||||
|
||||
learnable_embeddings = None
|
||||
if learnable_length > 0:
|
||||
learnable_embeddings = nn.Embedding(vocab_size, learnable_length)
|
||||
|
||||
embedding_length = learnable_length + pretrained_length
|
||||
assert embedding_length > 0, '0-size embeddings'
|
||||
return pretrained_embeddings, learnable_embeddings, embedding_length
|
||||
from util.evaluation import evaluate
|
||||
|
||||
|
||||
class RecurrentModel(pl.LightningModule):
|
||||
|
|
@ -97,7 +71,7 @@ class RecurrentModel(pl.LightningModule):
|
|||
self.label = nn.Linear(ff2, self.output_size)
|
||||
|
||||
lPretrained = None # TODO: setting lPretrained to None, letting it to its original value will bug first
|
||||
# validation step (i.e., checkpoint will store also its ++ value, I guess, making the saving process too slow)
|
||||
# validation step (i.e., checkpoint will store also its ++ value, I guess, making the saving process too slow)
|
||||
self.save_hyperparameters()
|
||||
|
||||
def forward(self, lX):
|
||||
|
|
@ -124,7 +98,6 @@ class RecurrentModel(pl.LightningModule):
|
|||
return output
|
||||
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
# TODO: double check StepLR scheduler...
|
||||
lX, ly = train_batch
|
||||
logits = self.forward(lX)
|
||||
_ly = []
|
||||
|
|
@ -132,20 +105,14 @@ class RecurrentModel(pl.LightningModule):
|
|||
_ly.append(ly[lang])
|
||||
ly = torch.cat(_ly, dim=0)
|
||||
loss = self.loss(logits, ly)
|
||||
|
||||
# Squashing logits through Sigmoid in order to get confidence score
|
||||
predictions = torch.sigmoid(logits) > 0.5
|
||||
|
||||
# microf1 = self.microf1(predictions, ly)
|
||||
# macrof1 = self.macrof1(predictions, ly)
|
||||
accuracy = self.accuracy(predictions, ly)
|
||||
# l_pred = {lang: predictions.detach().cpu().numpy()}
|
||||
# l_labels = {lang: ly.detach().cpu().numpy()}
|
||||
# l_eval = evaluate(l_labels, l_pred, n_jobs=1)
|
||||
|
||||
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
||||
custom = self.customMetrics(predictions, ly)
|
||||
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||
self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||
return loss
|
||||
self.log('custom', custom, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||
return {'loss': loss}
|
||||
|
||||
def validation_step(self, val_batch, batch_idx):
|
||||
lX, ly = val_batch
|
||||
|
|
@ -156,17 +123,10 @@ class RecurrentModel(pl.LightningModule):
|
|||
ly = torch.cat(_ly, dim=0)
|
||||
loss = self.loss(logits, ly)
|
||||
predictions = torch.sigmoid(logits) > 0.5
|
||||
# microf1 = self.microf1(predictions, ly)
|
||||
# macrof1 = self.macrof1(predictions, ly)
|
||||
accuracy = self.accuracy(predictions, ly)
|
||||
|
||||
# l_pred = {lang: predictions.detach().cpu().numpy()}
|
||||
# l_labels = {lang: y.detach().cpu().numpy()}
|
||||
# l_eval = evaluate(l_labels, l_pred, n_jobs=1)
|
||||
|
||||
self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
||||
self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||
self.log('val-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||
return
|
||||
return {'loss': loss}
|
||||
|
||||
def test_step(self, test_batch, batch_idx):
|
||||
lX, ly = test_batch
|
||||
|
|
@ -177,18 +137,9 @@ class RecurrentModel(pl.LightningModule):
|
|||
ly = torch.cat(_ly, dim=0)
|
||||
predictions = torch.sigmoid(logits) > 0.5
|
||||
accuracy = self.accuracy(predictions, ly)
|
||||
custom_metric = self.customMetrics(logits, ly) # TODO
|
||||
self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||
self.log('test-custom', custom_metric, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||
return {'pred': predictions, 'target': ly}
|
||||
|
||||
def test_epoch_end(self, outputs):
|
||||
# all_pred = torch.vstack([out['pred'] for out in outputs]) # TODO
|
||||
# all_y = torch.vstack([out['target'] for out in outputs]) # TODO
|
||||
# r = eval(all_y, all_pred)
|
||||
# print(r)
|
||||
# X = torch.cat(X).view([X[0].shape[0], len(X)])
|
||||
return
|
||||
# return {'pred': predictions, 'target': ly}
|
||||
|
||||
def embed(self, X, lang):
|
||||
input_list = []
|
||||
|
|
@ -308,5 +259,5 @@ def _fbeta_compute(
|
|||
new_den = 2 * true_positives + new_fp + new_fn
|
||||
if new_den.sum() == 0:
|
||||
# whats is the correct return type ? TODO
|
||||
return 1.
|
||||
return class_reduce(new_num, new_den, weights=actual_positives, class_reduction=average)
|
||||
return class_reduce(num, denom, weights=actual_positives, class_reduction=average)
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class MultilingualIndex:
|
|||
self.l_index = {}
|
||||
self.l_vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||
|
||||
def index(self, l_devel_raw, l_devel_target, l_test_raw, l_pretrained_vocabulary=None):
|
||||
def index(self, l_devel_raw, l_devel_target, l_test_raw, l_test_target, l_pretrained_vocabulary=None):
|
||||
self.langs = sorted(l_devel_raw.keys())
|
||||
self.l_vectorizer.fit(l_devel_raw)
|
||||
l_vocabulary = self.l_vectorizer.vocabulary()
|
||||
|
|
@ -62,7 +62,7 @@ class MultilingualIndex:
|
|||
|
||||
for lang in self.langs:
|
||||
# Init monolingual Index
|
||||
self.l_index[lang] = Index(l_devel_raw[lang], l_devel_target[lang], l_test_raw[lang], lang)
|
||||
self.l_index[lang] = Index(l_devel_raw[lang], l_devel_target[lang], l_test_raw[lang], l_test_target[lang], lang)
|
||||
# call to index() function of monolingual Index
|
||||
self.l_index[lang].index(l_pretrained_vocabulary[lang], l_analyzer[lang], l_vocabulary[lang])
|
||||
|
||||
|
|
@ -163,6 +163,9 @@ class MultilingualIndex:
|
|||
def l_val_target(self):
|
||||
return {l: index.val_target for l, index in self.l_index.items()}
|
||||
|
||||
def l_test_target(self):
|
||||
return {l: index.test_target for l, index in self.l_index.items()}
|
||||
|
||||
def l_test_index(self):
|
||||
return {l: index.test_index for l, index in self.l_index.items()}
|
||||
|
||||
|
|
@ -182,6 +185,9 @@ class MultilingualIndex:
|
|||
def l_val(self):
|
||||
return self.l_val_index(), self.l_val_target()
|
||||
|
||||
def l_test(self):
|
||||
return self.l_test_index(), self.l_test_target()
|
||||
|
||||
def l_train_raw(self):
|
||||
return self.l_train_raw_index(), self.l_train_target()
|
||||
|
||||
|
|
@ -193,7 +199,7 @@ class MultilingualIndex:
|
|||
|
||||
|
||||
class Index:
|
||||
def __init__(self, devel_raw, devel_target, test_raw, lang):
|
||||
def __init__(self, devel_raw, devel_target, test_raw, test_target, lang):
|
||||
"""
|
||||
Monolingual Index, takes care of tokenizing raw data, converting strings to ids, splitting the data into
|
||||
training and validation.
|
||||
|
|
@ -206,6 +212,7 @@ class Index:
|
|||
self.devel_raw = devel_raw
|
||||
self.devel_target = devel_target
|
||||
self.test_raw = test_raw
|
||||
self.test_target = test_target
|
||||
|
||||
def index(self, pretrained_vocabulary, analyzer, vocabulary):
|
||||
self.word2index = dict(vocabulary)
|
||||
|
|
|
|||
|
|
@ -205,7 +205,7 @@ class RecurrentGen(ViewGen):
|
|||
:return:
|
||||
"""
|
||||
recurrentDataModule = GfunDataModule(self.multilingualIndex, batchsize=self.batch_size)
|
||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=50)
|
||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=50, checkpoint_callback=False)
|
||||
|
||||
# vanilla_torch_model = torch.load(
|
||||
# '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle')
|
||||
|
|
|
|||
Loading…
Reference in New Issue