running comparison with refactor branch

This commit is contained in:
andrea 2021-01-29 14:50:34 +01:00
parent 66952820f9
commit 5405f60bd0
1 changed files with 11 additions and 6 deletions

View File

@ -328,7 +328,7 @@ class RecurrentEmbedder:
self.posteriorEmbedder = MetaClassifier( self.posteriorEmbedder = MetaClassifier(
SVC(kernel='rbf', gamma='auto', probability=True, cache_size=1000, random_state=1), n_jobs=options.n_jobs) SVC(kernel='rbf', gamma='auto', probability=True, cache_size=1000, random_state=1), n_jobs=options.n_jobs)
def fit(self, lX, ly, lV=None, batch_size=64, nepochs=2, val_epochs=1): def fit(self, lX, ly, lV=None, batch_size=64, nepochs=200, val_epochs=1):
print('### Gated Recurrent Unit View Generator (G)') print('### Gated Recurrent Unit View Generator (G)')
# could be better to init model here at first .fit() call! # could be better to init model here at first .fit() call!
if self.model is None: if self.model is None:
@ -412,14 +412,19 @@ class RecurrentEmbedder:
batcher_transform = BatchGRU(batch_size, batches_per_epoch=batch_size, languages=self.langs, batcher_transform = BatchGRU(batch_size, batches_per_epoch=batch_size, languages=self.langs,
lpad=self.multilingual_index.l_pad()) lpad=self.multilingual_index.l_pad())
l_devel_index = self.multilingual_index.l_devel_index() # l_devel_index = self.multilingual_index.l_devel_index()
l_devel_target = self.multilingual_index.l_devel_target()
# l_devel_target = {k: v[:len(data[lang])] for k, v in l_devel_target.items()}
l_devel_target = self.multilingual_index.l_devel_target()
l_devel_target = {k: v[:len(data[k])] for k, v in l_devel_target.items()} # todo -> debug
for batch, _, target, lang, in batchify(l_index=data,
l_post=None,
llabels=l_devel_target,
batchsize=batch_size,
lpad=self.multilingual_index.l_pad()):
# for idx, (batch, post, bert_emb, target, lang) in enumerate( # for idx, (batch, post, bert_emb, target, lang) in enumerate(
# batcher_transform.batchify(l_devel_index, None, None, l_devel_target)): # batcher_transform.batchify(l_devel_index, None, None, l_devel_target)):
for idx, (batch, post, bert_emb, target, lang) in enumerate( # for idx, (batch, post, bert_emb, target, lang) in enumerate(
batcher_transform.batchify(data, None, None, l_devel_target)): # batcher_transform.batchify(data, None, None, l_devel_target)):
if lang not in lX.keys(): if lang not in lX.keys():
lX[lang] = self.model.get_embeddings(batch, lang) lX[lang] = self.model.get_embeddings(batch, lang)
ly[lang] = target.cpu().detach().numpy() ly[lang] = target.cpu().detach().numpy()