This commit is contained in:
andrea 2019-12-17 10:42:29 +01:00
parent a95511b4d9
commit 56ee88220b
1 changed files with 3 additions and 3 deletions

View File

@ -225,7 +225,7 @@ class StorageEmbeddings:
return return
def _add_emebeddings_supervised(self, docs, labels, reduction, max_label_space, voc): def _add_embeddings_supervised(self, docs, labels, reduction, max_label_space, voc):
for lang in docs.keys(): # compute supervised matrices S - then apply PCA for lang in docs.keys(): # compute supervised matrices S - then apply PCA
print(f'# [supervised-matrix] for {lang}') print(f'# [supervised-matrix] for {lang}')
self.lang_S[lang] = get_supervised_embeddings(docs[lang], labels[lang], self.lang_S[lang] = get_supervised_embeddings(docs[lang], labels[lang],
@ -259,7 +259,7 @@ class StorageEmbeddings:
print(f'Applying PCA(n_components={i}') print(f'Applying PCA(n_components={i}')
for lang in languages: for lang in languages:
self.lang_S[lang] = stacked_pca.transform(self.lang_S[lang]) self.lang_S[lang] = stacked_pca.transform(self.lang_S[lang])
elif max_label_space <= nC: elif max_label_space <= nC: # also equal in order to reduce it to the same initial dimension
print(f'Computing PCA on Supervised Matrix PCA(n_components:{max_label_space})') print(f'Computing PCA on Supervised Matrix PCA(n_components:{max_label_space})')
self.lang_S = run_pca(max_label_space, self.lang_S) self.lang_S = run_pca(max_label_space, self.lang_S)
@ -275,7 +275,7 @@ class StorageEmbeddings:
if config['unsupervised']: if config['unsupervised']:
self._add_embeddings_unsupervised(config['we_type'], docs, vocs, config['dim_reduction_unsupervised']) self._add_embeddings_unsupervised(config['we_type'], docs, vocs, config['dim_reduction_unsupervised'])
if config['supervised']: if config['supervised']:
self._add_emebeddings_supervised(docs, labels, config['reduction'], config['max_label_space'], vocs) self._add_embeddings_supervised(docs, labels, config['reduction'], config['max_label_space'], vocs)
return self return self
def predict(self, config, docs): def predict(self, config, docs):