diff --git a/main.py b/main.py index a521fef..c689547 100644 --- a/main.py +++ b/main.py @@ -11,17 +11,19 @@ from src.view_generators import * def main(args): assert args.post_embedder or args.muse_embedder or args.wce_embedder or args.gru_embedder or args.bert_embedder, \ 'empty set of document embeddings is not allowed!' + assert not (args.zero_shot and (args.zscl_langs is None)), \ + '--zscl_langs cannot be empty when setting --zero_shot to True' print('Running generalized funnelling...') data = MultilingualDataset.load(args.dataset) - data.set_view(languages=['it', 'da', 'nl']) + data.set_view(languages=['nl']) data.show_dimensions() lX, ly = data.training() lXte, lyte = data.test() - zero_shot = True - zscl_train_langs = ['it'] # Todo: testing zero shot + zero_shot = args.zero_shot + zscl_train_langs = args.zscl_langs # Init multilingualIndex - mandatory when deploying Neural View Generators... if args.gru_embedder or args.bert_embedder: @@ -37,24 +39,24 @@ def main(args): if args.muse_embedder: museEmbedder = MuseGen(muse_dir=args.muse_dir, n_jobs=args.n_jobs, - zero_shot=zero_shot, train_langs=zscl_train_langs) # Todo: testing zero shot + zero_shot=zero_shot, train_langs=zscl_train_langs) embedder_list.append(museEmbedder) if args.wce_embedder: wceEmbedder = WordClassGen(n_jobs=args.n_jobs, - zero_shot=zero_shot, train_langs=zscl_train_langs) # Todo: testing zero shot + zero_shot=zero_shot, train_langs=zscl_train_langs) embedder_list.append(wceEmbedder) if args.gru_embedder: rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=args.rnn_wce, batch_size=args.batch_rnn, nepochs=args.nepochs_rnn, patience=args.patience_rnn, - zero_shot=zero_shot, train_langs=zscl_train_langs, # Todo: testing zero shot + zero_shot=zero_shot, train_langs=zscl_train_langs, gpus=args.gpus, n_jobs=args.n_jobs) embedder_list.append(rnnEmbedder) if args.bert_embedder: bertEmbedder = BertGen(multilingualIndex, batch_size=args.batch_bert, nepochs=args.nepochs_bert, - zero_shot=zero_shot, train_langs=zscl_train_langs, # Todo: testing zero shot + zero_shot=zero_shot, train_langs=zscl_train_langs, patience=args.patience_bert, gpus=args.gpus, n_jobs=args.n_jobs) embedder_list.append(bertEmbedder) @@ -76,8 +78,8 @@ def main(args): # Testing ---------------------------------------- print('\n[Testing Generalized Funnelling]') time_te = time.time() - # TODO: Zero shot scenario -> setting first tier learners zero_shot param to False - gfun.set_zero_shot(val=False) + if args.zero_shot: + gfun.set_zero_shot(val=False) ly_ = gfun.predict(lXte) l_eval = evaluate(ly_true=lyte, ly_pred=ly_) time_te = round(time.time() - time_te, 3) @@ -85,7 +87,7 @@ def main(args): # Logging --------------------------------------- print('\n[Results]') - results = CSVlog(args.csv_dir) + results = CSVlog(f'csv_logs/gfun/{args.csv_dir}') metrics = [] for lang in lXte.keys(): macrof1, microf1, macrok, microk = l_eval[lang] @@ -120,8 +122,8 @@ if __name__ == '__main__': parser.add_argument('dataset', help='Path to the dataset') parser.add_argument('-o', '--output', dest='csv_dir', metavar='', - help='Result file (default csv_logs/gfun/gfun_results.csv)', type=str, - default='csv_logs/gfun/gfun_results.csv') + help='Result file saved in csv_logs/gfun/dir, default is gfun_results.csv)', type=str, + default='gfun_results.csv') parser.add_argument('-x', '--post_embedder', dest='post_embedder', action='store_true', help='deploy posterior probabilities embedder to compute document embeddings', @@ -194,5 +196,12 @@ if __name__ == '__main__': parser.add_argument('--gpus', metavar='', help='specifies how many GPUs to use per node', default=None) + parser.add_argument('--zero_shot', dest='zero_shot', action='store_true', + help='run zero-shot experiments', + default=False) + + parser.add_argument('--zscl_langs', dest='zscl_langs', metavar='', nargs='*', + help='set the languages to be used in training in zero shot experiments') + args = parser.parse_args() main(args) diff --git a/run.sh b/run.sh index 09ce599..8470998 100644 --- a/run.sh +++ b/run.sh @@ -2,7 +2,16 @@ echo Running Zero-shot experiments [output at csv_logs/gfun/zero_shot_gfun.csv] -python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -o csv_logs/gfun/zero_shot_gfun.csv --gpus 0 +python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da --n_jobs 6 +python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de --n_jobs 6 +python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en --n_jobs 6 +python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es --n_jobs 6 +python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr --n_jobs 6 +python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it --n_jobs 6 +python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl --n_jobs 6 +python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt --n_jobs 6 +python main.py /home/moreo/CLESA/rcv2/rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle -m -w -o csv_logs/gfun/zero_shot_gfun.csv --zero_shot --zscl_langs da de en es fr it nl pt sv --n_jobs 6 + #for i in {0..10..1} #do diff --git a/src/funnelling.py b/src/funnelling.py index 116d67b..c8d3fc6 100644 --- a/src/funnelling.py +++ b/src/funnelling.py @@ -128,6 +128,9 @@ class Funnelling: def set_zero_shot(self, val: bool): for embedder in self.first_tier.embedders: - embedder.embedder.set_zero_shot(val) + if isinstance(embedder, VanillaFunGen): + embedder.set_zero_shot(val) + else: + embedder.embedder.set_zero_shot(val) return diff --git a/src/view_generators.py b/src/view_generators.py index 05c5263..9c73615 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -56,7 +56,7 @@ class VanillaFunGen(ViewGen): View Generator (x): original funnelling architecture proposed by Moreo, Esuli and Sebastiani in DOI: https://doi.org/10.1145/3326065 """ - def __init__(self, base_learner, first_tier_parameters=None, n_jobs=-1): + def __init__(self, base_learner, first_tier_parameters=None, zero_shot=False, train_langs: list = None, n_jobs=-1): """ Init Posterior Probabilities embedder (i.e., VanillaFunGen) :param base_learner: naive monolingual learners to be deployed as first-tier learners. Should be able to @@ -71,10 +71,20 @@ class VanillaFunGen(ViewGen): self.doc_projector = NaivePolylingualClassifier(base_learner=self.learners, parameters=self.first_tier_parameters, n_jobs=self.n_jobs) self.vectorizer = TfidfVectorizerMultilingual(sublinear_tf=True, use_idf=True) + # Zero shot parameters + self.zero_shot = zero_shot + if train_langs is None: + train_langs = ['it'] + self.train_langs = train_langs def fit(self, lX, lY): print('# Fitting VanillaFunGen (X)...') - lX = self.vectorizer.fit_transform(lX) + if self.zero_shot: + self.langs = sorted(self.train_langs) + lX = self.zero_shot_experiments(lX) + lX = self.vectorizer.fit_transform(lX) + else: + lX = self.vectorizer.fit_transform(lX) self.doc_projector.fit(lX, lY) return self @@ -93,9 +103,19 @@ class VanillaFunGen(ViewGen): def fit_transform(self, lX, ly): return self.fit(lX, ly).transform(lX) + def zero_shot_experiments(self, lX): + print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') + _lX = {} + for lang in self.langs: + if lang in self.train_langs: + _lX[lang] = lX[lang] + else: + _lX[lang] = None + lX = _lX + return lX + def set_zero_shot(self, val: bool): self.zero_shot = val - print('# TODO: PosteriorsGen has not been configured for zero-shot experiments') return @@ -205,10 +225,16 @@ class WordClassGen(ViewGen): :return: self. """ print('# Fitting WordClassGen (W)...') - lX = self.vectorizer.fit_transform(lX) - self.langs = sorted(lX.keys()) + if self.zero_shot: + self.langs = sorted(self.train_langs) + lX = self.zero_shot_experiments(lX) + lX = self.vectorizer.fit_transform(lX) + else: + lX = self.vectorizer.fit_transform(lX) + self.langs = sorted(lX.keys()) + wce = Parallel(n_jobs=self.n_jobs)( - delayed(wce_matrix)(lX[lang], ly[lang]) for lang in self.langs) + delayed(wce_matrix)(lX[lang], ly[lang]) for lang in self.langs) self.lWce = {l: wce[i] for i, l in enumerate(self.langs)} # TODO: featureweight.fit() return self @@ -220,15 +246,10 @@ class WordClassGen(ViewGen): :param lX: dict {lang: indexed documents} :return: document projection to the common latent space. """ - # Testing zero-shot experiments - if self.zero_shot: - lX = self.zero_shot_experiments(lX) - lX = {l: self.vectorizer.vectorizer[l].transform(lX[l]) for l in self.langs if lX[l] is not None} - else: - lX = self.vectorizer.transform(lX) + lX = self.vectorizer.transform(lX) XdotWce = Parallel(n_jobs=self.n_jobs)( - delayed(XdotM)(lX[lang], self.lWce[lang], sif=True) for lang in sorted(lX.keys())) - lWce = {l: XdotWce[i] for i, l in enumerate(sorted(lX.keys()))} + delayed(XdotM)(lX[lang], self.lWce[lang], sif=True) for lang in sorted(lX.keys()) if lang in self.lWce.keys()) + lWce = {l: XdotWce[i] for i, l in enumerate(sorted(lX.keys())) if l in self.lWce.keys()} lWce = _normalize(lWce, l2=True) return lWce @@ -339,7 +360,7 @@ class RecurrentGen(ViewGen): print('# Fitting RecurrentGen (G)...') create_if_not_exist(self.logger.save_dir) recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs, - zero_shot=self.zero_shot, zscl_langs=self.train_langs) # Todo: zero shot settings + zero_shot=self.zero_shot, zscl_langs=self.train_langs) trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs, callbacks=[self.early_stop_callback, self.lr_monitor], checkpoint_callback=False) @@ -350,7 +371,7 @@ class RecurrentGen(ViewGen): # self.model.linear2 = vanilla_torch_model.linear2 # self.model.rnn = vanilla_torch_model.rnn - if self.zero_shot: # Todo: zero shot experiment setting + if self.zero_shot: print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') trainer.fit(self.model, datamodule=recurrentDataModule) @@ -451,7 +472,7 @@ class BertGen(ViewGen): bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512, zero_shot=self.zero_shot, zscl_langs=self.train_langs) - if self.zero_shot: # Todo: zero shot experiment setting + if self.zero_shot: print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,