diff --git a/main.py b/main.py index 6b5d075..ea6a329 100644 --- a/main.py +++ b/main.py @@ -41,7 +41,8 @@ def main(args): embedder_list.append(museEmbedder) if args.wce_embedder: - wceEmbedder = WordClassGen(n_jobs=args.n_jobs) + wceEmbedder = WordClassGen(n_jobs=args.n_jobs, + zero_shot=zero_shot, train_langs=zscl_train_langs) # Todo: testing zero shot embedder_list.append(wceEmbedder) if args.gru_embedder: @@ -74,7 +75,7 @@ def main(args): # Testing ---------------------------------------- print('\n[Testing Generalized Funnelling]') time_te = time.time() - # Zero shot scenario -> setting first tier learners zero_shot param to False + # TODO: Zero shot scenario -> setting first tier learners zero_shot param to False gfun.set_zero_shot(val=False) ly_ = gfun.predict(lXte) l_eval = evaluate(ly_true=lyte, ly_pred=ly_) diff --git a/src/view_generators.py b/src/view_generators.py index 19110c2..688f133 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -181,7 +181,7 @@ class WordClassGen(ViewGen): View Generator (w): generates document representation via Word-Class-Embeddings. Document embeddings are obtained via weighted sum of document's constituent embeddings. """ - def __init__(self, n_jobs=-1): + def __init__(self, zero_shot=False, train_langs: list = None, n_jobs=-1): """ Init WordClassGen. :param n_jobs: int, number of concurrent workers @@ -191,6 +191,11 @@ class WordClassGen(ViewGen): self.langs = None self.lWce = None self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True) + # Zero shot parameters + self.zero_shot = zero_shot + if train_langs is None: + train_langs = ['it'] + self.train_langs = train_langs def fit(self, lX, ly): """ @@ -215,19 +220,34 @@ class WordClassGen(ViewGen): :param lX: dict {lang: indexed documents} :return: document projection to the common latent space. """ - lX = self.vectorizer.transform(lX) + # Testing zero-shot experiments + if self.zero_shot: + lX = self.zero_shot_experiments(lX) + lX = {l: self.vectorizer.vectorizer[l].transform(lX[l]) for l in self.langs if lX[l] is not None} + else: + lX = self.vectorizer.transform(lX) XdotWce = Parallel(n_jobs=self.n_jobs)( - delayed(XdotM)(lX[lang], self.lWce[lang], sif=True) for lang in self.langs) - lWce = {l: XdotWce[i] for i, l in enumerate(self.langs)} + delayed(XdotM)(lX[lang], self.lWce[lang], sif=True) for lang in sorted(lX.keys())) + lWce = {l: XdotWce[i] for i, l in enumerate(sorted(lX.keys()))} lWce = _normalize(lWce, l2=True) return lWce def fit_transform(self, lX, ly): return self.fit(lX, ly).transform(lX) + def zero_shot_experiments(self, lX): + print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') + _lX = {} + for lang in self.langs: + if lang in self.train_langs: + _lX[lang] = lX[lang] + else: + _lX[lang] = None + lX = _lX + return lX + def set_zero_shot(self, val: bool): self.zero_shot = val - print('# TODO: WordClassGen has not been configured for zero-shot experiments') return