From 01bd85d15659c49a56540931a8a90c3052fda46f Mon Sep 17 00:00:00 2001 From: andrea Date: Fri, 22 Jan 2021 18:00:41 +0100 Subject: [PATCH] Implementing inference functions --- refactor/main.py | 2 +- refactor/models/pl_gru.py | 2 +- refactor/view_generators.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/refactor/main.py b/refactor/main.py index ec2dc60..610defe 100644 --- a/refactor/main.py +++ b/refactor/main.py @@ -16,7 +16,7 @@ def main(args): _DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle' EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings' data = MultilingualDataset.load(_DATASET) - # data.set_view(languages=['it', 'fr']) + data.set_view(languages=['it', 'fr']) lX, ly = data.training() lXte, lyte = data.test() diff --git a/refactor/models/pl_gru.py b/refactor/models/pl_gru.py index c81f959..ed70e80 100644 --- a/refactor/models/pl_gru.py +++ b/refactor/models/pl_gru.py @@ -126,7 +126,7 @@ class RecurrentModel(pl.LightningModule): batch = lX[lang][i:i+batch_size] max_pad_len = define_pad_length(batch) batch = pad(batch, pad_index=l_pad[lang], max_pad_length=max_pad_len) - X = torch.LongTensor(batch) + X = torch.LongTensor(batch).to('cuda' if self.gpus else 'cpu') _batch_size = X.shape[0] X = self.embed(X, lang) X = self.embedding_dropout(X, drop_range=self.drop_embedding_range, p_drop=self.drop_embedding_prop, diff --git a/refactor/view_generators.py b/refactor/view_generators.py index d5f7ce8..8f1f191 100644 --- a/refactor/view_generators.py +++ b/refactor/view_generators.py @@ -229,6 +229,7 @@ class RecurrentGen(ViewGen): l_pad = self.multilingualIndex.l_pad() data = self.multilingualIndex.l_devel_index() # trainer = Trainer(gpus=self.gpus) + self.model.to('cuda' if self.gpus else 'cpu') self.model.eval() time_init = time() l_embeds = self.model.encode(data, l_pad, batch_size=256)