From 7b6938459f8c3a5a8c1c43177ade8bbcbd8bdd93 Mon Sep 17 00:00:00 2001 From: andrea Date: Tue, 9 Feb 2021 09:43:19 +0100 Subject: [PATCH] implemented BertDataModule collate function --- src/data/datamodule.py | 5 +++-- src/models/pl_bert.py | 3 ++- src/view_generators.py | 9 +++++---- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/data/datamodule.py b/src/data/datamodule.py index 767f349..268bf29 100644 --- a/src/data/datamodule.py +++ b/src/data/datamodule.py @@ -150,7 +150,7 @@ class RecurrentDataModule(pl.LightningDataModule): def train_dataloader(self): return DataLoader(self.training_dataset, batch_size=self.batchsize, num_workers=N_WORKERS, - collate_fn=self.training_dataset.collate_fn) + collate_fn=self.training_dataset.collate_fn, shuffle=True) def val_dataloader(self): return DataLoader(self.val_dataset, batch_size=self.batchsize, num_workers=N_WORKERS, @@ -247,7 +247,8 @@ class BertDataModule(RecurrentDataModule): NB: Setting n_workers to > 0 will cause "OSError: [Errno 24] Too many open files" :return: """ - return DataLoader(self.training_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert) + return DataLoader(self.training_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert, + shuffle=True) def val_dataloader(self): return DataLoader(self.val_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert) diff --git a/src/models/pl_bert.py b/src/models/pl_bert.py index dba9c8e..0a38e9f 100644 --- a/src/models/pl_bert.py +++ b/src/models/pl_bert.py @@ -141,7 +141,8 @@ class BertModel(pl.LightningModule): 'weight_decay': weight_decay} ] optimizer = AdamW(optimizer_grouped_parameters, lr=lr) - scheduler = StepLR(optimizer, step_size=25, gamma=0.1) + scheduler = {'scheduler': StepLR(optimizer, step_size=25, gamma=1.0), # TODO set to 1.0 to debug (prev. 0.1) + 'interval': 'epoch'} return [optimizer], [scheduler] def encode(self, lX, batch_size=64): diff --git a/src/view_generators.py b/src/view_generators.py index cd992ba..663e89d 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -366,7 +366,8 @@ class RecurrentGen(ViewGen): recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs, zero_shot=self.zero_shot, zscl_langs=self.train_langs) trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs, - callbacks=[self.early_stop_callback, self.lr_monitor], checkpoint_callback=False) + callbacks=[self.early_stop_callback, self.lr_monitor], checkpoint_callback=False, + overfit_batches=0.01) # vanilla_torch_model = torch.load( # '../_old_checkpoint/gru_viewgen_-rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle') @@ -474,14 +475,14 @@ class BertGen(ViewGen): create_if_not_exist(self.logger.save_dir) self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1) bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512, - zero_shot=self.zero_shot, zscl_langs=self.train_langs, - debug=True) + zero_shot=self.zero_shot, zscl_langs=self.train_langs) if self.zero_shot: print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}') trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus, - logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False) + logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False, + overfit_batches=0.01) # todo: overfit_batches -> DEBUG setting trainer.fit(self.model, datamodule=bertDataModule) trainer.test(self.model, datamodule=bertDataModule) return self