From 0fdb39532cc2aec98e3117d8fdfb3d1e02868d40 Mon Sep 17 00:00:00 2001 From: andrea Date: Fri, 12 Feb 2021 16:15:38 +0100 Subject: [PATCH] fixed funnelling transform function. Now it averages lZparts accoring to the actual number of embedders used for a given language (e.g., 'da' -> -x -m -w -b ->'da' will be averaged by 4, 'en' -> -m, -b -> 'en' will be averaged by 2) --- main.py | 13 ++++++------- src/funnelling.py | 22 +++++++++++++++++----- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index 534c432..d100433 100644 --- a/main.py +++ b/main.py @@ -25,13 +25,12 @@ def main(args): lX, ly = data.training() lXte, lyte = data.test() - # # TODO: debug settings - # print(f'\n[Running on DEBUG mode - samples per language are reduced to 50 max!]\n') - # lX = {k: v[:50] for k, v in lX.items()} - # ly = {k: v[:50] for k, v in ly.items()} - # lXte = {k: v[:50] for k, v in lXte.items()} - # lyte = {k: v[:50] for k, v in lyte.items()} - + # TODO: debug settings + # print(f'\n[Running on DEBUG mode - samples per language are reduced to 5 max!]\n') + # lX = {k: v[:5] for k, v in lX.items()} + # ly = {k: v[:5] for k, v in ly.items()} + # lXte = {k: v[:5] for k, v in lXte.items()} + # lyte = {k: v[:5] for k, v in lyte.items()} # Init multilingualIndex - mandatory when deploying Neural View Generators... if args.gru_embedder or args.bert_embedder: diff --git a/src/funnelling.py b/src/funnelling.py index 457ee89..814ab82 100644 --- a/src/funnelling.py +++ b/src/funnelling.py @@ -43,8 +43,8 @@ class DocEmbedderList: :param lX: :return: common latent space (averaged). """ - langs = sorted(lX.keys()) - lZparts = {lang: None for lang in langs} + self.langs = sorted(lX.keys()) + lZparts = {lang: None for lang in self.langs} for embedder in self.embedders: lZ = embedder.transform(lX) @@ -57,12 +57,24 @@ class DocEmbedderList: n_embedders = len(self.embedders) # Zero shot experiments: removing k:v if v is None (i.e, it is a lang that will be used in zero shot setting) lZparts = {k: v for k, v in lZparts.items() if v is not None} - - return {lang: lZparts[lang]/n_embedders for lang in sorted(lZparts.keys())} # Averaging feature spaces + lang_number_embedders = self.get_number_embedders_zeroshot() + return {lang: lZparts[lang]/lang_number_embedders[lang] for lang in sorted(lZparts.keys())} # Averaging feature spaces def fit_transform(self, lX, ly): return self.fit(lX, ly).transform(lX) + def get_number_embedders_zeroshot(self): + lang_number_embedders = {lang: len(self.embedders) for lang in self.langs} + for lang in self.langs: + for embedder in self.embedders: + if isinstance(embedder, VanillaFunGen): + if lang not in embedder.train_langs: + lang_number_embedders[lang] = 2 # todo: number of view gen is hard-codede + else: + if lang not in embedder.embedder.train_langs: + lang_number_embedders[lang] = 2 # todo: number of view gen is hard-codede + return lang_number_embedders + class FeatureSet2Posteriors: """ @@ -80,7 +92,7 @@ class FeatureSet2Posteriors: self.l2 = l2 self.n_jobs = n_jobs self.prob_classifier = MetaClassifier( - SVC(kernel='rbf', gamma='auto', probability=True, cache_size=1000, random_state=1), n_jobs=n_jobs) + SVC(kernel='rbf', gamma='auto', probability=True, cache_size=1000, random_state=1), n_jobs=self.n_jobs) def fit(self, lX, ly): lZ = self.embedder.fit_transform(lX, ly)