diff --git a/main.py b/main.py index fd73306..2ee7175 100644 --- a/main.py +++ b/main.py @@ -17,7 +17,7 @@ def main(args): print('Running generalized funnelling...') data = MultilingualDataset.load(args.dataset) - data.set_view(languages=['da', 'nl', 'it']) + data.set_view(languages=['da']) data.show_dimensions() lX, ly = data.training() lXte, lyte = data.test() @@ -32,40 +32,64 @@ def main(args): embedder_list = [] if args.post_embedder: posteriorEmbedder = VanillaFunGen(base_learner=get_learner(calibrate=True), - zero_shot=args.zero_shot, train_langs=args.zscl_langs, + zero_shot=args.zero_shot, + train_langs=args.zscl_langs, n_jobs=args.n_jobs) + embedder_list.append(posteriorEmbedder) if args.muse_embedder: - museEmbedder = MuseGen(muse_dir=args.muse_dir, n_jobs=args.n_jobs, - zero_shot=args.zero_shot, train_langs=args.zscl_langs) + museEmbedder = MuseGen(muse_dir=args.muse_dir, + zero_shot=args.zero_shot, + train_langs=args.zscl_langs, + n_jobs=args.n_jobs) + embedder_list.append(museEmbedder) if args.wce_embedder: - wceEmbedder = WordClassGen(n_jobs=args.n_jobs, - zero_shot=args.zero_shot, train_langs=args.zscl_langs) + wceEmbedder = WordClassGen(zero_shot=args.zero_shot, + train_langs=args.zscl_langs, + n_jobs=args.n_jobs) + embedder_list.append(wceEmbedder) if args.gru_embedder: - rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=args.rnn_wce, - batch_size=args.batch_rnn, nepochs=args.nepochs_rnn, patience=args.patience_rnn, - zero_shot=args.zero_shot, train_langs=args.zscl_langs, - gpus=args.gpus, n_jobs=args.n_jobs) + rnnEmbedder = RecurrentGen(multilingualIndex, + pretrained_embeddings=lMuse, + wce=args.rnn_wce, + batch_size=args.batch_rnn, + nepochs=args.nepochs_rnn, + patience=args.patience_rnn, + zero_shot=args.zero_shot, + train_langs=args.zscl_langs, + gpus=args.gpus, + n_jobs=args.n_jobs) + embedder_list.append(rnnEmbedder) if args.bert_embedder: - bertEmbedder = BertGen(multilingualIndex, batch_size=args.batch_bert, nepochs=args.nepochs_bert, - zero_shot=args.zero_shot, train_langs=args.zscl_langs, - patience=args.patience_bert, gpus=args.gpus, n_jobs=args.n_jobs) + bertEmbedder = BertGen(multilingualIndex, + batch_size=args.batch_bert, + nepochs=args.nepochs_bert, + patience=args.patience_bert, + zero_shot=args.zero_shot, + train_langs=args.zscl_langs, + gpus=args.gpus, + n_jobs=args.n_jobs) + embedder_list.append(bertEmbedder) # Init DocEmbedderList (i.e., first-tier learners or view generators) and metaclassifier docEmbedders = DocEmbedderList(embedder_list=embedder_list, probabilistic=True) + meta = MetaClassifier(meta_learner=get_learner(calibrate=False, kernel='rbf'), - meta_parameters=get_params(optimc=args.optimc), n_jobs=args.n_jobs) + meta_parameters=get_params(optimc=args.optimc), + n_jobs=args.n_jobs) # Init Funnelling Architecture - gfun = Funnelling(first_tier=docEmbedders, meta_classifier=meta) + gfun = Funnelling(first_tier=docEmbedders, + meta_classifier=meta, + n_jobs=args.n_jobs) # Training --------------------------------------- print('\n[Training Generalized Funnelling]') diff --git a/src/funnelling.py b/src/funnelling.py index 4860a24..457ee89 100644 --- a/src/funnelling.py +++ b/src/funnelling.py @@ -116,9 +116,9 @@ class Funnelling: self.n_jobs = n_jobs def fit(self, lX, ly): - print('## Fitting first-tier learners!') + print('\n## Fitting first-tier learners!') lZ = self.first_tier.fit_transform(lX, ly) - print('## Fitting meta-learner!') + print('\n## Fitting meta-learner!') self.meta.fit(lZ, ly) def predict(self, lX): diff --git a/src/models/learners.py b/src/models/learners.py index 46737c6..25fc16b 100644 --- a/src/models/learners.py +++ b/src/models/learners.py @@ -74,7 +74,7 @@ class NaivePolylingualClassifier: _sort_if_sparse(lX[lang]) models = Parallel(n_jobs=self.n_jobs)\ - (delayed(MonolingualClassifier(self.base_learner, parameters=self.parameters).fit)((lX[lang]), ly[lang]) for + (delayed(MonolingualClassifier(self.base_learner, parameters=self.parameters, n_jobs=self.n_jobs).fit)((lX[lang]), ly[lang]) for lang in langs) self.model = {lang: models[i] for i, lang in enumerate(langs)} diff --git a/src/view_generators.py b/src/view_generators.py index 673453e..3a3ff5d 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -69,7 +69,8 @@ class VanillaFunGen(ViewGen): self.first_tier_parameters = first_tier_parameters self.n_jobs = n_jobs self.doc_projector = NaivePolylingualClassifier(base_learner=self.learners, - parameters=self.first_tier_parameters, n_jobs=self.n_jobs) + parameters=self.first_tier_parameters, + n_jobs=self.n_jobs) self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True) # Zero shot parameters self.zero_shot = zero_shot