implemented zero-shot experiment code for VanillaFunGen and WordClassGen
This commit is contained in:
parent
7affa1fab4
commit
8968570d82
19
main.py
19
main.py
|
|
@ -17,14 +17,11 @@ def main(args):
|
||||||
print('Running generalized funnelling...')
|
print('Running generalized funnelling...')
|
||||||
|
|
||||||
data = MultilingualDataset.load(args.dataset)
|
data = MultilingualDataset.load(args.dataset)
|
||||||
data.set_view(languages=['nl'])
|
data.set_view(languages=['da', 'nl', 'it'])
|
||||||
data.show_dimensions()
|
data.show_dimensions()
|
||||||
lX, ly = data.training()
|
lX, ly = data.training()
|
||||||
lXte, lyte = data.test()
|
lXte, lyte = data.test()
|
||||||
|
|
||||||
zero_shot = args.zero_shot
|
|
||||||
zscl_train_langs = args.zscl_langs
|
|
||||||
|
|
||||||
# Init multilingualIndex - mandatory when deploying Neural View Generators...
|
# Init multilingualIndex - mandatory when deploying Neural View Generators...
|
||||||
if args.gru_embedder or args.bert_embedder:
|
if args.gru_embedder or args.bert_embedder:
|
||||||
multilingualIndex = MultilingualIndex()
|
multilingualIndex = MultilingualIndex()
|
||||||
|
|
@ -34,29 +31,31 @@ def main(args):
|
||||||
# Init ViewGenerators and append them to embedder_list
|
# Init ViewGenerators and append them to embedder_list
|
||||||
embedder_list = []
|
embedder_list = []
|
||||||
if args.post_embedder:
|
if args.post_embedder:
|
||||||
posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), n_jobs=args.n_jobs)
|
posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True),
|
||||||
|
zero_shot=args.zero_shot, train_langs=args.zscl_langs,
|
||||||
|
n_jobs=args.n_jobs)
|
||||||
embedder_list.append(posteriorEmbedder)
|
embedder_list.append(posteriorEmbedder)
|
||||||
|
|
||||||
if args.muse_embedder:
|
if args.muse_embedder:
|
||||||
museEmbedder = MuseGen(muse_dir=args.muse_dir, n_jobs=args.n_jobs,
|
museEmbedder = MuseGen(muse_dir=args.muse_dir, n_jobs=args.n_jobs,
|
||||||
zero_shot=zero_shot, train_langs=zscl_train_langs)
|
zero_shot=args.zero_shot, train_langs=args.zscl_langs)
|
||||||
embedder_list.append(museEmbedder)
|
embedder_list.append(museEmbedder)
|
||||||
|
|
||||||
if args.wce_embedder:
|
if args.wce_embedder:
|
||||||
wceEmbedder = WordClassGen(n_jobs=args.n_jobs,
|
wceEmbedder = WordClassGen(n_jobs=args.n_jobs,
|
||||||
zero_shot=zero_shot, train_langs=zscl_train_langs)
|
zero_shot=args.zero_shot, train_langs=args.zscl_langs)
|
||||||
embedder_list.append(wceEmbedder)
|
embedder_list.append(wceEmbedder)
|
||||||
|
|
||||||
if args.gru_embedder:
|
if args.gru_embedder:
|
||||||
rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=args.rnn_wce,
|
rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=args.rnn_wce,
|
||||||
batch_size=args.batch_rnn, nepochs=args.nepochs_rnn, patience=args.patience_rnn,
|
batch_size=args.batch_rnn, nepochs=args.nepochs_rnn, patience=args.patience_rnn,
|
||||||
zero_shot=zero_shot, train_langs=zscl_train_langs,
|
zero_shot=args.zero_shot, train_langs=args.zscl_langs,
|
||||||
gpus=args.gpus, n_jobs=args.n_jobs)
|
gpus=args.gpus, n_jobs=args.n_jobs)
|
||||||
embedder_list.append(rnnEmbedder)
|
embedder_list.append(rnnEmbedder)
|
||||||
|
|
||||||
if args.bert_embedder:
|
if args.bert_embedder:
|
||||||
bertEmbedder = BertGen(multilingualIndex, batch_size=args.batch_bert, nepochs=args.nepochs_bert,
|
bertEmbedder = BertGen(multilingualIndex, batch_size=args.batch_bert, nepochs=args.nepochs_bert,
|
||||||
zero_shot=zero_shot, train_langs=zscl_train_langs,
|
zero_shot=args.zero_shot, train_langs=args.zscl_langs,
|
||||||
patience=args.patience_bert, gpus=args.gpus, n_jobs=args.n_jobs)
|
patience=args.patience_bert, gpus=args.gpus, n_jobs=args.n_jobs)
|
||||||
embedder_list.append(bertEmbedder)
|
embedder_list.append(bertEmbedder)
|
||||||
|
|
||||||
|
|
@ -109,7 +108,7 @@ def main(args):
|
||||||
microf1=microf1,
|
microf1=microf1,
|
||||||
macrok=macrok,
|
macrok=macrok,
|
||||||
microk=microk,
|
microk=microk,
|
||||||
notes=f'Train langs: {sorted(zscl_train_langs)}' if zero_shot else '')
|
notes=f'Train langs: {sorted(args.zscl_langs)}' if args.zero_shot else '')
|
||||||
print('Averages: MF1, mF1, MK, mK', np.round(np.mean(np.array(metrics), axis=0), 3))
|
print('Averages: MF1, mF1, MK, mK', np.round(np.mean(np.array(metrics), axis=0), 3))
|
||||||
|
|
||||||
overall_time = round(time.time() - time_init, 3)
|
overall_time = round(time.time() - time_init, 3)
|
||||||
|
|
|
||||||
18
run.sh
18
run.sh
|
|
@ -2,15 +2,15 @@
|
||||||
|
|
||||||
echo Running Zero-shot experiments [output at csv_logs/gfun/zero_shot_gfun.csv]
|
echo Running Zero-shot experiments [output at csv_logs/gfun/zero_shot_gfun.csv]
|
||||||
|
|
||||||
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da --n_jobs 6
|
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da --n_jobs 3
|
||||||
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de --n_jobs 6
|
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de --n_jobs 3
|
||||||
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en --n_jobs 6
|
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en --n_jobs 3
|
||||||
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es --n_jobs 6
|
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es --n_jobs 3
|
||||||
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr --n_jobs 6
|
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr --n_jobs 3
|
||||||
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it --n_jobs 6
|
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it --n_jobs 3
|
||||||
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl --n_jobs 6
|
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl --n_jobs 3
|
||||||
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt --n_jobs 6
|
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt --n_jobs 3
|
||||||
python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt sv --n_jobs 6
|
#python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -x -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt sv --n_jobs 3
|
||||||
|
|
||||||
|
|
||||||
#for i in {0..10..1}
|
#for i in {0..10..1}
|
||||||
|
|
|
||||||
|
|
@ -77,15 +77,17 @@ class VanillaFunGen(ViewGen):
|
||||||
train_langs = ['it']
|
train_langs = ['it']
|
||||||
self.train_langs = train_langs
|
self.train_langs = train_langs
|
||||||
|
|
||||||
def fit(self, lX, lY):
|
def fit(self, lX, ly):
|
||||||
print('# Fitting VanillaFunGen (X)...')
|
print('# Fitting VanillaFunGen (X)...')
|
||||||
if self.zero_shot:
|
if self.zero_shot:
|
||||||
|
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||||
self.langs = sorted(self.train_langs)
|
self.langs = sorted(self.train_langs)
|
||||||
lX = self.zero_shot_experiments(lX)
|
lX = self.zero_shot_experiments(lX)
|
||||||
|
ly = self.zero_shot_experiments(ly)
|
||||||
lX = self.vectorizer.fit_transform(lX)
|
lX = self.vectorizer.fit_transform(lX)
|
||||||
else:
|
else:
|
||||||
lX = self.vectorizer.fit_transform(lX)
|
lX = self.vectorizer.fit_transform(lX)
|
||||||
self.doc_projector.fit(lX, lY)
|
self.doc_projector.fit(lX, ly)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def transform(self, lX):
|
def transform(self, lX):
|
||||||
|
|
@ -104,7 +106,6 @@ class VanillaFunGen(ViewGen):
|
||||||
return self.fit(lX, ly).transform(lX)
|
return self.fit(lX, ly).transform(lX)
|
||||||
|
|
||||||
def zero_shot_experiments(self, lX):
|
def zero_shot_experiments(self, lX):
|
||||||
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
|
||||||
_lX = {}
|
_lX = {}
|
||||||
for lang in self.langs:
|
for lang in self.langs:
|
||||||
if lang in self.train_langs:
|
if lang in self.train_langs:
|
||||||
|
|
@ -150,6 +151,8 @@ class MuseGen(ViewGen):
|
||||||
:return: self.
|
:return: self.
|
||||||
"""
|
"""
|
||||||
print('# Fitting MuseGen (M)...')
|
print('# Fitting MuseGen (M)...')
|
||||||
|
if self.zero_shot:
|
||||||
|
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||||
self.vectorizer.fit(lX)
|
self.vectorizer.fit(lX)
|
||||||
self.langs = sorted(lX.keys())
|
self.langs = sorted(lX.keys())
|
||||||
self.lMuse = MuseLoader(langs=self.langs, cache=self.muse_dir)
|
self.lMuse = MuseLoader(langs=self.langs, cache=self.muse_dir)
|
||||||
|
|
@ -181,7 +184,6 @@ class MuseGen(ViewGen):
|
||||||
return self.fit(lX, ly).transform(lX)
|
return self.fit(lX, ly).transform(lX)
|
||||||
|
|
||||||
def zero_shot_experiments(self, lX):
|
def zero_shot_experiments(self, lX):
|
||||||
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
|
||||||
_lX = {}
|
_lX = {}
|
||||||
for lang in self.langs:
|
for lang in self.langs:
|
||||||
if lang in self.train_langs:
|
if lang in self.train_langs:
|
||||||
|
|
@ -226,6 +228,7 @@ class WordClassGen(ViewGen):
|
||||||
"""
|
"""
|
||||||
print('# Fitting WordClassGen (W)...')
|
print('# Fitting WordClassGen (W)...')
|
||||||
if self.zero_shot:
|
if self.zero_shot:
|
||||||
|
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||||
self.langs = sorted(self.train_langs)
|
self.langs = sorted(self.train_langs)
|
||||||
lX = self.zero_shot_experiments(lX)
|
lX = self.zero_shot_experiments(lX)
|
||||||
lX = self.vectorizer.fit_transform(lX)
|
lX = self.vectorizer.fit_transform(lX)
|
||||||
|
|
@ -257,7 +260,6 @@ class WordClassGen(ViewGen):
|
||||||
return self.fit(lX, ly).transform(lX)
|
return self.fit(lX, ly).transform(lX)
|
||||||
|
|
||||||
def zero_shot_experiments(self, lX):
|
def zero_shot_experiments(self, lX):
|
||||||
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
|
||||||
_lX = {}
|
_lX = {}
|
||||||
for lang in self.langs:
|
for lang in self.langs:
|
||||||
if lang in self.train_langs:
|
if lang in self.train_langs:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue