diff --git a/src/view_generators.py b/src/view_generators.py index 20a8045..d014ef0 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -370,10 +370,10 @@ class BertGen(ViewGen): self.model.to('cuda' if self.gpus else 'cpu') self.model.eval() # time_init = time.time() - l_emebds = self.model.encode(data, batch_size=64) + l_embeds = self.model.encode(data, batch_size=64) # transform_time = round(time.time() - time_init, 3) # print(f'Executed! Transform took: {transform_time}') - return l_emebds + return l_embeds def fit_transform(self, lX, ly): # we can assume that we have already indexed data for transform() since we are first calling fit()