From 7f493da0f812bcbda20561c0bfc7713f59cbefed Mon Sep 17 00:00:00 2001 From: andrea Date: Tue, 2 Feb 2021 12:57:27 +0100 Subject: [PATCH] setting up zero-shot experiments (done and tested for MuseGen) --- main.py | 24 +++++++++--------------- src/funnelling.py | 13 +++++++++++-- src/view_generators.py | 40 +++++++++++++++++++++++++++++++++------- 3 files changed, 53 insertions(+), 24 deletions(-) diff --git a/main.py b/main.py index 4e4fbc5..6b5d075 100644 --- a/main.py +++ b/main.py @@ -15,23 +15,14 @@ def main(args): print('Running generalized funnelling...') data = MultilingualDataset.load(args.dataset) - data.set_view(languages=['it', 'da']) + data.set_view(languages=['it', 'da', 'nl']) data.show_dimensions() lX, ly = data.training() - - # Testing zero shot experiments - # zero_shot_setting = True - # if zero_shot_setting: - # # _lX = {} - # _ly = {} - # train_langs = ['it'] - # for train_lang in train_langs: - # # _lX[train_lang] = lX[train_lang] - # _ly[train_lang] = ly[train_lang] - # ly = _ly - lXte, lyte = data.test() + zero_shot = True + zscl_train_langs = ['it'] # Todo: testing zero shot + # Init multilingualIndex - mandatory when deploying Neural View Generators... if args.gru_embedder or args.bert_embedder: multilingualIndex = MultilingualIndex() @@ -45,7 +36,8 @@ def main(args): embedder_list.append(posteriorEmbedder) if args.muse_embedder: - museEmbedder = MuseGen(muse_dir=args.muse_dir, n_jobs=args.n_jobs, zero_shot=True) + museEmbedder = MuseGen(muse_dir=args.muse_dir, n_jobs=args.n_jobs, + zero_shot=zero_shot, train_langs=zscl_train_langs) # Todo: testing zero shot embedder_list.append(museEmbedder) if args.wce_embedder: @@ -82,6 +74,8 @@ def main(args): # Testing ---------------------------------------- print('\n[Testing Generalized Funnelling]') time_te = time.time() + # Zero shot scenario -> setting first tier learners zero_shot param to False + gfun.set_zero_shot(val=False) ly_ = gfun.predict(lXte) l_eval = evaluate(ly_true=lyte, ly_pred=ly_) time_te = round(time.time() - time_te, 3) @@ -111,7 +105,7 @@ def main(args): microf1=microf1, macrok=macrok, microk=microk, - notes=f'Train langs: {sorted(lX.keys())}') + notes=f'Train langs: {sorted(zscl_train_langs)}' if zero_shot else '') print('Averages: MF1, mF1, MK, mK', np.round(np.mean(np.array(metrics), axis=0), 3)) overall_time = round(time.time() - time_init, 3) diff --git a/src/funnelling.py b/src/funnelling.py index ba2be1b..116d67b 100644 --- a/src/funnelling.py +++ b/src/funnelling.py @@ -48,14 +48,17 @@ class DocEmbedderList: for embedder in self.embedders: lZ = embedder.transform(lX) - for lang in langs: + for lang in sorted(lZ.keys()): Z = lZ[lang] if lZparts[lang] is None: lZparts[lang] = Z else: lZparts[lang] += Z n_embedders = len(self.embedders) - return {lang: lZparts[lang]/n_embedders for lang in langs} # Averaging feature spaces + # Zero shot experiments: removing k:v if v is None (i.e, it is a lang that will be used in zero shot setting) + lZparts = {k: v for k, v in lZparts.items() if v is not None} + + return {lang: lZparts[lang]/n_embedders for lang in sorted(lZparts.keys())} # Averaging feature spaces def fit_transform(self, lX, ly): return self.fit(lX, ly).transform(lX) @@ -122,3 +125,9 @@ class Funnelling: lZ = self.first_tier.transform(lX) ly = self.meta.predict(lZ) return ly + + def set_zero_shot(self, val: bool): + for embedder in self.first_tier.embedders: + embedder.embedder.set_zero_shot(val) + return + diff --git a/src/view_generators.py b/src/view_generators.py index f8bf289..19110c2 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -93,13 +93,18 @@ class VanillaFunGen(ViewGen): def fit_transform(self, lX, ly): return self.fit(lX, ly).transform(lX) + def set_zero_shot(self, val: bool): + self.zero_shot = val + print('# TODO: PosteriorsGen has not been configured for zero-shot experiments') + return + class MuseGen(ViewGen): """ View Generator (m): generates document representation via MUSE embeddings (Fasttext multilingual word embeddings). Document embeddings are obtained via weighted sum of document's constituent embeddings. """ - def __init__(self, muse_dir='../embeddings', zero_shot=False, n_jobs=-1): + def __init__(self, muse_dir='../embeddings', zero_shot=False, train_langs: list = None, n_jobs=-1): """ Init the MuseGen. :param muse_dir: string, path to folder containing muse embeddings @@ -111,7 +116,11 @@ class MuseGen(ViewGen): self.langs = None self.lMuse = None self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True) + # Zero shot parameters self.zero_shot = zero_shot + if train_langs is None: + train_langs = ['it'] + self.train_langs = train_langs def fit(self, lX, ly): """ @@ -138,6 +147,7 @@ class MuseGen(ViewGen): """ # Testing zero-shot experiments if self.zero_shot: + lX = self.zero_shot_experiments(lX) lX = {l: self.vectorizer.vectorizer[l].transform(lX[l]) for l in self.langs if lX[l] is not None} else: lX = self.vectorizer.transform(lX) @@ -148,22 +158,23 @@ class MuseGen(ViewGen): return lZ def fit_transform(self, lX, ly): - print('## NB: Calling fit_transform!') - if self.zero_shot: - return self.fit(lX, ly).transform(self.zero_shot_experiments(lX)) return self.fit(lX, ly).transform(lX) - def zero_shot_experiments(self, lX, train_langs: list = ['it']): - print(f'# Zero-shot setting! Training langs will be set to: {sorted(train_langs)}') + def zero_shot_experiments(self, lX): + print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') _lX = {} for lang in self.langs: - if lang in train_langs: + if lang in self.train_langs: _lX[lang] = lX[lang] else: _lX[lang] = None lX = _lX return lX + def set_zero_shot(self, val: bool): + self.zero_shot = val + return + class WordClassGen(ViewGen): """ @@ -214,6 +225,11 @@ class WordClassGen(ViewGen): def fit_transform(self, lX, ly): return self.fit(lX, ly).transform(lX) + def set_zero_shot(self, val: bool): + self.zero_shot = val + print('# TODO: WordClassGen has not been configured for zero-shot experiments') + return + class RecurrentGen(ViewGen): """ @@ -335,6 +351,11 @@ class RecurrentGen(ViewGen): def fit_transform(self, lX, ly): return self.fit(lX, ly).transform(lX) + def set_zero_shot(self, val: bool): + self.zero_shot = val + print('# TODO: RecurrentGen has not been configured for zero-shot experiments') + return + class BertGen(ViewGen): """ @@ -405,3 +426,8 @@ class BertGen(ViewGen): def fit_transform(self, lX, ly): # we can assume that we have already indexed data for transform() since we are first calling fit() return self.fit(lX, ly).transform(lX) + + def set_zero_shot(self, val: bool): + self.zero_shot = val + print('# TODO: BertGen has not been configured for zero-shot experiments') + return