diff --git a/src/learning/transformers.py b/src/learning/transformers.py index 0032460..f99c23b 100644 --- a/src/learning/transformers.py +++ b/src/learning/transformers.py @@ -328,7 +328,7 @@ class RecurrentEmbedder: self.posteriorEmbedder = MetaClassifier( SVC(kernel='rbf', gamma='auto', probability=True, cache_size=1000, random_state=1), n_jobs=options.n_jobs) - def fit(self, lX, ly, lV=None, batch_size=64, nepochs=200, val_epochs=1): + def fit(self, lX, ly, lV=None, batch_size=64, nepochs=2, val_epochs=1): print('### Gated Recurrent Unit View Generator (G)') # could be better to init model here at first .fit() call! if self.model is None: @@ -397,6 +397,16 @@ class RecurrentEmbedder: def _get_doc_embeddings(self, lX, batch_size=64): assert self.is_trained, 'Model is not trained, cannot call transform before fitting the model!' print('Generating document embeddings via GRU') + data = {} + for lang in lX.keys(): + indexed = index(data=lX[lang], + vocab=self.multilingual_index.l_index[lang].word2index, + known_words=set(self.multilingual_index.l_index[lang].word2index.keys()), + analyzer=self.multilingual_index.l_vectorizer.get_analyzer(lang), + unk_index=self.multilingual_index.l_index[lang].unk_index, + out_of_vocabulary=self.multilingual_index.l_index[lang].out_of_vocabulary) + data[lang] = indexed + lX = {} ly = {} batcher_transform = BatchGRU(batch_size, batches_per_epoch=batch_size, languages=self.langs, @@ -404,9 +414,12 @@ class RecurrentEmbedder: l_devel_index = self.multilingual_index.l_devel_index() l_devel_target = self.multilingual_index.l_devel_target() + # l_devel_target = {k: v[:len(data[lang])] for k, v in l_devel_target.items()} + # for idx, (batch, post, bert_emb, target, lang) in enumerate( + # batcher_transform.batchify(l_devel_index, None, None, l_devel_target)): for idx, (batch, post, bert_emb, target, lang) in enumerate( - batcher_transform.batchify(l_devel_index, None, None, l_devel_target)): + batcher_transform.batchify(data, None, None, l_devel_target)): if lang not in lX.keys(): lX[lang] = self.model.get_embeddings(batch, lang) ly[lang] = target.cpu().detach().numpy()