bugfix in muse extract method

This commit is contained in:
Alejandro Moreo Fernandez 2020-01-20 14:53:14 +01:00
parent 5bb5c913c0
commit d249c4801f
2 changed files with 8 additions and 1 deletions

View File

@ -21,6 +21,9 @@ class PretrainedEmbeddings(ABC):
@classmethod @classmethod
def reindex(cls, words, word2index): def reindex(cls, words, word2index):
if isinstance(words, dict):
words = list(zip(*sorted(words.items(), key=lambda x: x[1])))[0]
source_idx, target_idx = [], [] source_idx, target_idx = [], []
for i, word in enumerate(words): for i, word in enumerate(words):
if word not in word2index: continue if word not in word2index: continue

View File

@ -82,7 +82,8 @@ class MuseEmbedder:
MUSE = Parallel(n_jobs=self.n_jobs)( MUSE = Parallel(n_jobs=self.n_jobs)(
delayed(FastTextMUSE)(self.path, lang) for lang in self.langs delayed(FastTextMUSE)(self.path, lang) for lang in self.langs
) )
self.MUSE = {l:MUSE[i].extract(lV[l]).numpy() for i,l in enumerate(self.langs)} lWordList = {l:self._get_wordlist_from_word2index(lV[l]) for l in self.langs}
self.MUSE = {l:MUSE[i].extract(lWordList[l]).numpy() for i,l in enumerate(self.langs)}
return self return self
def transform(self, lX): def transform(self, lX):
@ -95,6 +96,9 @@ class MuseEmbedder:
def fit_transform(self, lX, ly, lV): def fit_transform(self, lX, ly, lV):
return self.fit(lX, ly, lV).transform(lX) return self.fit(lX, ly, lV).transform(lX)
def _get_wordlist_from_word2index(self, word2index):
return list(zip(*sorted(word2index.items(), key=lambda x: x[1])))[0]
class WordClassEmbedder: class WordClassEmbedder: