From d249c4801f8cba9173f9f8e16ed4fe87cbfa168f Mon Sep 17 00:00:00 2001 From: Alex Moreo Date: Mon, 20 Jan 2020 14:53:14 +0100 Subject: [PATCH] bugfix in muse extract method --- src/embeddings/embeddings.py | 3 +++ src/learning/transformers.py | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/embeddings/embeddings.py b/src/embeddings/embeddings.py index 49ea7a0..a12c206 100644 --- a/src/embeddings/embeddings.py +++ b/src/embeddings/embeddings.py @@ -21,6 +21,9 @@ class PretrainedEmbeddings(ABC): @classmethod def reindex(cls, words, word2index): + if isinstance(words, dict): + words = list(zip(*sorted(words.items(), key=lambda x: x[1])))[0] + source_idx, target_idx = [], [] for i, word in enumerate(words): if word not in word2index: continue diff --git a/src/learning/transformers.py b/src/learning/transformers.py index 72f19f0..e5b0da4 100644 --- a/src/learning/transformers.py +++ b/src/learning/transformers.py @@ -82,7 +82,8 @@ class MuseEmbedder: MUSE = Parallel(n_jobs=self.n_jobs)( delayed(FastTextMUSE)(self.path, lang) for lang in self.langs ) - self.MUSE = {l:MUSE[i].extract(lV[l]).numpy() for i,l in enumerate(self.langs)} + lWordList = {l:self._get_wordlist_from_word2index(lV[l]) for l in self.langs} + self.MUSE = {l:MUSE[i].extract(lWordList[l]).numpy() for i,l in enumerate(self.langs)} return self def transform(self, lX): @@ -95,6 +96,9 @@ class MuseEmbedder: def fit_transform(self, lX, ly, lV): return self.fit(lX, ly, lV).transform(lX) + def _get_wordlist_from_word2index(self, word2index): + return list(zip(*sorted(word2index.items(), key=lambda x: x[1])))[0] + class WordClassEmbedder: