bugfix in muse extract method
This commit is contained in:
parent
5bb5c913c0
commit
d249c4801f
|
|
@ -21,6 +21,9 @@ class PretrainedEmbeddings(ABC):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def reindex(cls, words, word2index):
|
def reindex(cls, words, word2index):
|
||||||
|
if isinstance(words, dict):
|
||||||
|
words = list(zip(*sorted(words.items(), key=lambda x: x[1])))[0]
|
||||||
|
|
||||||
source_idx, target_idx = [], []
|
source_idx, target_idx = [], []
|
||||||
for i, word in enumerate(words):
|
for i, word in enumerate(words):
|
||||||
if word not in word2index: continue
|
if word not in word2index: continue
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,8 @@ class MuseEmbedder:
|
||||||
MUSE = Parallel(n_jobs=self.n_jobs)(
|
MUSE = Parallel(n_jobs=self.n_jobs)(
|
||||||
delayed(FastTextMUSE)(self.path, lang) for lang in self.langs
|
delayed(FastTextMUSE)(self.path, lang) for lang in self.langs
|
||||||
)
|
)
|
||||||
self.MUSE = {l:MUSE[i].extract(lV[l]).numpy() for i,l in enumerate(self.langs)}
|
lWordList = {l:self._get_wordlist_from_word2index(lV[l]) for l in self.langs}
|
||||||
|
self.MUSE = {l:MUSE[i].extract(lWordList[l]).numpy() for i,l in enumerate(self.langs)}
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def transform(self, lX):
|
def transform(self, lX):
|
||||||
|
|
@ -95,6 +96,9 @@ class MuseEmbedder:
|
||||||
def fit_transform(self, lX, ly, lV):
|
def fit_transform(self, lX, ly, lV):
|
||||||
return self.fit(lX, ly, lV).transform(lX)
|
return self.fit(lX, ly, lV).transform(lX)
|
||||||
|
|
||||||
|
def _get_wordlist_from_word2index(self, word2index):
|
||||||
|
return list(zip(*sorted(word2index.items(), key=lambda x: x[1])))[0]
|
||||||
|
|
||||||
|
|
||||||
class WordClassEmbedder:
|
class WordClassEmbedder:
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue