fixed typos + n_jobs across code (still missing one wrt brach 'rsc')

This commit is contained in:
andrea 2021-02-04 16:52:05 +01:00
parent ec050dce7b
commit 59146f0dda
2 changed files with 4 additions and 3 deletions

View File

@ -54,10 +54,11 @@ def main(args):
# Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier # Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True) docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'), meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'),
meta_parameters=get_params(optimc=args.optimc)) meta_parameters=get_params(optimc=args.optimc),
n_jobs=args.n_jobs)
# Init Funnelling Architecture # Init Funnelling Architecture
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta) gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta, n_jobs=args.n_jobs)
# Training --------------------------------------- # Training ---------------------------------------
print('\n[Training Generalized Funnelling]') print('\n[Training Generalized Funnelling]')

View File

@ -361,7 +361,7 @@ class BertGen(ViewGen):
:param ly: dict {lang: target vectors} :param ly: dict {lang: target vectors}
:return: self. :return: self.
""" """
print('# Fitting BertGen (M)...') print('# Fitting BertGen (B)...')
create_if_not_exist(self.logger.save_dir) create_if_not_exist(self.logger.save_dir)
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1) self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512) bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)