implemented BertDataModule collate function
This commit is contained in:
parent
f579a1a7f2
commit
7b6938459f
|
|
@ -150,7 +150,7 @@ class RecurrentDataModule(pl.LightningDataModule):
|
|||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(self.training_dataset, batch_size=self.batchsize, num_workers=N_WORKERS,
|
||||
collate_fn=self.training_dataset.collate_fn)
|
||||
collate_fn=self.training_dataset.collate_fn, shuffle=True)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.val_dataset, batch_size=self.batchsize, num_workers=N_WORKERS,
|
||||
|
|
@ -247,7 +247,8 @@ class BertDataModule(RecurrentDataModule):
|
|||
NB: Setting n_workers to > 0 will cause "OSError: [Errno 24] Too many open files"
|
||||
:return:
|
||||
"""
|
||||
return DataLoader(self.training_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert)
|
||||
return DataLoader(self.training_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert,
|
||||
shuffle=True)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(self.val_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert)
|
||||
|
|
|
|||
|
|
@ -141,7 +141,8 @@ class BertModel(pl.LightningModule):
|
|||
'weight_decay': weight_decay}
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
|
||||
scheduler = StepLR(optimizer, step_size=25, gamma=0.1)
|
||||
scheduler = {'scheduler': StepLR(optimizer, step_size=25, gamma=1.0), # TODO set to 1.0 to debug (prev. 0.1)
|
||||
'interval': 'epoch'}
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
def encode(self, lX, batch_size=64):
|
||||
|
|
|
|||
|
|
@ -366,7 +366,8 @@ class RecurrentGen(ViewGen):
|
|||
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs,
|
||||
zero_shot=self.zero_shot, zscl_langs=self.train_langs)
|
||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
||||
callbacks=[self.early_stop_callback, self.lr_monitor], checkpoint_callback=False)
|
||||
callbacks=[self.early_stop_callback, self.lr_monitor], checkpoint_callback=False,
|
||||
overfit_batches=0.01)
|
||||
|
||||
# vanilla_torch_model = torch.load(
|
||||
# '../_old_checkpoint/gru_viewgen_-rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle')
|
||||
|
|
@ -474,14 +475,14 @@ class BertGen(ViewGen):
|
|||
create_if_not_exist(self.logger.save_dir)
|
||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
|
||||
zero_shot=self.zero_shot, zscl_langs=self.train_langs,
|
||||
debug=True)
|
||||
zero_shot=self.zero_shot, zscl_langs=self.train_langs)
|
||||
|
||||
if self.zero_shot:
|
||||
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||
|
||||
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
|
||||
logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False)
|
||||
logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False,
|
||||
overfit_batches=0.01) # todo: overfit_batches -> DEBUG setting
|
||||
trainer.fit(self.model, datamodule=bertDataModule)
|
||||
trainer.test(self.model, datamodule=bertDataModule)
|
||||
return self
|
||||
|
|
|
|||
Loading…
Reference in New Issue