running comparison with refactor branch
This commit is contained in:
parent
091101b39d
commit
66952820f9
|
|
@ -328,7 +328,7 @@ class RecurrentEmbedder:
|
||||||
self.posteriorEmbedder = MetaClassifier(
|
self.posteriorEmbedder = MetaClassifier(
|
||||||
SVC(kernel='rbf', gamma='auto', probability=True, cache_size=1000, random_state=1), n_jobs=options.n_jobs)
|
SVC(kernel='rbf', gamma='auto', probability=True, cache_size=1000, random_state=1), n_jobs=options.n_jobs)
|
||||||
|
|
||||||
def fit(self, lX, ly, lV=None, batch_size=64, nepochs=200, val_epochs=1):
|
def fit(self, lX, ly, lV=None, batch_size=64, nepochs=2, val_epochs=1):
|
||||||
print('### Gated Recurrent Unit View Generator (G)')
|
print('### Gated Recurrent Unit View Generator (G)')
|
||||||
# could be better to init model here at first .fit() call!
|
# could be better to init model here at first .fit() call!
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
|
|
@ -397,6 +397,16 @@ class RecurrentEmbedder:
|
||||||
def _get_doc_embeddings(self, lX, batch_size=64):
|
def _get_doc_embeddings(self, lX, batch_size=64):
|
||||||
assert self.is_trained, 'Model is not trained, cannot call transform before fitting the model!'
|
assert self.is_trained, 'Model is not trained, cannot call transform before fitting the model!'
|
||||||
print('Generating document embeddings via GRU')
|
print('Generating document embeddings via GRU')
|
||||||
|
data = {}
|
||||||
|
for lang in lX.keys():
|
||||||
|
indexed = index(data=lX[lang],
|
||||||
|
vocab=self.multilingual_index.l_index[lang].word2index,
|
||||||
|
known_words=set(self.multilingual_index.l_index[lang].word2index.keys()),
|
||||||
|
analyzer=self.multilingual_index.l_vectorizer.get_analyzer(lang),
|
||||||
|
unk_index=self.multilingual_index.l_index[lang].unk_index,
|
||||||
|
out_of_vocabulary=self.multilingual_index.l_index[lang].out_of_vocabulary)
|
||||||
|
data[lang] = indexed
|
||||||
|
|
||||||
lX = {}
|
lX = {}
|
||||||
ly = {}
|
ly = {}
|
||||||
batcher_transform = BatchGRU(batch_size, batches_per_epoch=batch_size, languages=self.langs,
|
batcher_transform = BatchGRU(batch_size, batches_per_epoch=batch_size, languages=self.langs,
|
||||||
|
|
@ -404,9 +414,12 @@ class RecurrentEmbedder:
|
||||||
|
|
||||||
l_devel_index = self.multilingual_index.l_devel_index()
|
l_devel_index = self.multilingual_index.l_devel_index()
|
||||||
l_devel_target = self.multilingual_index.l_devel_target()
|
l_devel_target = self.multilingual_index.l_devel_target()
|
||||||
|
# l_devel_target = {k: v[:len(data[lang])] for k, v in l_devel_target.items()}
|
||||||
|
|
||||||
|
# for idx, (batch, post, bert_emb, target, lang) in enumerate(
|
||||||
|
# batcher_transform.batchify(l_devel_index, None, None, l_devel_target)):
|
||||||
for idx, (batch, post, bert_emb, target, lang) in enumerate(
|
for idx, (batch, post, bert_emb, target, lang) in enumerate(
|
||||||
batcher_transform.batchify(l_devel_index, None, None, l_devel_target)):
|
batcher_transform.batchify(data, None, None, l_devel_target)):
|
||||||
if lang not in lX.keys():
|
if lang not in lX.keys():
|
||||||
lX[lang] = self.model.get_embeddings(batch, lang)
|
lX[lang] = self.model.get_embeddings(batch, lang)
|
||||||
ly[lang] = target.cpu().detach().numpy()
|
ly[lang] = target.cpu().detach().numpy()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue