implemented zero-shot experiment code for VanillaFunGen and WordClassGen

This commit is contained in:
andrea 2021-02-04 12:24:57 +01:00
parent 7affa1fab4
commit 8968570d82
3 changed files with 25 additions and 24 deletions

19
main.py
View File

@ -17,14 +17,11 @@ def main(args):
print('Running generalized funnelling...')
data = MultilingualDataset.load(args.dataset)
data.set_view(languages=['nl'])
data.set_view(languages=['da', 'nl', 'it'])
data.show_dimensions()
lX, ly = data.training()
lXte, lyte = data.test()
zero_shot = args.zero_shot
zscl_train_langs = args.zscl_langs
# Init multilingualIndex - mandatory when deploying Neural View Generators...
if args.gru_embedder or args.bert_embedder:
multilingualIndex = MultilingualIndex()
@ -34,29 +31,31 @@ def main(args):
# Init ViewGenerators and append them to embedder_list
embedder_list = []
if args.post_embedder:
posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=args.n_jobs)
posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True),
zero_shot=args.zero_shot, train_langs=args.zscl_langs,
n_jobs=args.n_jobs)
embedder_list.append(posteriorEmbedder)
if args.muse_embedder:
museEmbedder = MuseGen(muse_dir=args.muse_dir, n_jobs=args.n_jobs,
zero_shot=zero_shot, train_langs=zscl_train_langs)
zero_shot=args.zero_shot, train_langs=args.zscl_langs)
embedder_list.append(museEmbedder)
if args.wce_embedder:
wceEmbedder = WordClassGen(n_jobs=args.n_jobs,
zero_shot=zero_shot, train_langs=zscl_train_langs)
zero_shot=args.zero_shot, train_langs=args.zscl_langs)
embedder_list.append(wceEmbedder)
if args.gru_embedder:
rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=args.rnn_wce,
batch_size=args.batch_rnn, nepochs=args.nepochs_rnn, patience=args.patience_rnn,
zero_shot=zero_shot, train_langs=zscl_train_langs,
zero_shot=args.zero_shot, train_langs=args.zscl_langs,
gpus=args.gpus, n_jobs=args.n_jobs)
embedder_list.append(rnnEmbedder)
if args.bert_embedder:
bertEmbedder = BertGen(multilingualIndex, batch_size=args.batch_bert, nepochs=args.nepochs_bert,
zero_shot=zero_shot, train_langs=zscl_train_langs,
zero_shot=args.zero_shot, train_langs=args.zscl_langs,
patience=args.patience_bert, gpus=args.gpus, n_jobs=args.n_jobs)
embedder_list.append(bertEmbedder)
@ -109,7 +108,7 @@ def main(args):
microf1=microf1,
macrok=macrok,
microk=microk,
notes=f'Train langs: {sorted(zscl_train_langs)}' if zero_shot else '')
notes=f'Train langs: {sorted(args.zscl_langs)}' if args.zero_shot else '')
print('Averages: MF1, mF1, MK, mK', np.round(np.mean(np.array(metrics), axis=0), 3))
overall_time = round(time.time() - time_init, 3)

18
run.sh
View File

@ -2,15 +2,15 @@
echo Running Zero-shot experiments [output at csv_logs/gfun/zero_shot_gfun.csv]
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da --n_jobs 6
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de --n_jobs 6
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en --n_jobs 6
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es --n_jobs 6
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr --n_jobs 6
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it --n_jobs 6
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl --n_jobs 6
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt --n_jobs 6
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt sv --n_jobs 6
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da --n_jobs 3
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de --n_jobs 3
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en --n_jobs 3
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es --n_jobs 3
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr --n_jobs 3
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it --n_jobs 3
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl --n_jobs 3
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt --n_jobs 3
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt sv --n_jobs 3
#for i in {0..10..1}

View File

@ -77,15 +77,17 @@ class VanillaFunGen(ViewGen):
train_langs = ['it']
self.train_langs = train_langs
def fit(self, lX, lY):
def fit(self, lX, ly):
print('# Fitting VanillaFunGen (X)...')
if self.zero_shot:
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
self.langs = sorted(self.train_langs)
lX = self.zero_shot_experiments(lX)
ly = self.zero_shot_experiments(ly)
lX = self.vectorizer.fit_transform(lX)
else:
lX = self.vectorizer.fit_transform(lX)
self.doc_projector.fit(lX, lY)
self.doc_projector.fit(lX, ly)
return self
def transform(self, lX):
@ -104,7 +106,6 @@ class VanillaFunGen(ViewGen):
return self.fit(lX, ly).transform(lX)
def zero_shot_experiments(self, lX):
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
_lX = {}
for lang in self.langs:
if lang in self.train_langs:
@ -150,6 +151,8 @@ class MuseGen(ViewGen):
:return: self.
"""
print('# Fitting MuseGen (M)...')
if self.zero_shot:
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
self.vectorizer.fit(lX)
self.langs = sorted(lX.keys())
self.lMuse = MuseLoader(langs=self.langs, cache=self.muse_dir)
@ -181,7 +184,6 @@ class MuseGen(ViewGen):
return self.fit(lX, ly).transform(lX)
def zero_shot_experiments(self, lX):
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
_lX = {}
for lang in self.langs:
if lang in self.train_langs:
@ -226,6 +228,7 @@ class WordClassGen(ViewGen):
"""
print('# Fitting WordClassGen (W)...')
if self.zero_shot:
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
self.langs = sorted(self.train_langs)
lX = self.zero_shot_experiments(lX)
lX = self.vectorizer.fit_transform(lX)
@ -257,7 +260,6 @@ class WordClassGen(ViewGen):
return self.fit(lX, ly).transform(lX)
def zero_shot_experiments(self, lX):
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
_lX = {}
for lang in self.langs:
if lang in self.train_langs: