implemented zero-shot experiment code for VanillaFunGen and WordClassGen

This commit is contained in:
andrea 2021-02-04 16:50:09 +01:00
parent 495a0b6af9
commit a6be7857a3
6 changed files with 29 additions and 15 deletions

View File

@ -7,6 +7,9 @@ from src.util.evaluation import evaluate
from src.util.results_csv import CSVlog
from src.view_generators import *
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def main(args):
assert args.post_embedder or args.muse_embedder or args.wce_embedder or args.gru_embedder or args.bert_embedder, \
@ -17,7 +20,7 @@ def main(args):
print('Running generalized funnelling...')
data = MultilingualDataset.load(args.dataset)
data.set_view(languages=['da'])
# data.set_view(languages=['da', 'nl'])
data.show_dimensions()
lX, ly = data.training()
lXte, lyte = data.test()
@ -189,8 +192,8 @@ if __name__ == '__main__':
default=25)
parser.add_argument('--patience_bert', dest='patience_bert', type=int, metavar='',
help='set early stop patience for the BertGen, default 5',
default=5)
help='set early stop patience for the BertGen, default 3',
default=3)
parser.add_argument('--batch_rnn', dest='batch_rnn', type=int, metavar='',
help='set batchsize for the RecurrentGen, default 64',

18
run.sh
View File

@ -2,15 +2,15 @@
echo Running Zero-shot experiments [output at csv_logs/gfun/zero_shot_gfun.csv]
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da --n_jobs 3
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de --n_jobs 3
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en --n_jobs 3
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es --n_jobs 3
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr --n_jobs 3
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it --n_jobs 3
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl --n_jobs 3
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt --n_jobs 3
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt sv --n_jobs 3
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da --n_jobs 6
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de --n_jobs 6
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en --n_jobs 6
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es --n_jobs 6
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr --n_jobs 6
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it --n_jobs 6
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl --n_jobs 6
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt --n_jobs 6
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt sv --n_jobs 6
#for i in {0..10..1}

View File

@ -1,4 +1,5 @@
import time
import src.util.disable_sklearn_warnings
import numpy as np
from joblib import Parallel, delayed

View File

@ -0,0 +1,8 @@
import warnings
def warn(*args, **kwargs):
pass
warnings.warn = warn

View File

@ -1,4 +1,5 @@
import numpy as np
import src.util.disable_sklearn_warnings
class StandardizeTransformer:

View File

@ -16,6 +16,7 @@ This module contains the view generators that take care of computing the view sp
- View generator (-b): generates document embedding via mBERT model.
"""
from abc import ABC, abstractmethod
import src.util.disable_sklearn_warnings
# from time import time
from pytorch_lightning import Trainer
@ -317,7 +318,7 @@ class RecurrentGen(ViewGen):
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
self.multilingualIndex.embedding_matrices(self.pretrained, supervised=self.wce)
self.model = self._init_model()
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False)
self.logger = TensorBoardLogger(save_dir='tb_logs', name='rnn', default_hp_metric=False)
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
patience=self.patience, verbose=False, mode='max')
self.lr_monitor = LearningRateMonitor(logging_interval='epoch')
@ -446,7 +447,7 @@ class BertGen(ViewGen):
self.stored_path = stored_path
self.model = self._init_model()
self.patience = patience
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='bert', default_hp_metric=False)
self.logger = TensorBoardLogger(save_dir='tb_logs', name='bert', default_hp_metric=False)
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
patience=self.patience, verbose=False, mode='max')
@ -469,7 +470,7 @@ class BertGen(ViewGen):
:param ly: dict {lang: target vectors}
:return: self.
"""
print('\n# Fitting BertGen (M)...')
print('\n# Fitting BertGen (B)...')
create_if_not_exist(self.logger.save_dir)
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,