implemented BertDataModule collate function

This commit is contained in:
andrea 2021-02-09 09:43:19 +01:00
parent f579a1a7f2
commit 7b6938459f
3 changed files with 10 additions and 7 deletions

View File

@ -150,7 +150,7 @@ class RecurrentDataModule(pl.LightningDataModule):
def train_dataloader(self):
return DataLoader(self.training_dataset, batch_size=self.batchsize, num_workers=N_WORKERS,
collate_fn=self.training_dataset.collate_fn)
collate_fn=self.training_dataset.collate_fn, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batchsize, num_workers=N_WORKERS,
@ -247,7 +247,8 @@ class BertDataModule(RecurrentDataModule):
NB: Setting n_workers to > 0 will cause "OSError: [Errno 24] Too many open files"
:return:
"""
return DataLoader(self.training_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert)
return DataLoader(self.training_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert,
shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert)

View File

@ -141,7 +141,8 @@ class BertModel(pl.LightningModule):
'weight_decay': weight_decay}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
scheduler = StepLR(optimizer, step_size=25, gamma=0.1)
scheduler = {'scheduler': StepLR(optimizer, step_size=25, gamma=1.0), # TODO set to 1.0 to debug (prev. 0.1)
'interval': 'epoch'}
return [optimizer], [scheduler]
def encode(self, lX, batch_size=64):

View File

@ -366,7 +366,8 @@ class RecurrentGen(ViewGen):
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs,
zero_shot=self.zero_shot, zscl_langs=self.train_langs)
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
callbacks=[self.early_stop_callback, self.lr_monitor], checkpoint_callback=False)
callbacks=[self.early_stop_callback, self.lr_monitor], checkpoint_callback=False,
overfit_batches=0.01)
# vanilla_torch_model = torch.load(
# '../_old_checkpoint/gru_viewgen_-rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle')
@ -474,14 +475,14 @@ class BertGen(ViewGen):
create_if_not_exist(self.logger.save_dir)
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
zero_shot=self.zero_shot, zscl_langs=self.train_langs,
debug=True)
zero_shot=self.zero_shot, zscl_langs=self.train_langs)
if self.zero_shot:
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False)
logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False,
overfit_batches=0.01) # todo: overfit_batches -> DEBUG setting
trainer.fit(self.model, datamodule=bertDataModule)
trainer.test(self.model, datamodule=bertDataModule)
return self