fixed typos + n_jobs across code (still missing one wrt brach 'rsc')
This commit is contained in:
parent
ec050dce7b
commit
59146f0dda
5
main.py
5
main.py
|
|
@ -54,10 +54,11 @@ def main(args):
|
||||||
# Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier
|
# Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier
|
||||||
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
|
docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True)
|
||||||
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'),
|
meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'),
|
||||||
meta_parameters=get_params(optimc=args.optimc))
|
meta_parameters=get_params(optimc=args.optimc),
|
||||||
|
n_jobs=args.n_jobs)
|
||||||
|
|
||||||
# Init Funnelling Architecture
|
# Init Funnelling Architecture
|
||||||
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta)
|
gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta, n_jobs=args.n_jobs)
|
||||||
|
|
||||||
# Training ---------------------------------------
|
# Training ---------------------------------------
|
||||||
print('\n[Training Generalized Funnelling]')
|
print('\n[Training Generalized Funnelling]')
|
||||||
|
|
|
||||||
|
|
@ -361,7 +361,7 @@ class BertGen(ViewGen):
|
||||||
:param ly: dict {lang: target vectors}
|
:param ly: dict {lang: target vectors}
|
||||||
:return: self.
|
:return: self.
|
||||||
"""
|
"""
|
||||||
print('# Fitting BertGen (M)...')
|
print('# Fitting BertGen (B)...')
|
||||||
create_if_not_exist(self.logger.save_dir)
|
create_if_not_exist(self.logger.save_dir)
|
||||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue