From a6be7857a34fb853a03e6716ea92bb0d8ed51b03 Mon Sep 17 00:00:00 2001 From: andrea Date: Thu, 4 Feb 2021 16:50:09 +0100 Subject: [PATCH] implemented zero-shot experiment code for VanillaFunGen and WordClassGen --- main.py | 9 ++++++--- run.sh | 18 +++++++++--------- src/models/learners.py | 1 + src/util/disable_sklearn_warnings.py | 8 ++++++++ src/util/standardizer.py | 1 + src/view_generators.py | 7 ++++--- 6 files changed, 29 insertions(+), 15 deletions(-) create mode 100644 src/util/disable_sklearn_warnings.py diff --git a/main.py b/main.py index 2ee7175..7ee1da3 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,9 @@ from src.util.evaluation import evaluate from src.util.results_csv import CSVlog from src.view_generators import * +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "1" + def main(args): assert args.post_embedder or args.muse_embedder or args.wce_embedder or args.gru_embedder or args.bert_embedder, \ @@ -17,7 +20,7 @@ def main(args): print('Running generalized funnelling...') data = MultilingualDataset.load(args.dataset) - data.set_view(languages=['da']) + # data.set_view(languages=['da', 'nl']) data.show_dimensions() lX, ly = data.training() lXte, lyte = data.test() @@ -189,8 +192,8 @@ if __name__ == '__main__': default=25) parser.add_argument('--patience_bert', dest='patience_bert', type=int, metavar='', - help='set early stop patience for the BertGen, default 5', - default=5) + help='set early stop patience for the BertGen, default 3', + default=3) parser.add_argument('--batch_rnn', dest='batch_rnn', type=int, metavar='', help='set batchsize for the RecurrentGen, default 64', diff --git a/run.sh b/run.sh index 788c0ee..219dc13 100644 --- a/run.sh +++ b/run.sh @@ -2,15 +2,15 @@ echo Running Zero-shot experiments [output at csv_logs/gfun/zero_shot_gfun.csv] -python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da --n_jobs 3 -#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de --n_jobs 3 -#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en --n_jobs 3 -#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es --n_jobs 3 -#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr --n_jobs 3 -#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it --n_jobs 3 -#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl --n_jobs 3 -#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt --n_jobs 3 -#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt sv --n_jobs 3 +python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da --n_jobs 6 +python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de --n_jobs 6 +python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en --n_jobs 6 +python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es --n_jobs 6 +python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr --n_jobs 6 +python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it --n_jobs 6 +python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl --n_jobs 6 +python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt --n_jobs 6 +python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt sv --n_jobs 6 #for i in {0..10..1} diff --git a/src/models/learners.py b/src/models/learners.py index 25fc16b..fc2fa0b 100644 --- a/src/models/learners.py +++ b/src/models/learners.py @@ -1,4 +1,5 @@ import time +import src.util.disable_sklearn_warnings import numpy as np from joblib import Parallel, delayed diff --git a/src/util/disable_sklearn_warnings.py b/src/util/disable_sklearn_warnings.py new file mode 100644 index 0000000..60156c0 --- /dev/null +++ b/src/util/disable_sklearn_warnings.py @@ -0,0 +1,8 @@ +import warnings + + +def warn(*args, **kwargs): + pass + + +warnings.warn = warn diff --git a/src/util/standardizer.py b/src/util/standardizer.py index 429bccd..b5e9aa5 100644 --- a/src/util/standardizer.py +++ b/src/util/standardizer.py @@ -1,4 +1,5 @@ import numpy as np +import src.util.disable_sklearn_warnings class StandardizeTransformer: diff --git a/src/view_generators.py b/src/view_generators.py index 3a3ff5d..0804aec 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -16,6 +16,7 @@ This module contains the view generators that take care of computing the view sp - View generator (-b): generates document embedding via mBERT model. """ from abc import ABC, abstractmethod +import src.util.disable_sklearn_warnings # from time import time from pytorch_lightning import Trainer @@ -317,7 +318,7 @@ class RecurrentGen(ViewGen): self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1) self.multilingualIndex.embedding_matrices(self.pretrained, supervised=self.wce) self.model = self._init_model() - self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False) + self.logger = TensorBoardLogger(save_dir='tb_logs', name='rnn', default_hp_metric=False) self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00, patience=self.patience, verbose=False, mode='max') self.lr_monitor = LearningRateMonitor(logging_interval='epoch') @@ -446,7 +447,7 @@ class BertGen(ViewGen): self.stored_path = stored_path self.model = self._init_model() self.patience = patience - self.logger = TensorBoardLogger(save_dir='../tb_logs', name='bert', default_hp_metric=False) + self.logger = TensorBoardLogger(save_dir='tb_logs', name='bert', default_hp_metric=False) self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00, patience=self.patience, verbose=False, mode='max') @@ -469,7 +470,7 @@ class BertGen(ViewGen): :param ly: dict {lang: target vectors} :return: self. """ - print('\n# Fitting BertGen (M)...') + print('\n# Fitting BertGen (B)...') create_if_not_exist(self.logger.save_dir) self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1) bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,