Implementing inference functions

This commit is contained in:
andrea 2021-01-22 18:00:41 +01:00
parent 9af9347531
commit 01bd85d156
3 changed files with 3 additions and 2 deletions

View File

@ -16,7 +16,7 @@ def main(args):
_DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle' _DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings' EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings'
data = MultilingualDataset.load(_DATASET) data = MultilingualDataset.load(_DATASET)
# data.set_view(languages=['it', 'fr']) data.set_view(languages=['it', 'fr'])
lX, ly = data.training() lX, ly = data.training()
lXte, lyte = data.test() lXte, lyte = data.test()

View File

@ -126,7 +126,7 @@ class RecurrentModel(pl.LightningModule):
batch = lX[lang][i:i+batch_size] batch = lX[lang][i:i+batch_size]
max_pad_len = define_pad_length(batch) max_pad_len = define_pad_length(batch)
batch = pad(batch, pad_index=l_pad[lang], max_pad_length=max_pad_len) batch = pad(batch, pad_index=l_pad[lang], max_pad_length=max_pad_len)
X = torch.LongTensor(batch) X = torch.LongTensor(batch).to('cuda' if self.gpus else 'cpu')
_batch_size = X.shape[0] _batch_size = X.shape[0]
X = self.embed(X, lang) X = self.embed(X, lang)
X = self.embedding_dropout(X, drop_range=self.drop_embedding_range, p_drop=self.drop_embedding_prop, X = self.embedding_dropout(X, drop_range=self.drop_embedding_range, p_drop=self.drop_embedding_prop,

View File

@ -229,6 +229,7 @@ class RecurrentGen(ViewGen):
l_pad = self.multilingualIndex.l_pad() l_pad = self.multilingualIndex.l_pad()
data = self.multilingualIndex.l_devel_index() data = self.multilingualIndex.l_devel_index()
# trainer = Trainer(gpus=self.gpus) # trainer = Trainer(gpus=self.gpus)
self.model.to('cuda' if self.gpus else 'cpu')
self.model.eval() self.model.eval()
time_init = time() time_init = time()
l_embeds = self.model.encode(data, l_pad, batch_size=256) l_embeds = self.model.encode(data, l_pad, batch_size=256)