implemented zero-shot experiment code for VanillaFunGen and WordClassGen

This commit is contained in:
andrea 2021-02-04 12:44:36 +01:00
parent 8968570d82
commit f3fafd0f00
3 changed files with 8 additions and 8 deletions

View File

@ -62,7 +62,7 @@ def main(args):
# Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier # Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True) docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'), meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'),
meta_parameters=get_params(optimc=args.optimc)) meta_parameters=get_params(optimc=args.optimc), n_jobs=args.n_jobs)
# Init Funnelling Architecture # Init Funnelling Architecture
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta) gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta)
@ -80,7 +80,7 @@ def main(args):
if args.zero_shot: if args.zero_shot:
gfun.set_zero_shot(val=False) gfun.set_zero_shot(val=False)
ly_ = gfun.predict(lXte) ly_ = gfun.predict(lXte)
l_eval = evaluate(ly_true=lyte, ly_pred=ly_) l_eval = evaluate(ly_true=lyte, ly_pred=ly_, n_jobs=args.n_jobs)
time_te = round(time.time() - time_te, 3) time_te = round(time.time() - time_te, 3)
print(f'Testing completed in {time_te} seconds!') print(f'Testing completed in {time_te} seconds!')

View File

@ -23,7 +23,7 @@ class DocEmbedderList:
if isinstance(embedder, VanillaFunGen): if isinstance(embedder, VanillaFunGen):
_tmp.append(embedder) _tmp.append(embedder)
else: else:
_tmp.append(FeatureSet2Posteriors(embedder)) _tmp.append(FeatureSet2Posteriors(embedder, n_jobs=embedder.n_jobs))
self.embedders = _tmp self.embedders = _tmp
def fit(self, lX, ly): def fit(self, lX, ly):

View File

@ -78,7 +78,7 @@ class VanillaFunGen(ViewGen):
self.train_langs = train_langs self.train_langs = train_langs
def fit(self, lX, ly): def fit(self, lX, ly):
print('# Fitting VanillaFunGen (X)...') print('\n# Fitting VanillaFunGen (X)...')
if self.zero_shot: if self.zero_shot:
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
self.langs = sorted(self.train_langs) self.langs = sorted(self.train_langs)
@ -150,7 +150,7 @@ class MuseGen(ViewGen):
:param ly: dict {lang: target vectors} :param ly: dict {lang: target vectors}
:return: self. :return: self.
""" """
print('# Fitting MuseGen (M)...') print('\n# Fitting MuseGen (M)...')
if self.zero_shot: if self.zero_shot:
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
self.vectorizer.fit(lX) self.vectorizer.fit(lX)
@ -226,7 +226,7 @@ class WordClassGen(ViewGen):
:param ly: dict {lang: target vectors} :param ly: dict {lang: target vectors}
:return: self. :return: self.
""" """
print('# Fitting WordClassGen (W)...') print('\n# Fitting WordClassGen (W)...')
if self.zero_shot: if self.zero_shot:
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
self.langs = sorted(self.train_langs) self.langs = sorted(self.train_langs)
@ -359,7 +359,7 @@ class RecurrentGen(ViewGen):
:param ly: dict {lang: target vectors} :param ly: dict {lang: target vectors}
:return: self. :return: self.
""" """
print('# Fitting RecurrentGen (G)...') print('\n# Fitting RecurrentGen (G)...')
create_if_not_exist(self.logger.save_dir) create_if_not_exist(self.logger.save_dir)
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs, recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs,
zero_shot=self.zero_shot, zscl_langs=self.train_langs) zero_shot=self.zero_shot, zscl_langs=self.train_langs)
@ -468,7 +468,7 @@ class BertGen(ViewGen):
:param ly: dict {lang: target vectors} :param ly: dict {lang: target vectors}
:return: self. :return: self.
""" """
print('# Fitting BertGen (M)...') print('\n# Fitting BertGen (M)...')
create_if_not_exist(self.logger.save_dir) create_if_not_exist(self.logger.save_dir)
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1) self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512, bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,