Implementing inference functions
This commit is contained in:
parent
9af9347531
commit
01bd85d156
|
|
@ -16,7 +16,7 @@ def main(args):
|
|||
_DATASET = '/home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle'
|
||||
EMBEDDINGS_PATH = '/home/andreapdr/gfun/embeddings'
|
||||
data = MultilingualDataset.load(_DATASET)
|
||||
# data.set_view(languages=['it', 'fr'])
|
||||
data.set_view(languages=['it', 'fr'])
|
||||
lX, ly = data.training()
|
||||
lXte, lyte = data.test()
|
||||
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ class RecurrentModel(pl.LightningModule):
|
|||
batch = lX[lang][i:i+batch_size]
|
||||
max_pad_len = define_pad_length(batch)
|
||||
batch = pad(batch, pad_index=l_pad[lang], max_pad_length=max_pad_len)
|
||||
X = torch.LongTensor(batch)
|
||||
X = torch.LongTensor(batch).to('cuda' if self.gpus else 'cpu')
|
||||
_batch_size = X.shape[0]
|
||||
X = self.embed(X, lang)
|
||||
X = self.embedding_dropout(X, drop_range=self.drop_embedding_range, p_drop=self.drop_embedding_prop,
|
||||
|
|
|
|||
|
|
@ -229,6 +229,7 @@ class RecurrentGen(ViewGen):
|
|||
l_pad = self.multilingualIndex.l_pad()
|
||||
data = self.multilingualIndex.l_devel_index()
|
||||
# trainer = Trainer(gpus=self.gpus)
|
||||
self.model.to('cuda' if self.gpus else 'cpu')
|
||||
self.model.eval()
|
||||
time_init = time()
|
||||
l_embeds = self.model.encode(data, l_pad, batch_size=256)
|
||||
|
|
|
|||
Loading…
Reference in New Issue