From ee98c5f610da1f1bb8e8c5341b1ed001313e42cd Mon Sep 17 00:00:00 2001 From: andrea Date: Tue, 2 Feb 2021 16:09:49 +0100 Subject: [PATCH] setting up zero-shot experiments (implemented for Recurrent and Bert but not tested) --- main.py | 1 + src/data/datamodule.py | 44 +++++++++++++++++++++++++++++++++-------- src/util/common.py | 45 ++++++++++++++++++++++++++++++++++++++++++ src/view_generators.py | 14 +++++++++++-- 4 files changed, 94 insertions(+), 10 deletions(-) diff --git a/main.py b/main.py index ea6a329..5de4b8d 100644 --- a/main.py +++ b/main.py @@ -48,6 +48,7 @@ def main(args): if args.gru_embedder: rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=args.rnn_wce, batch_size=args.batch_rnn, nepochs=args.nepochs_rnn, patience=args.patience_rnn, + zero_shot=zero_shot, train_langs=zscl_train_langs, # Todo: testing zero shot gpus=args.gpus, n_jobs=args.n_jobs) embedder_list.append(rnnEmbedder) diff --git a/src/data/datamodule.py b/src/data/datamodule.py index 66146b3..067d47f 100644 --- a/src/data/datamodule.py +++ b/src/data/datamodule.py @@ -92,7 +92,7 @@ class RecurrentDataModule(pl.LightningDataModule): Pytorch Lightning Datamodule to be deployed with RecurrentGen. https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html """ - def __init__(self, multilingualIndex, batchsize=64, n_jobs=-1): + def __init__(self, multilingualIndex, batchsize=64, n_jobs=-1, zero_shot=False, zscl_langs=None): """ Init RecurrentDataModule. :param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents @@ -103,6 +103,11 @@ class RecurrentDataModule(pl.LightningDataModule): self.multilingualIndex = multilingualIndex self.batchsize = batchsize self.n_jobs = n_jobs + # Zero shot arguments + if zscl_langs is None: + zscl_langs = [] + self.zero_shot = zero_shot + self.train_langs = zscl_langs super().__init__() def prepare_data(self, *args, **kwargs): @@ -110,7 +115,10 @@ class RecurrentDataModule(pl.LightningDataModule): def setup(self, stage=None): if stage == 'fit' or stage is None: - l_train_index, l_train_target = self.multilingualIndex.l_train() + if self.zero_shot: + l_train_index, l_train_target = self.multilingualIndex.l_train_zero_shot(langs=self.train_langs) + else: + l_train_index, l_train_target = self.multilingualIndex.l_train() # Debug settings: reducing number of samples # l_train_index = {l: train[:5] for l, train in l_train_index.items()} # l_train_target = {l: target[:5] for l, target in l_train_target.items()} @@ -118,7 +126,10 @@ class RecurrentDataModule(pl.LightningDataModule): self.training_dataset = RecurrentDataset(l_train_index, l_train_target, lPad_index=self.multilingualIndex.l_pad()) - l_val_index, l_val_target = self.multilingualIndex.l_val() + if self.zero_shot: + l_val_index, l_val_target = self.multilingualIndex.l_val_zero_shot(langs=self.train_langs) + else: + l_val_index, l_val_target = self.multilingualIndex.l_val() # Debug settings: reducing number of samples # l_val_index = {l: train[:5] for l, train in l_val_index.items()} # l_val_target = {l: target[:5] for l, target in l_val_target.items()} @@ -126,7 +137,10 @@ class RecurrentDataModule(pl.LightningDataModule): self.val_dataset = RecurrentDataset(l_val_index, l_val_target, lPad_index=self.multilingualIndex.l_pad()) if stage == 'test' or stage is None: - l_test_index, l_test_target = self.multilingualIndex.l_test() + if self.zero_shot: + l_test_index, l_test_target = self.multilingualIndex.l_test_zero_shot(langs=self.train_langs) + else: + l_test_index, l_test_target = self.multilingualIndex.l_test() # Debug settings: reducing number of samples # l_test_index = {l: train[:5] for l, train in l_test_index.items()} # l_test_target = {l: target[:5] for l, target in l_test_target.items()} @@ -167,7 +181,7 @@ class BertDataModule(RecurrentDataModule): Pytorch Lightning Datamodule to be deployed with BertGen. https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html """ - def __init__(self, multilingualIndex, batchsize=64, max_len=512): + def __init__(self, multilingualIndex, batchsize=64, max_len=512, zero_shot=False, zscl_langs=None): """ Init BertDataModule. :param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents @@ -177,10 +191,18 @@ class BertDataModule(RecurrentDataModule): """ super().__init__(multilingualIndex, batchsize) self.max_len = max_len + # Zero shot arguments + if zscl_langs is None: + zscl_langs = [] + self.zero_shot = zero_shot + self.train_langs = zscl_langs def setup(self, stage=None): if stage == 'fit' or stage is None: - l_train_raw, l_train_target = self.multilingualIndex.l_train_raw() + if self.zero_shot: + l_train_raw, l_train_target = self.multilingualIndex.l_train_raw_zero_shot(langs=self.train_langs) # todo: check this! + else: + l_train_raw, l_train_target = self.multilingualIndex.l_train_raw() # Debug settings: reducing number of samples # l_train_raw = {l: train[:5] for l, train in l_train_raw.items()} # l_train_target = {l: target[:5] for l, target in l_train_target.items()} @@ -189,7 +211,10 @@ class BertDataModule(RecurrentDataModule): self.training_dataset = RecurrentDataset(l_train_index, l_train_target, lPad_index=self.multilingualIndex.l_pad()) - l_val_raw, l_val_target = self.multilingualIndex.l_val_raw() + if self.zero_shot: + l_val_raw, l_val_target = self.multilingualIndex.l_val_raw_zero_shot(langs=self.train_langs) # todo: check this! + else: + l_val_raw, l_val_target = self.multilingualIndex.l_val_raw() # Debug settings: reducing number of samples # l_val_raw = {l: train[:5] for l, train in l_val_raw.items()} # l_val_target = {l: target[:5] for l, target in l_val_target.items()} @@ -199,7 +224,10 @@ class BertDataModule(RecurrentDataModule): lPad_index=self.multilingualIndex.l_pad()) if stage == 'test' or stage is None: - l_test_raw, l_test_target = self.multilingualIndex.l_test_raw() + if self.zero_shot: + l_test_raw, l_test_target = self.multilingualIndex.l_test_raw_zero_shot(langs=self.train_langs) # todo: check this! + else: + l_test_raw, l_test_target = self.multilingualIndex.l_test_raw() # Debug settings: reducing number of samples # l_test_raw = {l: train[:5] for l, train in l_test_raw.items()} # l_test_target = {l: target[:5] for l, target in l_test_target.items()} diff --git a/src/util/common.py b/src/util/common.py index 9f44273..bae237c 100644 --- a/src/util/common.py +++ b/src/util/common.py @@ -149,33 +149,60 @@ class MultilingualIndex: def l_train_index(self): return {l: index.train_index for l, index in self.l_index.items()} + def l_train_index_zero_shot(self, langs): + return {l: index.train_index for l, index in self.l_index.items() if l in langs} + def l_train_raw_index(self): return {l: index.train_raw for l, index in self.l_index.items()} + def l_train_raw_index_zero_shot(self, langs): + return {l: index.train_raw for l, index in self.l_index.items() if l in langs} + def l_train_target(self): return {l: index.train_target for l, index in self.l_index.items()} + def l_train_target_zero_shot(self, langs): + return {l: index.train_target for l, index in self.l_index.items() if l in langs} + def l_val_index(self): return {l: index.val_index for l, index in self.l_index.items()} + def l_val_index_zero_shot(self, langs): + return {l: index.val_index for l, index in self.l_index.items() if l in langs} + def l_val_raw_index(self): return {l: index.val_raw for l, index in self.l_index.items()} + def l_val_raw_index_zero_shot(self, langs): + return {l: index.val_raw for l, index in self.l_index.items() if l in langs} + def l_test_raw_index(self): return {l: index.test_raw for l, index in self.l_index.items()} + def l_test_raw_index_zero_shot(self, langs): + return {l: index.test_raw for l, index in self.l_index.items() for l in langs} + def l_devel_raw_index(self): return {l: index.devel_raw for l, index in self.l_index.items()} def l_val_target(self): return {l: index.val_target for l, index in self.l_index.items()} + def l_val_target_zero_shot(self, langs): + return {l: index.val_target for l, index in self.l_index.items() if l in langs} + def l_test_target(self): return {l: index.test_target for l, index in self.l_index.items()} def l_test_index(self): return {l: index.test_index for l, index in self.l_index.items()} + def l_test_target_zero_shot(self, langs): + return {l: index.test_target for l, index in self.l_index.items() if l in langs} + + def l_test_index_zero_shot(self, langs): + return {l: index.test_index for l, index in self.l_index.items() if l in langs} + def l_devel_index(self): return {l: index.devel_index for l, index in self.l_index.items()} @@ -191,15 +218,33 @@ class MultilingualIndex: def l_test(self): return self.l_test_index(), self.l_test_target() + def l_test_zero_shot(self, langs): + return self.l_test_index_zero_shot(langs), self.l_test_target_zero_shot(langs) + + def l_train_zero_shot(self, langs): + return self.l_train_index_zero_shot(langs), self.l_train_target_zero_shot(langs) + + def l_val_zero_shot(self, langs): + return self.l_val_index_zero_shot(langs), self.l_val_target_zero_shot(langs) + def l_train_raw(self): return self.l_train_raw_index(), self.l_train_target() + def l_train_raw_zero_shot(self, langs): + return self.l_train_raw_index_zero_shot(langs), self.l_train_target_zero_shot(langs) + def l_val_raw(self): return self.l_val_raw_index(), self.l_val_target() + def l_val_raw_zero_shot(self, langs): + return self.l_val_raw_index_zero_shot(langs), self.l_val_target_zero_shot(langs) + def l_test_raw(self): return self.l_test_raw_index(), self.l_test_target() + def l_test_raw_zero_shot(self, langs): + return self.l_test_raw_index_zero_shot(langs), self.l_test_target_zero_shot(langs) + def l_devel_raw(self): return self.l_devel_raw_index(), self.l_devel_target() diff --git a/src/view_generators.py b/src/view_generators.py index 688f133..f1e3eb6 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -259,7 +259,7 @@ class RecurrentGen(ViewGen): the network internal state at the second feed-forward layer level. Training metrics are logged via TensorBoard. """ def __init__(self, multilingualIndex, pretrained_embeddings, wce, batch_size=512, nepochs=50, - gpus=0, n_jobs=-1, patience=20, stored_path=None): + gpus=0, n_jobs=-1, patience=20, stored_path=None, zero_shot=False, train_langs: list = None): """ Init RecurrentGen. :param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents @@ -298,6 +298,12 @@ class RecurrentGen(ViewGen): patience=self.patience, verbose=False, mode='max') self.lr_monitor = LearningRateMonitor(logging_interval='epoch') + # Zero shot parameters + self.zero_shot = zero_shot + if train_langs is None: + train_langs = ['it'] + self.train_langs = train_langs + def _init_model(self): if self.stored_path: lpretrained = self.multilingualIndex.l_embeddings() @@ -332,7 +338,8 @@ class RecurrentGen(ViewGen): """ print('# Fitting RecurrentGen (G)...') create_if_not_exist(self.logger.save_dir) - recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs) + recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs, + zero_shot=self.zero_shot, zscl_langs=self.train_langs) # Todo: zero shot settings trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs, callbacks=[self.early_stop_callback, self.lr_monitor], checkpoint_callback=False) @@ -343,6 +350,9 @@ class RecurrentGen(ViewGen): # self.model.linear2 = vanilla_torch_model.linear2 # self.model.rnn = vanilla_torch_model.rnn + if self.zero_shot: # Todo: zero shot experiment setting + print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') + trainer.fit(self.model, datamodule=recurrentDataModule) trainer.test(self.model, datamodule=recurrentDataModule) return self