Implementing inference functions
This commit is contained in:
parent
4d3ef41a07
commit
9af9347531
|
|
@ -110,9 +110,13 @@ class RecurrentModel(pl.LightningModule):
|
|||
def encode(self, lX, l_pad, batch_size=128):
|
||||
"""
|
||||
Returns encoded data (i.e, RNN hidden state at second feed-forward layer - linear1). Dimensionality is 512.
|
||||
# TODO: does not run on gpu..
|
||||
:param lX:
|
||||
:param l_pad:
|
||||
:param batch_size:
|
||||
:return:
|
||||
"""
|
||||
with torch.no_grad():
|
||||
l_embed = {lang: [] for lang in lX.keys()}
|
||||
for lang in sorted(lX.keys()):
|
||||
for i in range(0, len(lX[lang]), batch_size):
|
||||
|
|
@ -135,7 +139,7 @@ class RecurrentModel(pl.LightningModule):
|
|||
output = self.dropout(F.relu(self.linear1(output)))
|
||||
l_embed[lang].append(output)
|
||||
for k, v in l_embed.items():
|
||||
l_embed[k] = torch.cat(v, dim=0)
|
||||
l_embed[k] = torch.cat(v, dim=0).cpu().numpy()
|
||||
return l_embed
|
||||
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ class RecurrentGen(ViewGen):
|
|||
l_pad = self.multilingualIndex.l_pad()
|
||||
data = self.multilingualIndex.l_devel_index()
|
||||
# trainer = Trainer(gpus=self.gpus)
|
||||
# self.model.eval()
|
||||
self.model.eval()
|
||||
time_init = time()
|
||||
l_embeds = self.model.encode(data, l_pad, batch_size=256)
|
||||
transform_time = round(time() - time_init, 3)
|
||||
|
|
|
|||
Loading…
Reference in New Issue