Implementing inference functions
This commit is contained in:
parent
472b64ee0e
commit
4d3ef41a07
|
|
@ -65,9 +65,8 @@ class RecurrentDataset(Dataset):
|
|||
ly_batch[current_lang].append(d[1])
|
||||
|
||||
for lang in lX_batch.keys():
|
||||
# TODO: double check padding function (too many left pad tokens?)
|
||||
lX_batch[lang] = self.pad(lX_batch[lang], pad_index=self.lPad_index[lang], max_pad_length=70)
|
||||
# max_pad_length=self.define_pad_length(lX_batch[lang]))
|
||||
lX_batch[lang] = self.pad(lX_batch[lang], pad_index=self.lPad_index[lang],
|
||||
max_pad_length=self.define_pad_length(lX_batch[lang]))
|
||||
lX_batch[lang] = torch.LongTensor(lX_batch[lang])
|
||||
ly_batch[lang] = torch.FloatTensor(ly_batch[lang])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,36 +0,0 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
""
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
|
|
@ -3,6 +3,7 @@ from util.embeddings_manager import MuseLoader
|
|||
from view_generators import RecurrentGen, BertGen
|
||||
from data.dataset_builder import MultilingualDataset
|
||||
from util.common import MultilingualIndex
|
||||
from time import time
|
||||
|
||||
|
||||
def main(args):
|
||||
|
|
@ -21,23 +22,23 @@ def main(args):
|
|||
|
||||
# Init multilingualIndex - mandatory when deploying Neural View Generators...
|
||||
multilingualIndex = MultilingualIndex()
|
||||
# lMuse = MuseLoader(langs=sorted(lX.keys()), cache=)
|
||||
lMuse = MuseLoader(langs=sorted(lX.keys()), cache=EMBEDDINGS_PATH)
|
||||
multilingualIndex.index(lX, ly, lXte, lyte, l_pretrained_vocabulary=lMuse.vocabulary())
|
||||
|
||||
# gFun = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=N_JOBS)
|
||||
# gFun = MuseGen(muse_dir='/home/andreapdr/funneling_pdr/embeddings', n_jobs=N_JOBS)
|
||||
# gFun = WordClassGen(n_jobs=N_JOBS)
|
||||
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=128,
|
||||
gFun = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=False, batch_size=256,
|
||||
nepochs=50, gpus=args.gpus, n_jobs=N_JOBS)
|
||||
# gFun = BertGen(multilingualIndex, batch_size=4, nepochs=10, gpus=args.gpus, n_jobs=N_JOBS)
|
||||
|
||||
gFun.fit(lX, ly)
|
||||
time_init = time()
|
||||
# gFun.fit(lX, ly)
|
||||
|
||||
# print('Projecting...')
|
||||
# y_ = gFun.transform(lX)
|
||||
|
||||
exit('Executed!')
|
||||
print('Projecting...')
|
||||
y_ = gFun.transform(lX)
|
||||
train_time = round(time() - time_init, 3)
|
||||
exit(f'Executed! Training time: {train_time}!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from transformers import AdamW
|
|||
import pytorch_lightning as pl
|
||||
from models.helpers import init_embeddings
|
||||
from util.pl_metrics import CustomF1, CustomK
|
||||
from util.common import define_pad_length, pad
|
||||
|
||||
|
||||
class RecurrentModel(pl.LightningModule):
|
||||
|
|
@ -78,17 +79,17 @@ class RecurrentModel(pl.LightningModule):
|
|||
self.linear2 = nn.Linear(ff1, ff2)
|
||||
self.label = nn.Linear(ff2, self.output_size)
|
||||
|
||||
# TODO: setting lPretrained to None, letting it to its original value will bug first validation
|
||||
# TODO: setting lPretrained to None, letting it to its original value will "bug" first validation
|
||||
# step (i.e., checkpoint will store also its ++ value, I guess, making the saving process too slow)
|
||||
lPretrained = None
|
||||
self.save_hyperparameters()
|
||||
|
||||
def forward(self, lX):
|
||||
_tmp = []
|
||||
l_embed = []
|
||||
for lang in sorted(lX.keys()):
|
||||
doc_embedding = self.transform(lX[lang], lang)
|
||||
_tmp.append(doc_embedding)
|
||||
embed = torch.cat(_tmp, dim=0)
|
||||
l_embed.append(doc_embedding)
|
||||
embed = torch.cat(l_embed, dim=0)
|
||||
logits = self.label(embed)
|
||||
return logits
|
||||
|
||||
|
|
@ -106,6 +107,37 @@ class RecurrentModel(pl.LightningModule):
|
|||
output = self.dropout(F.relu(self.linear2(output)))
|
||||
return output
|
||||
|
||||
def encode(self, lX, l_pad, batch_size=128):
|
||||
"""
|
||||
Returns encoded data (i.e, RNN hidden state at second feed-forward layer - linear1). Dimensionality is 512.
|
||||
:param lX:
|
||||
:return:
|
||||
"""
|
||||
l_embed = {lang: [] for lang in lX.keys()}
|
||||
for lang in sorted(lX.keys()):
|
||||
for i in range(0, len(lX[lang]), batch_size):
|
||||
if i+batch_size > len(lX[lang]):
|
||||
batch = lX[lang][i:len(lX[lang])]
|
||||
else:
|
||||
batch = lX[lang][i:i+batch_size]
|
||||
max_pad_len = define_pad_length(batch)
|
||||
batch = pad(batch, pad_index=l_pad[lang], max_pad_length=max_pad_len)
|
||||
X = torch.LongTensor(batch)
|
||||
_batch_size = X.shape[0]
|
||||
X = self.embed(X, lang)
|
||||
X = self.embedding_dropout(X, drop_range=self.drop_embedding_range, p_drop=self.drop_embedding_prop,
|
||||
training=self.training)
|
||||
X = X.permute(1, 0, 2)
|
||||
h_0 = Variable(torch.zeros(self.n_layers * self.n_directions, _batch_size, self.hidden_size).to(self.device))
|
||||
output, _ = self.rnn(X, h_0)
|
||||
output = output[-1, :, :]
|
||||
output = F.relu(self.linear0(output))
|
||||
output = self.dropout(F.relu(self.linear1(output)))
|
||||
l_embed[lang].append(output)
|
||||
for k, v in l_embed.items():
|
||||
l_embed[k] = torch.cat(v, dim=0)
|
||||
return l_embed
|
||||
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
lX, ly = train_batch
|
||||
logits = self.forward(lX)
|
||||
|
|
@ -140,6 +172,7 @@ class RecurrentModel(pl.LightningModule):
|
|||
def training_epoch_end(self, outputs):
|
||||
# outputs is a of n dicts of m elements, where n is equal to the number of epoch steps and m is batchsize.
|
||||
# here we save epoch level metric values and compute them specifically for each language
|
||||
# TODO: this is horrible...
|
||||
res_macroF1 = {lang: [] for lang in self.langs}
|
||||
res_microF1 = {lang: [] for lang in self.langs}
|
||||
res_macroK = {lang: [] for lang in self.langs}
|
||||
|
|
@ -197,8 +230,12 @@ class RecurrentModel(pl.LightningModule):
|
|||
predictions = torch.sigmoid(logits) > 0.5
|
||||
microF1 = self.microF1(predictions, ly)
|
||||
macroF1 = self.macroF1(predictions, ly)
|
||||
self.log('test-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=False, logger=False)
|
||||
self.log('test-microF1', microF1, on_step=False, on_epoch=True, prog_bar=False, logger=False)
|
||||
microK = self.microK(predictions, ly)
|
||||
macroK = self.macroK(predictions, ly)
|
||||
self.log('test-macroF1', macroF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||
self.log('test-microF1', microF1, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||
self.log('test-macroK', macroK, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||
self.log('test-microK', microK, on_step=False, on_epoch=True, prog_bar=True, logger=True)
|
||||
return
|
||||
|
||||
def embed(self, X, lang):
|
||||
|
|
|
|||
|
|
@ -339,3 +339,17 @@ def is_true(tensor, device):
|
|||
|
||||
def is_false(tensor, device):
|
||||
return torch.where(tensor == 0, torch.Tensor([1]).to(device), torch.Tensor([0]).to(device))
|
||||
|
||||
|
||||
def define_pad_length(index_list):
|
||||
lengths = [len(index) for index in index_list]
|
||||
return int(np.mean(lengths) + np.std(lengths))
|
||||
|
||||
|
||||
def pad(index_list, pad_index, max_pad_length=None):
|
||||
pad_length = np.max([len(index) for index in index_list])
|
||||
if max_pad_length is not None:
|
||||
pad_length = min(pad_length, max_pad_length)
|
||||
for i, indexes in enumerate(index_list):
|
||||
index_list[i] = [pad_index] * (pad_length - len(indexes)) + indexes[:pad_length]
|
||||
return index_list
|
||||
|
|
@ -20,11 +20,10 @@ from util.embeddings_manager import MuseLoader, XdotM, wce_matrix
|
|||
from util.common import TfidfVectorizerMultilingual, _normalize
|
||||
from models.pl_gru import RecurrentModel
|
||||
from models.pl_bert import BertModel
|
||||
from models.lstm_class import RNNMultilingualClassifier
|
||||
from pytorch_lightning import Trainer
|
||||
from data.datamodule import RecurrentDataModule, BertDataModule
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
import torch
|
||||
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
|
||||
from time import time
|
||||
|
||||
|
||||
class ViewGen(ABC):
|
||||
|
|
@ -172,9 +171,8 @@ class RecurrentGen(ViewGen):
|
|||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||
self.multilingualIndex.embedding_matrices(self.pretrained, supervised=self.wce)
|
||||
self.model = self._init_model()
|
||||
# hp_tuning with Tensorboard: check https://www.tensorflow.org/tensorboard/hyperparameter_tuning_with_hparams
|
||||
# however, setting it to False at the moment!
|
||||
self.logger = TensorBoardLogger(save_dir='tb_logs', name='gfun_rnn_dev', default_hp_metric=False)
|
||||
self.logger = TensorBoardLogger(save_dir='tb_logs', name='rnn_dev', default_hp_metric=False)
|
||||
# self.logger = CSVLogger(save_dir='csv_logs', name='rnn_dev')
|
||||
|
||||
def _init_model(self):
|
||||
if self.stored_path:
|
||||
|
|
@ -201,7 +199,7 @@ class RecurrentGen(ViewGen):
|
|||
|
||||
def fit(self, lX, ly):
|
||||
"""
|
||||
lX and ly are not directly used. We rather get them from the multilingual index used in the instatiation
|
||||
lX and ly are not directly used. We rather get them from the multilingual index used in the instantiation
|
||||
of the Dataset object (RecurrentDataset) in the GfunDataModule class.
|
||||
:param lX:
|
||||
:param ly:
|
||||
|
|
@ -223,7 +221,20 @@ class RecurrentGen(ViewGen):
|
|||
return self
|
||||
|
||||
def transform(self, lX):
|
||||
pass
|
||||
"""
|
||||
Project documents to the common latent space
|
||||
:param lX:
|
||||
:return:
|
||||
"""
|
||||
l_pad = self.multilingualIndex.l_pad()
|
||||
data = self.multilingualIndex.l_devel_index()
|
||||
# trainer = Trainer(gpus=self.gpus)
|
||||
# self.model.eval()
|
||||
time_init = time()
|
||||
l_embeds = self.model.encode(data, l_pad, batch_size=256)
|
||||
transform_time = round(time() - time_init, 3)
|
||||
print(f'Executed! Transform took: {transform_time}')
|
||||
return l_embeds
|
||||
|
||||
def fit_transform(self, lX, ly):
|
||||
pass
|
||||
|
|
@ -239,26 +250,28 @@ class BertGen(ViewGen):
|
|||
self.batch_size = batch_size
|
||||
self.n_jobs = n_jobs
|
||||
self.stored_path = stored_path
|
||||
self.logger = TensorBoardLogger(save_dir='tb_logs', name='bert_dev', default_hp_metric=False)
|
||||
self.model = self._init_model()
|
||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||
self.logger = TensorBoardLogger(save_dir='tb_logs', name='bert_dev', default_hp_metric=False)
|
||||
|
||||
def _init_model(self):
|
||||
output_size = self.multilingualIndex.get_target_dim()
|
||||
return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)
|
||||
|
||||
def fit(self, lX, ly):
|
||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
||||
trainer = Trainer(default_root_dir='checkpoints/bert/', gradient_clip_val=1e-1, max_epochs=self.nepochs,
|
||||
gpus=self.gpus, logger=self.logger, checkpoint_callback=False)
|
||||
trainer.fit(self.model, bertDataModule)
|
||||
# trainer.test(self.model, bertDataModule)
|
||||
pass
|
||||
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
|
||||
logger=self.logger, checkpoint_callback=False)
|
||||
trainer.fit(self.model, datamodule=bertDataModule)
|
||||
trainer.test(self.model, datamodule=bertDataModule)
|
||||
return self
|
||||
|
||||
def transform(self, lX):
|
||||
# lX is raw text data. It has to be first indexed via multilingualIndex Vectorizer.
|
||||
pass
|
||||
|
||||
def fit_transform(self, lX, ly):
|
||||
# we can assume that we have already indexed data for transform() since we are first calling fit()
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue