implemented zero-shot experiment code for VanillaFunGen and WordClassGen
This commit is contained in:
parent
495a0b6af9
commit
a6be7857a3
9
main.py
9
main.py
|
|
@ -7,6 +7,9 @@ from src.util.evaluation import evaluate
|
||||||
from src.util.results_csv import CSVlog
|
from src.util.results_csv import CSVlog
|
||||||
from src.view_generators import *
|
from src.view_generators import *
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
assert args.post_embedder or args.muse_embedder or args.wce_embedder or args.gru_embedder or args.bert_embedder, \
|
assert args.post_embedder or args.muse_embedder or args.wce_embedder or args.gru_embedder or args.bert_embedder, \
|
||||||
|
|
@ -17,7 +20,7 @@ def main(args):
|
||||||
print('Running generalized funnelling...')
|
print('Running generalized funnelling...')
|
||||||
|
|
||||||
data = MultilingualDataset.load(args.dataset)
|
data = MultilingualDataset.load(args.dataset)
|
||||||
data.set_view(languages=['da'])
|
# data.set_view(languages=['da', 'nl'])
|
||||||
data.show_dimensions()
|
data.show_dimensions()
|
||||||
lX, ly = data.training()
|
lX, ly = data.training()
|
||||||
lXte, lyte = data.test()
|
lXte, lyte = data.test()
|
||||||
|
|
@ -189,8 +192,8 @@ if __name__ == '__main__':
|
||||||
default=25)
|
default=25)
|
||||||
|
|
||||||
parser.add_argument('--patience_bert', dest='patience_bert', type=int, metavar='',
|
parser.add_argument('--patience_bert', dest='patience_bert', type=int, metavar='',
|
||||||
help='set early stop patience for the BertGen, default 5',
|
help='set early stop patience for the BertGen, default 3',
|
||||||
default=5)
|
default=3)
|
||||||
|
|
||||||
parser.add_argument('--batch_rnn', dest='batch_rnn', type=int, metavar='',
|
parser.add_argument('--batch_rnn', dest='batch_rnn', type=int, metavar='',
|
||||||
help='set batchsize for the RecurrentGen, default 64',
|
help='set batchsize for the RecurrentGen, default 64',
|
||||||
|
|
|
||||||
18
run.sh
18
run.sh
|
|
@ -2,15 +2,15 @@
|
||||||
|
|
||||||
echo Running Zero-shot experiments [output at csv_logs/gfun/zero_shot_gfun.csv]
|
echo Running Zero-shot experiments [output at csv_logs/gfun/zero_shot_gfun.csv]
|
||||||
|
|
||||||
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da --n_jobs 3
|
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da --n_jobs 6
|
||||||
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de --n_jobs 3
|
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de --n_jobs 6
|
||||||
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en --n_jobs 3
|
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en --n_jobs 6
|
||||||
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es --n_jobs 3
|
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es --n_jobs 6
|
||||||
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr --n_jobs 3
|
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr --n_jobs 6
|
||||||
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it --n_jobs 3
|
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it --n_jobs 6
|
||||||
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl --n_jobs 3
|
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl --n_jobs 6
|
||||||
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt --n_jobs 3
|
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt --n_jobs 6
|
||||||
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt sv --n_jobs 3
|
python main.py ../datasets/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -b -c --nepochs_bert 10 --batch_bert 8 --gpus 0 --muse_dir ../embeddings/MUSE/ -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt sv --n_jobs 6
|
||||||
|
|
||||||
|
|
||||||
#for i in {0..10..1}
|
#for i in {0..10..1}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import time
|
import time
|
||||||
|
import src.util.disable_sklearn_warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from joblib import Parallel, delayed
|
from joblib import Parallel, delayed
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
def warn(*args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
warnings.warn = warn
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import src.util.disable_sklearn_warnings
|
||||||
|
|
||||||
|
|
||||||
class StandardizeTransformer:
|
class StandardizeTransformer:
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ This module contains the view generators that take care of computing the view sp
|
||||||
- View generator (-b): generates document embedding via mBERT model.
|
- View generator (-b): generates document embedding via mBERT model.
|
||||||
"""
|
"""
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
import src.util.disable_sklearn_warnings
|
||||||
# from time import time
|
# from time import time
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
|
|
@ -317,7 +318,7 @@ class RecurrentGen(ViewGen):
|
||||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||||
self.multilingualIndex.embedding_matrices(self.pretrained, supervised=self.wce)
|
self.multilingualIndex.embedding_matrices(self.pretrained, supervised=self.wce)
|
||||||
self.model = self._init_model()
|
self.model = self._init_model()
|
||||||
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False)
|
self.logger = TensorBoardLogger(save_dir='tb_logs', name='rnn', default_hp_metric=False)
|
||||||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||||
patience=self.patience, verbose=False, mode='max')
|
patience=self.patience, verbose=False, mode='max')
|
||||||
self.lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
self.lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||||
|
|
@ -446,7 +447,7 @@ class BertGen(ViewGen):
|
||||||
self.stored_path = stored_path
|
self.stored_path = stored_path
|
||||||
self.model = self._init_model()
|
self.model = self._init_model()
|
||||||
self.patience = patience
|
self.patience = patience
|
||||||
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='bert', default_hp_metric=False)
|
self.logger = TensorBoardLogger(save_dir='tb_logs', name='bert', default_hp_metric=False)
|
||||||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||||
patience=self.patience, verbose=False, mode='max')
|
patience=self.patience, verbose=False, mode='max')
|
||||||
|
|
||||||
|
|
@ -469,7 +470,7 @@ class BertGen(ViewGen):
|
||||||
:param ly: dict {lang: target vectors}
|
:param ly: dict {lang: target vectors}
|
||||||
:return: self.
|
:return: self.
|
||||||
"""
|
"""
|
||||||
print('\n# Fitting BertGen (M)...')
|
print('\n# Fitting BertGen (B)...')
|
||||||
create_if_not_exist(self.logger.save_dir)
|
create_if_not_exist(self.logger.save_dir)
|
||||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
|
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue