bugfix in muse extract method
This commit is contained in:
parent
5bb5c913c0
commit
d249c4801f
|
|
@ -21,6 +21,9 @@ class PretrainedEmbeddings(ABC):
|
|||
|
||||
@classmethod
|
||||
def reindex(cls, words, word2index):
|
||||
if isinstance(words, dict):
|
||||
words = list(zip(*sorted(words.items(), key=lambda x: x[1])))[0]
|
||||
|
||||
source_idx, target_idx = [], []
|
||||
for i, word in enumerate(words):
|
||||
if word not in word2index: continue
|
||||
|
|
|
|||
|
|
@ -82,7 +82,8 @@ class MuseEmbedder:
|
|||
MUSE = Parallel(n_jobs=self.n_jobs)(
|
||||
delayed(FastTextMUSE)(self.path, lang) for lang in self.langs
|
||||
)
|
||||
self.MUSE = {l:MUSE[i].extract(lV[l]).numpy() for i,l in enumerate(self.langs)}
|
||||
lWordList = {l:self._get_wordlist_from_word2index(lV[l]) for l in self.langs}
|
||||
self.MUSE = {l:MUSE[i].extract(lWordList[l]).numpy() for i,l in enumerate(self.langs)}
|
||||
return self
|
||||
|
||||
def transform(self, lX):
|
||||
|
|
@ -95,6 +96,9 @@ class MuseEmbedder:
|
|||
def fit_transform(self, lX, ly, lV):
|
||||
return self.fit(lX, ly, lV).transform(lX)
|
||||
|
||||
def _get_wordlist_from_word2index(self, word2index):
|
||||
return list(zip(*sorted(word2index.items(), key=lambda x: x[1])))[0]
|
||||
|
||||
|
||||
class WordClassEmbedder:
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue