From f3fafd0f00736f75bbd78c7efdd48d7d8b1a40d8 Mon Sep 17 00:00:00 2001 From: andrea Date: Thu, 4 Feb 2021 12:44:36 +0100 Subject: [PATCH] implemented zero-shot experiment code for VanillaFunGen and WordClassGen --- main.py | 4 ++-- src/funnelling.py | 2 +- src/view_generators.py | 10 +++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index c3938fe..fd73306 100644 --- a/main.py +++ b/main.py @@ -62,7 +62,7 @@ def main(args): # Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True) meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'), - meta_parameters=get_params(optimc=args.optimc)) + meta_parameters=get_params(optimc=args.optimc), n_jobs=args.n_jobs) # Init Funnelling Architecture gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta) @@ -80,7 +80,7 @@ def main(args): if args.zero_shot: gfun.set_zero_shot(val=False) ly_ = gfun.predict(lXte) - l_eval = evaluate(ly_true=lyte, ly_pred=ly_) + l_eval = evaluate(ly_true=lyte, ly_pred=ly_, n_jobs=args.n_jobs) time_te = round(time.time() - time_te, 3) print(f'Testing completed in {time_te} seconds!') diff --git a/src/funnelling.py b/src/funnelling.py index c8d3fc6..4860a24 100644 --- a/src/funnelling.py +++ b/src/funnelling.py @@ -23,7 +23,7 @@ class DocEmbedderList: if isinstance(embedder, VanillaFunGen): _tmp.append(embedder) else: - _tmp.append(FeatureSet2Posteriors(embedder)) + _tmp.append(FeatureSet2Posteriors(embedder, n_jobs=embedder.n_jobs)) self.embedders = _tmp def fit(self, lX, ly): diff --git a/src/view_generators.py b/src/view_generators.py index bc3916a..673453e 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -78,7 +78,7 @@ class VanillaFunGen(ViewGen): self.train_langs = train_langs def fit(self, lX, ly): - print('# Fitting VanillaFunGen (X)...') + print('\n# Fitting VanillaFunGen (X)...') if self.zero_shot: print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') self.langs = sorted(self.train_langs) @@ -150,7 +150,7 @@ class MuseGen(ViewGen): :param ly: dict {lang: target vectors} :return: self. """ - print('# Fitting MuseGen (M)...') + print('\n# Fitting MuseGen (M)...') if self.zero_shot: print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') self.vectorizer.fit(lX) @@ -226,7 +226,7 @@ class WordClassGen(ViewGen): :param ly: dict {lang: target vectors} :return: self. """ - print('# Fitting WordClassGen (W)...') + print('\n# Fitting WordClassGen (W)...') if self.zero_shot: print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') self.langs = sorted(self.train_langs) @@ -359,7 +359,7 @@ class RecurrentGen(ViewGen): :param ly: dict {lang: target vectors} :return: self. """ - print('# Fitting RecurrentGen (G)...') + print('\n# Fitting RecurrentGen (G)...') create_if_not_exist(self.logger.save_dir) recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs, zero_shot=self.zero_shot, zscl_langs=self.train_langs) @@ -468,7 +468,7 @@ class BertGen(ViewGen): :param ly: dict {lang: target vectors} :return: self. """ - print('# Fitting BertGen (M)...') + print('\n# Fitting BertGen (M)...') create_if_not_exist(self.logger.save_dir) self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1) bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,