refactor
This commit is contained in:
parent
34676167e8
commit
294d7c3be7
|
|
@ -103,7 +103,6 @@ class GfunDataModule(pl.LightningDataModule):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def setup(self, stage=None):
|
def setup(self, stage=None):
|
||||||
# Assign train/val datasets for use in dataloaders
|
|
||||||
if stage == 'fit' or stage is None:
|
if stage == 'fit' or stage is None:
|
||||||
l_train_index, l_train_target = self.multilingualIndex.l_train()
|
l_train_index, l_train_target = self.multilingualIndex.l_train()
|
||||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||||
|
|
@ -111,9 +110,8 @@ class GfunDataModule(pl.LightningDataModule):
|
||||||
l_val_index, l_val_target = self.multilingualIndex.l_val()
|
l_val_index, l_val_target = self.multilingualIndex.l_val()
|
||||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
# Assign test dataset for use in dataloader(s)
|
|
||||||
if stage == 'test' or stage is None:
|
if stage == 'test' or stage is None:
|
||||||
l_test_index, l_test_target = self.multilingualIndex.l_val()
|
l_test_index, l_test_target = self.multilingualIndex.l_test()
|
||||||
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
|
|
||||||
|
|
@ -136,7 +134,6 @@ class BertDataModule(GfunDataModule):
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
|
|
||||||
def setup(self, stage=None):
|
def setup(self, stage=None):
|
||||||
# Assign train/val datasets for use in dataloaders
|
|
||||||
if stage == 'fit' or stage is None:
|
if stage == 'fit' or stage is None:
|
||||||
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw()
|
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw()
|
||||||
l_train_index = self.tokenize(l_train_raw, max_len=self.max_len)
|
l_train_index = self.tokenize(l_train_raw, max_len=self.max_len)
|
||||||
|
|
@ -146,12 +143,11 @@ class BertDataModule(GfunDataModule):
|
||||||
l_val_index = self.tokenize(l_val_raw, max_len=self.max_len)
|
l_val_index = self.tokenize(l_val_raw, max_len=self.max_len)
|
||||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
# Assign test dataset for use in dataloader(s)
|
|
||||||
# TODO
|
# TODO
|
||||||
if stage == 'test' or stage is None:
|
if stage == 'test' or stage is None:
|
||||||
l_val_raw, l_val_target = self.multilingualIndex.l_test_raw()
|
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw()
|
||||||
l_val_index = self.tokenize(l_val_raw)
|
l_test_index = self.tokenize(l_val_raw, max_len=self.max_len)
|
||||||
self.test_dataset = RecurrentDataset(l_val_index, l_val_target,
|
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -7,29 +7,28 @@ from util.common import MultilingualIndex
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
N_JOBS = 8
|
N_JOBS = 8
|
||||||
print('Running...')
|
print('Running refactored...')
|
||||||
|
|
||||||
# _DATASET = '/homenfs/a.pedrotti1/datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
|
# _DATASET = '/homenfs/a.pedrotti1/datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
|
||||||
# EMBEDDINGS_PATH = '/homenfs/a.pedrotti1/embeddings/MUSE'
|
# EMBEDDINGS_PATH = '/homenfs/a.pedrotti1/embeddings/MUSE'
|
||||||
|
|
||||||
_DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
|
_DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
|
||||||
EMBEDDINGS_PATH = '/home/andreapdr/funneling_pdr/embeddings'
|
EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings'
|
||||||
data = MultilingualDataset.load(_DATASET)
|
data = MultilingualDataset.load(_DATASET)
|
||||||
# data.set_view(languages=['it'])
|
data.set_view(languages=['it'], categories=[0,1])
|
||||||
lX, ly = data.training()
|
lX, ly = data.training()
|
||||||
lXte, lyte = data.test()
|
lXte, lyte = data.test()
|
||||||
|
|
||||||
# Init multilingualIndex - mandatory when deploying Neural View Generators...
|
# Init multilingualIndex - mandatory when deploying Neural View Generators...
|
||||||
multilingualIndex = MultilingualIndex()
|
multilingualIndex = MultilingualIndex()
|
||||||
# lMuse = MuseLoader(langs=sorted(lX.keys()), cache=)
|
# lMuse = MuseLoader(langs=sorted(lX.keys()), cache=)
|
||||||
lMuse = MuseLoader(langs=sorted(lX.keys()), cache=EMBEDDINGS_PATH)
|
lMuse = MuseLoader(langs=sorted(lX.keys()), cache=EMBEDDINGS_PATH)
|
||||||
multilingualIndex.index(lX, ly, lXte, l_pretrained_vocabulary=lMuse.vocabulary())
|
multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary())
|
||||||
|
|
||||||
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
||||||
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
|
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
|
||||||
# gFun = WordClassGen(n_jobs=N_JOBS)
|
# gFun = WordClassGen(n_jobs=N_JOBS)
|
||||||
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, gpus=args.gpus, n_jobs=N_JOBS,
|
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=512, gpus=args.gpus, n_jobs=N_JOBS)
|
||||||
stored_path='/home/andreapdr/gfun_refactor/tb_logs/gfun_rnn_dev/version_19/checkpoints/epoch=0-step=14.ckpt')
|
|
||||||
# gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS)
|
# gFun = BertGen(multilingualIndex, gpus=args.gpus, batch_size=128, n_jobs=N_JOBS)
|
||||||
|
|
||||||
gFun.fit(lX, ly)
|
gFun.fit(lX, ly)
|
||||||
|
|
|
||||||
|
|
@ -3,25 +3,29 @@ import torch.nn as nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def init_embeddings(pretrained, vocab_size, learnable_length):
|
||||||
def init_embeddings(pretrained, vocab_size, learnable_length, device='cuda'):
|
"""
|
||||||
|
Compute the embedding matrix
|
||||||
|
:param pretrained:
|
||||||
|
:param vocab_size:
|
||||||
|
:param learnable_length:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
pretrained_embeddings = None
|
pretrained_embeddings = None
|
||||||
pretrained_length = 0
|
pretrained_length = 0
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
pretrained_length = pretrained.shape[1]
|
pretrained_length = pretrained.shape[1]
|
||||||
assert pretrained.shape[0] == vocab_size, 'pre-trained matrix does not match with the vocabulary size'
|
assert pretrained.shape[0] == vocab_size, 'pre-trained matrix does not match with the vocabulary size'
|
||||||
pretrained_embeddings = nn.Embedding(vocab_size, pretrained_length)
|
pretrained_embeddings = nn.Embedding(vocab_size, pretrained_length)
|
||||||
|
# requires_grad=False sets the embedding layer as NOT trainable
|
||||||
pretrained_embeddings.weight = nn.Parameter(pretrained, requires_grad=False)
|
pretrained_embeddings.weight = nn.Parameter(pretrained, requires_grad=False)
|
||||||
# pretrained_embeddings.to(device)
|
|
||||||
|
|
||||||
learnable_embeddings = None
|
learnable_embeddings = None
|
||||||
if learnable_length > 0:
|
if learnable_length > 0:
|
||||||
learnable_embeddings = nn.Embedding(vocab_size, learnable_length)
|
learnable_embeddings = nn.Embedding(vocab_size, learnable_length)
|
||||||
# learnable_embeddings.to(device)
|
|
||||||
|
|
||||||
embedding_length = learnable_length + pretrained_length
|
embedding_length = learnable_length + pretrained_length
|
||||||
assert embedding_length > 0, '0-size embeddings'
|
assert embedding_length > 0, '0-size embeddings'
|
||||||
|
|
||||||
return pretrained_embeddings, learnable_embeddings, embedding_length
|
return pretrained_embeddings, learnable_embeddings, embedding_length
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,43 +1,17 @@
|
||||||
|
# Lightning modules, see https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim import Adam
|
|
||||||
from transformers import AdamW
|
from transformers import AdamW
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.autograd import Variable
|
from torch.autograd import Variable
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.metrics import F1, Accuracy, Metric
|
from pytorch_lightning.metrics import F1, Accuracy, Metric
|
||||||
from torch.optim.lr_scheduler import StepLR
|
from torch.optim.lr_scheduler import StepLR
|
||||||
|
|
||||||
from util.evaluation import evaluate
|
|
||||||
from typing import Any, Optional, Tuple
|
from typing import Any, Optional, Tuple
|
||||||
from pytorch_lightning.metrics.utils import _input_format_classification_one_hot, class_reduce
|
from pytorch_lightning.metrics.utils import _input_format_classification_one_hot, class_reduce
|
||||||
|
from models.helpers import init_embeddings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from util.evaluation import evaluate
|
||||||
|
|
||||||
def init_embeddings(pretrained, vocab_size, learnable_length):
|
|
||||||
"""
|
|
||||||
Compute the embedding matrix
|
|
||||||
:param pretrained:
|
|
||||||
:param vocab_size:
|
|
||||||
:param learnable_length:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
pretrained_embeddings = None
|
|
||||||
pretrained_length = 0
|
|
||||||
if pretrained is not None:
|
|
||||||
pretrained_length = pretrained.shape[1]
|
|
||||||
assert pretrained.shape[0] == vocab_size, 'pre-trained matrix does not match with the vocabulary size'
|
|
||||||
pretrained_embeddings = nn.Embedding(vocab_size, pretrained_length)
|
|
||||||
# requires_grad=False sets the embedding layer as NOT trainable
|
|
||||||
pretrained_embeddings.weight = nn.Parameter(pretrained, requires_grad=False)
|
|
||||||
|
|
||||||
learnable_embeddings = None
|
|
||||||
if learnable_length > 0:
|
|
||||||
learnable_embeddings = nn.Embedding(vocab_size, learnable_length)
|
|
||||||
|
|
||||||
embedding_length = learnable_length + pretrained_length
|
|
||||||
assert embedding_length > 0, '0-size embeddings'
|
|
||||||
return pretrained_embeddings, learnable_embeddings, embedding_length
|
|
||||||
|
|
||||||
|
|
||||||
class RecurrentModel(pl.LightningModule):
|
class RecurrentModel(pl.LightningModule):
|
||||||
|
|
@ -97,7 +71,7 @@ class RecurrentModel(pl.LightningModule):
|
||||||
self.label = nn.Linear(ff2, self.output_size)
|
self.label = nn.Linear(ff2, self.output_size)
|
||||||
|
|
||||||
lPretrained = None # TODO: setting lPretrained to None, letting it to its original value will bug first
|
lPretrained = None # TODO: setting lPretrained to None, letting it to its original value will bug first
|
||||||
# validation step (i.e., checkpoint will store also its ++ value, I guess, making the saving process too slow)
|
# validation step (i.e., checkpoint will store also its ++ value, I guess, making the saving process too slow)
|
||||||
self.save_hyperparameters()
|
self.save_hyperparameters()
|
||||||
|
|
||||||
def forward(self, lX):
|
def forward(self, lX):
|
||||||
|
|
@ -124,7 +98,6 @@ class RecurrentModel(pl.LightningModule):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx):
|
def training_step(self, train_batch, batch_idx):
|
||||||
# TODO: double check StepLR scheduler...
|
|
||||||
lX, ly = train_batch
|
lX, ly = train_batch
|
||||||
logits = self.forward(lX)
|
logits = self.forward(lX)
|
||||||
_ly = []
|
_ly = []
|
||||||
|
|
@ -132,20 +105,14 @@ class RecurrentModel(pl.LightningModule):
|
||||||
_ly.append(ly[lang])
|
_ly.append(ly[lang])
|
||||||
ly = torch.cat(_ly, dim=0)
|
ly = torch.cat(_ly, dim=0)
|
||||||
loss = self.loss(logits, ly)
|
loss = self.loss(logits, ly)
|
||||||
|
|
||||||
# Squashing logits through Sigmoid in order to get confidence score
|
# Squashing logits through Sigmoid in order to get confidence score
|
||||||
predictions = torch.sigmoid(logits) > 0.5
|
predictions = torch.sigmoid(logits) > 0.5
|
||||||
|
|
||||||
# microf1 = self.microf1(predictions, ly)
|
|
||||||
# macrof1 = self.macrof1(predictions, ly)
|
|
||||||
accuracy = self.accuracy(predictions, ly)
|
accuracy = self.accuracy(predictions, ly)
|
||||||
# l_pred = {lang: predictions.detach().cpu().numpy()}
|
custom = self.customMetrics(predictions, ly)
|
||||||
# l_labels = {lang: ly.detach().cpu().numpy()}
|
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
# l_eval = evaluate(l_labels, l_pred, n_jobs=1)
|
|
||||||
|
|
||||||
self.log('train-loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
|
||||||
self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
self.log('train-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
return loss
|
self.log('custom', custom, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||||
|
return {'loss': loss}
|
||||||
|
|
||||||
def validation_step(self, val_batch, batch_idx):
|
def validation_step(self, val_batch, batch_idx):
|
||||||
lX, ly = val_batch
|
lX, ly = val_batch
|
||||||
|
|
@ -156,17 +123,10 @@ class RecurrentModel(pl.LightningModule):
|
||||||
ly = torch.cat(_ly, dim=0)
|
ly = torch.cat(_ly, dim=0)
|
||||||
loss = self.loss(logits, ly)
|
loss = self.loss(logits, ly)
|
||||||
predictions = torch.sigmoid(logits) > 0.5
|
predictions = torch.sigmoid(logits) > 0.5
|
||||||
# microf1 = self.microf1(predictions, ly)
|
|
||||||
# macrof1 = self.macrof1(predictions, ly)
|
|
||||||
accuracy = self.accuracy(predictions, ly)
|
accuracy = self.accuracy(predictions, ly)
|
||||||
|
self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
# l_pred = {lang: predictions.detach().cpu().numpy()}
|
|
||||||
# l_labels = {lang: y.detach().cpu().numpy()}
|
|
||||||
# l_eval = evaluate(l_labels, l_pred, n_jobs=1)
|
|
||||||
|
|
||||||
self.log('val-loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
|
||||||
self.log('val-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
self.log('val-accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||||
return
|
return {'loss': loss}
|
||||||
|
|
||||||
def test_step(self, test_batch, batch_idx):
|
def test_step(self, test_batch, batch_idx):
|
||||||
lX, ly = test_batch
|
lX, ly = test_batch
|
||||||
|
|
@ -177,18 +137,9 @@ class RecurrentModel(pl.LightningModule):
|
||||||
ly = torch.cat(_ly, dim=0)
|
ly = torch.cat(_ly, dim=0)
|
||||||
predictions = torch.sigmoid(logits) > 0.5
|
predictions = torch.sigmoid(logits) > 0.5
|
||||||
accuracy = self.accuracy(predictions, ly)
|
accuracy = self.accuracy(predictions, ly)
|
||||||
custom_metric = self.customMetrics(logits, ly) # TODO
|
|
||||||
self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
self.log('test-accuracy', accuracy, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||||
self.log('test-custom', custom_metric, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
|
||||||
return {'pred': predictions, 'target': ly}
|
|
||||||
|
|
||||||
def test_epoch_end(self, outputs):
|
|
||||||
# all_pred = torch.vstack([out['pred'] for out in outputs]) # TODO
|
|
||||||
# all_y = torch.vstack([out['target'] for out in outputs]) # TODO
|
|
||||||
# r = eval(all_y, all_pred)
|
|
||||||
# print(r)
|
|
||||||
# X = torch.cat(X).view([X[0].shape[0], len(X)])
|
|
||||||
return
|
return
|
||||||
|
# return {'pred': predictions, 'target': ly}
|
||||||
|
|
||||||
def embed(self, X, lang):
|
def embed(self, X, lang):
|
||||||
input_list = []
|
input_list = []
|
||||||
|
|
@ -308,5 +259,5 @@ def _fbeta_compute(
|
||||||
new_den = 2 * true_positives + new_fp + new_fn
|
new_den = 2 * true_positives + new_fp + new_fn
|
||||||
if new_den.sum() == 0:
|
if new_den.sum() == 0:
|
||||||
# whats is the correct return type ? TODO
|
# whats is the correct return type ? TODO
|
||||||
return 1.
|
return class_reduce(new_num, new_den, weights=actual_positives, class_reduction=average)
|
||||||
return class_reduce(num, denom, weights=actual_positives, class_reduction=average)
|
return class_reduce(num, denom, weights=actual_positives, class_reduction=average)
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ class MultilingualIndex:
|
||||||
self.l_index = {}
|
self.l_index = {}
|
||||||
self.l_vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
self.l_vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True)
|
||||||
|
|
||||||
def index(self, l_devel_raw, l_devel_target, l_test_raw, l_pretrained_vocabulary=None):
|
def index(self, l_devel_raw, l_devel_target, l_test_raw, l_test_target, l_pretrained_vocabulary=None):
|
||||||
self.langs = sorted(l_devel_raw.keys())
|
self.langs = sorted(l_devel_raw.keys())
|
||||||
self.l_vectorizer.fit(l_devel_raw)
|
self.l_vectorizer.fit(l_devel_raw)
|
||||||
l_vocabulary = self.l_vectorizer.vocabulary()
|
l_vocabulary = self.l_vectorizer.vocabulary()
|
||||||
|
|
@ -62,7 +62,7 @@ class MultilingualIndex:
|
||||||
|
|
||||||
for lang in self.langs:
|
for lang in self.langs:
|
||||||
# Init monolingual Index
|
# Init monolingual Index
|
||||||
self.l_index[lang] = Index(l_devel_raw[lang], l_devel_target[lang], l_test_raw[lang], lang)
|
self.l_index[lang] = Index(l_devel_raw[lang], l_devel_target[lang], l_test_raw[lang], l_test_target[lang], lang)
|
||||||
# call to index() function of monolingual Index
|
# call to index() function of monolingual Index
|
||||||
self.l_index[lang].index(l_pretrained_vocabulary[lang], l_analyzer[lang], l_vocabulary[lang])
|
self.l_index[lang].index(l_pretrained_vocabulary[lang], l_analyzer[lang], l_vocabulary[lang])
|
||||||
|
|
||||||
|
|
@ -163,6 +163,9 @@ class MultilingualIndex:
|
||||||
def l_val_target(self):
|
def l_val_target(self):
|
||||||
return {l: index.val_target for l, index in self.l_index.items()}
|
return {l: index.val_target for l, index in self.l_index.items()}
|
||||||
|
|
||||||
|
def l_test_target(self):
|
||||||
|
return {l: index.test_target for l, index in self.l_index.items()}
|
||||||
|
|
||||||
def l_test_index(self):
|
def l_test_index(self):
|
||||||
return {l: index.test_index for l, index in self.l_index.items()}
|
return {l: index.test_index for l, index in self.l_index.items()}
|
||||||
|
|
||||||
|
|
@ -182,6 +185,9 @@ class MultilingualIndex:
|
||||||
def l_val(self):
|
def l_val(self):
|
||||||
return self.l_val_index(), self.l_val_target()
|
return self.l_val_index(), self.l_val_target()
|
||||||
|
|
||||||
|
def l_test(self):
|
||||||
|
return self.l_test_index(), self.l_test_target()
|
||||||
|
|
||||||
def l_train_raw(self):
|
def l_train_raw(self):
|
||||||
return self.l_train_raw_index(), self.l_train_target()
|
return self.l_train_raw_index(), self.l_train_target()
|
||||||
|
|
||||||
|
|
@ -193,7 +199,7 @@ class MultilingualIndex:
|
||||||
|
|
||||||
|
|
||||||
class Index:
|
class Index:
|
||||||
def __init__(self, devel_raw, devel_target, test_raw, lang):
|
def __init__(self, devel_raw, devel_target, test_raw, test_target, lang):
|
||||||
"""
|
"""
|
||||||
Monolingual Index, takes care of tokenizing raw data, converting strings to ids, splitting the data into
|
Monolingual Index, takes care of tokenizing raw data, converting strings to ids, splitting the data into
|
||||||
training and validation.
|
training and validation.
|
||||||
|
|
@ -206,6 +212,7 @@ class Index:
|
||||||
self.devel_raw = devel_raw
|
self.devel_raw = devel_raw
|
||||||
self.devel_target = devel_target
|
self.devel_target = devel_target
|
||||||
self.test_raw = test_raw
|
self.test_raw = test_raw
|
||||||
|
self.test_target = test_target
|
||||||
|
|
||||||
def index(self, pretrained_vocabulary, analyzer, vocabulary):
|
def index(self, pretrained_vocabulary, analyzer, vocabulary):
|
||||||
self.word2index = dict(vocabulary)
|
self.word2index = dict(vocabulary)
|
||||||
|
|
|
||||||
|
|
@ -205,7 +205,7 @@ class RecurrentGen(ViewGen):
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
recurrentDataModule = GfunDataModule(self.multilingualIndex, batchsize=self.batch_size)
|
recurrentDataModule = GfunDataModule(self.multilingualIndex, batchsize=self.batch_size)
|
||||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=50)
|
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=50, checkpoint_callback=False)
|
||||||
|
|
||||||
# vanilla_torch_model = torch.load(
|
# vanilla_torch_model = torch.load(
|
||||||
# '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle')
|
# '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue