From 8968570d82573c994c584bf833216875771226f3 Mon Sep 17 00:00:00 2001 From: andrea Date: Thu, 4 Feb 2021 12:24:57 +0100 Subject: [PATCH] implemented zero-shot experiment code for VanillaFunGen and WordClassGen --- main.py | 19 +++++++++---------- run.sh | 18 +++++++++--------- src/view_generators.py | 12 +++++++----- 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/main.py b/main.py index c689547..c3938fe 100644 --- a/main.py +++ b/main.py @@ -17,14 +17,11 @@ def main(args): print('Running generalized funnelling...') data = MultilingualDataset.load(args.dataset) - data.set_view(languages=['nl']) + data.set_view(languages=['da', 'nl', 'it']) data.show_dimensions() lX, ly = data.training() lXte, lyte = data.test() - zero_shot = args.zero_shot - zscl_train_langs = args.zscl_langs - # Init multilingualIndex - mandatory when deploying Neural View Generators... if args.gru_embedder or args.bert_embedder: multilingualIndex = MultilingualIndex() @@ -34,29 +31,31 @@ def main(args): # Init ViewGenerators and append them to embedder_list embedder_list = [] if args.post_embedder: - posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=args.n_jobs) + posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), + zero_shot=args.zero_shot, train_langs=args.zscl_langs, + n_jobs=args.n_jobs) embedder_list.append(posteriorEmbedder) if args.muse_embedder: museEmbedder = MuseGen(muse_dir=args.muse_dir, n_jobs=args.n_jobs, - zero_shot=zero_shot, train_langs=zscl_train_langs) + zero_shot=args.zero_shot, train_langs=args.zscl_langs) embedder_list.append(museEmbedder) if args.wce_embedder: wceEmbedder = WordClassGen(n_jobs=args.n_jobs, - zero_shot=zero_shot, train_langs=zscl_train_langs) + zero_shot=args.zero_shot, train_langs=args.zscl_langs) embedder_list.append(wceEmbedder) if args.gru_embedder: rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=args.rnn_wce, batch_size=args.batch_rnn, nepochs=args.nepochs_rnn, patience=args.patience_rnn, - zero_shot=zero_shot, train_langs=zscl_train_langs, + zero_shot=args.zero_shot, train_langs=args.zscl_langs, gpus=args.gpus, n_jobs=args.n_jobs) embedder_list.append(rnnEmbedder) if args.bert_embedder: bertEmbedder = BertGen(multilingualIndex, batch_size=args.batch_bert, nepochs=args.nepochs_bert, - zero_shot=zero_shot, train_langs=zscl_train_langs, + zero_shot=args.zero_shot, train_langs=args.zscl_langs, patience=args.patience_bert, gpus=args.gpus, n_jobs=args.n_jobs) embedder_list.append(bertEmbedder) @@ -109,7 +108,7 @@ def main(args): microf1=microf1, macrok=macrok, microk=microk, - notes=f'Train langs: {sorted(zscl_train_langs)}' if zero_shot else '') + notes=f'Train langs: {sorted(args.zscl_langs)}' if args.zero_shot else '') print('Averages: MF1, mF1, MK, mK', np.round(np.mean(np.array(metrics), axis=0), 3)) overall_time = round(time.time() - time_init, 3) diff --git a/run.sh b/run.sh index 8470998..788c0ee 100644 --- a/run.sh +++ b/run.sh @@ -2,15 +2,15 @@ echo Running Zero-shot experiments [output at csv_logs/gfun/zero_shot_gfun.csv] -python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da --n_jobs 6 -python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de --n_jobs 6 -python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en --n_jobs 6 -python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es --n_jobs 6 -python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr --n_jobs 6 -python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it --n_jobs 6 -python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl --n_jobs 6 -python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt --n_jobs 6 -python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt sv --n_jobs 6 +python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da --n_jobs 3 +#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de --n_jobs 3 +#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en --n_jobs 3 +#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es --n_jobs 3 +#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr --n_jobs 3 +#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it --n_jobs 3 +#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl --n_jobs 3 +#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt --n_jobs 3 +#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt sv --n_jobs 3 #for i in {0..10..1} diff --git a/src/view_generators.py b/src/view_generators.py index 9c73615..bc3916a 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -77,15 +77,17 @@ class VanillaFunGen(ViewGen): train_langs = ['it'] self.train_langs = train_langs - def fit(self, lX, lY): + def fit(self, lX, ly): print('# Fitting VanillaFunGen (X)...') if self.zero_shot: + print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') self.langs = sorted(self.train_langs) lX = self.zero_shot_experiments(lX) + ly = self.zero_shot_experiments(ly) lX = self.vectorizer.fit_transform(lX) else: lX = self.vectorizer.fit_transform(lX) - self.doc_projector.fit(lX, lY) + self.doc_projector.fit(lX, ly) return self def transform(self, lX): @@ -104,7 +106,6 @@ class VanillaFunGen(ViewGen): return self.fit(lX, ly).transform(lX) def zero_shot_experiments(self, lX): - print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') _lX = {} for lang in self.langs: if lang in self.train_langs: @@ -150,6 +151,8 @@ class MuseGen(ViewGen): :return: self. """ print('# Fitting MuseGen (M)...') + if self.zero_shot: + print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') self.vectorizer.fit(lX) self.langs = sorted(lX.keys()) self.lMuse = MuseLoader(langs=self.langs, cache=self.muse_dir) @@ -181,7 +184,6 @@ class MuseGen(ViewGen): return self.fit(lX, ly).transform(lX) def zero_shot_experiments(self, lX): - print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') _lX = {} for lang in self.langs: if lang in self.train_langs: @@ -226,6 +228,7 @@ class WordClassGen(ViewGen): """ print('# Fitting WordClassGen (W)...') if self.zero_shot: + print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') self.langs = sorted(self.train_langs) lX = self.zero_shot_experiments(lX) lX = self.vectorizer.fit_transform(lX) @@ -257,7 +260,6 @@ class WordClassGen(ViewGen): return self.fit(lX, ly).transform(lX) def zero_shot_experiments(self, lX): - print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') _lX = {} for lang in self.langs: if lang in self.train_langs: