implemented zero-shot experiment code for VanillaFunGen and WordClassGen
This commit is contained in:
parent
495a0b6af9
commit
a6be7857a3
9
main.py
9
main.py
|
|
@ -7,6 +7,9 @@ from src.util.evaluation import evaluate
|
|||
from src.util.results_csv import CSVlog
|
||||
from src.view_generators import *
|
||||
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||
|
||||
|
||||
def main(args):
|
||||
assert args.post_embedder or args.muse_embedder or args.wce_embedder or args.gru_embedder or args.bert_embedder, \
|
||||
|
|
@ -17,7 +20,7 @@ def main(args):
|
|||
print('Running generalized funnelling...')
|
||||
|
||||
data = MultilingualDataset.load(args.dataset)
|
||||
data.set_view(languages=['da'])
|
||||
# data.set_view(languages=['da', 'nl'])
|
||||
data.show_dimensions()
|
||||
lX, ly = data.training()
|
||||
lXte, lyte = data.test()
|
||||
|
|
@ -189,8 +192,8 @@ if __name__ == '__main__':
|
|||
default=25)
|
||||
|
||||
parser.add_argument('--patience_bert', dest='patience_bert', type=int, metavar='',
|
||||
help='set early stop patience for the BertGen, default 5',
|
||||
default=5)
|
||||
help='set early stop patience for the BertGen, default 3',
|
||||
default=3)
|
||||
|
||||
parser.add_argument('--batch_rnn', dest='batch_rnn', type=int, metavar='',
|
||||
help='set batchsize for the RecurrentGen, default 64',
|
||||
|
|
|
|||
18
run.sh
18
run.sh
|
|
@ -2,15 +2,15 @@
|
|||
|
||||
echo Running Zero-shot experiments [output at csv_logs/gfun/zero_shot_gfun.csv]
|
||||
|
||||
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da --n_jobs 3
|
||||
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de --n_jobs 3
|
||||
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en --n_jobs 3
|
||||
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es --n_jobs 3
|
||||
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr --n_jobs 3
|
||||
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it --n_jobs 3
|
||||
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl --n_jobs 3
|
||||
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt --n_jobs 3
|
||||
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt sv --n_jobs 3
|
||||
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da --n_jobs 6
|
||||
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de --n_jobs 6
|
||||
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en --n_jobs 6
|
||||
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es --n_jobs 6
|
||||
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr --n_jobs 6
|
||||
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it --n_jobs 6
|
||||
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl --n_jobs 6
|
||||
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt --n_jobs 6
|
||||
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt sv --n_jobs 6
|
||||
|
||||
|
||||
#for i in {0..10..1}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import time
|
||||
import src.util.disable_sklearn_warnings
|
||||
|
||||
import numpy as np
|
||||
from joblib import Parallel, delayed
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
import warnings
|
||||
|
||||
|
||||
def warn(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
warnings.warn = warn
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
import numpy as np
|
||||
import src.util.disable_sklearn_warnings
|
||||
|
||||
|
||||
class StandardizeTransformer:
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ This module contains the view generators that take care of computing the view sp
|
|||
- View generator (-b): generates document embedding via mBERT model.
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
import src.util.disable_sklearn_warnings
|
||||
# from time import time
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
|
|
@ -317,7 +318,7 @@ class RecurrentGen(ViewGen):
|
|||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||
self.multilingualIndex.embedding_matrices(self.pretrained, supervised=self.wce)
|
||||
self.model = self._init_model()
|
||||
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False)
|
||||
self.logger = TensorBoardLogger(save_dir='tb_logs', name='rnn', default_hp_metric=False)
|
||||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||
patience=self.patience, verbose=False, mode='max')
|
||||
self.lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||
|
|
@ -446,7 +447,7 @@ class BertGen(ViewGen):
|
|||
self.stored_path = stored_path
|
||||
self.model = self._init_model()
|
||||
self.patience = patience
|
||||
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='bert', default_hp_metric=False)
|
||||
self.logger = TensorBoardLogger(save_dir='tb_logs', name='bert', default_hp_metric=False)
|
||||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||
patience=self.patience, verbose=False, mode='max')
|
||||
|
||||
|
|
@ -469,7 +470,7 @@ class BertGen(ViewGen):
|
|||
:param ly: dict {lang: target vectors}
|
||||
:return: self.
|
||||
"""
|
||||
print('\n# Fitting BertGen (M)...')
|
||||
print('\n# Fitting BertGen (B)...')
|
||||
create_if_not_exist(self.logger.save_dir)
|
||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
|
||||
|
|
|
|||
Loading…
Reference in New Issue