implemented BertDataModule collate function
This commit is contained in:
parent
f579a1a7f2
commit
7b6938459f
|
|
@ -150,7 +150,7 @@ class RecurrentDataModule(pl.LightningDataModule):
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return DataLoader(self.training_dataset, batch_size=self.batchsize, num_workers=N_WORKERS,
|
return DataLoader(self.training_dataset, batch_size=self.batchsize, num_workers=N_WORKERS,
|
||||||
collate_fn=self.training_dataset.collate_fn)
|
collate_fn=self.training_dataset.collate_fn, shuffle=True)
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
return DataLoader(self.val_dataset, batch_size=self.batchsize, num_workers=N_WORKERS,
|
return DataLoader(self.val_dataset, batch_size=self.batchsize, num_workers=N_WORKERS,
|
||||||
|
|
@ -247,7 +247,8 @@ class BertDataModule(RecurrentDataModule):
|
||||||
NB: Setting n_workers to > 0 will cause "OSError: [Errno 24] Too many open files"
|
NB: Setting n_workers to > 0 will cause "OSError: [Errno 24] Too many open files"
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return DataLoader(self.training_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert)
|
return DataLoader(self.training_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert,
|
||||||
|
shuffle=True)
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
return DataLoader(self.val_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert)
|
return DataLoader(self.val_dataset, batch_size=self.batchsize, collate_fn=self.collate_fn_bert)
|
||||||
|
|
|
||||||
|
|
@ -141,7 +141,8 @@ class BertModel(pl.LightningModule):
|
||||||
'weight_decay': weight_decay}
|
'weight_decay': weight_decay}
|
||||||
]
|
]
|
||||||
optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
|
optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
|
||||||
scheduler = StepLR(optimizer, step_size=25, gamma=0.1)
|
scheduler = {'scheduler': StepLR(optimizer, step_size=25, gamma=1.0), # TODO set to 1.0 to debug (prev. 0.1)
|
||||||
|
'interval': 'epoch'}
|
||||||
return [optimizer], [scheduler]
|
return [optimizer], [scheduler]
|
||||||
|
|
||||||
def encode(self, lX, batch_size=64):
|
def encode(self, lX, batch_size=64):
|
||||||
|
|
|
||||||
|
|
@ -366,7 +366,8 @@ class RecurrentGen(ViewGen):
|
||||||
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs,
|
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs,
|
||||||
zero_shot=self.zero_shot, zscl_langs=self.train_langs)
|
zero_shot=self.zero_shot, zscl_langs=self.train_langs)
|
||||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
||||||
callbacks=[self.early_stop_callback, self.lr_monitor], checkpoint_callback=False)
|
callbacks=[self.early_stop_callback, self.lr_monitor], checkpoint_callback=False,
|
||||||
|
overfit_batches=0.01)
|
||||||
|
|
||||||
# vanilla_torch_model = torch.load(
|
# vanilla_torch_model = torch.load(
|
||||||
# '../_old_checkpoint/gru_viewgen_-rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle')
|
# '../_old_checkpoint/gru_viewgen_-rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle')
|
||||||
|
|
@ -474,14 +475,14 @@ class BertGen(ViewGen):
|
||||||
create_if_not_exist(self.logger.save_dir)
|
create_if_not_exist(self.logger.save_dir)
|
||||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
|
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512,
|
||||||
zero_shot=self.zero_shot, zscl_langs=self.train_langs,
|
zero_shot=self.zero_shot, zscl_langs=self.train_langs)
|
||||||
debug=True)
|
|
||||||
|
|
||||||
if self.zero_shot:
|
if self.zero_shot:
|
||||||
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
print(f'# Zero-shot setting! Training langs will be set to: {sorted(self.train_langs)}')
|
||||||
|
|
||||||
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
|
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
|
||||||
logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False)
|
logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False,
|
||||||
|
overfit_batches=0.01) # todo: overfit_batches -> DEBUG setting
|
||||||
trainer.fit(self.model, datamodule=bertDataModule)
|
trainer.fit(self.model, datamodule=bertDataModule)
|
||||||
trainer.test(self.model, datamodule=bertDataModule)
|
trainer.test(self.model, datamodule=bertDataModule)
|
||||||
return self
|
return self
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue