diff --git a/src/view_generators.py b/src/view_generators.py index 27da0fc..af4ee8e 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -21,6 +21,7 @@ from abc import ABC, abstractmethod from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from src.data.datamodule import RecurrentDataModule, BertDataModule, tokenize from src.models.learners import * @@ -240,6 +241,7 @@ class RecurrentGen(ViewGen): self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False) self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00, patience=self.patience, verbose=False, mode='max') + self.lr_monitor = LearningRateMonitor(logging_interval='epoch') def _init_model(self): if self.stored_path: @@ -277,7 +279,7 @@ class RecurrentGen(ViewGen): create_if_not_exist(self.logger.save_dir) recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs) trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs, - callbacks=[self.early_stop_callback], checkpoint_callback=False) + callbacks=[self.early_stop_callback, self.lr_monitor], checkpoint_callback=False) # vanilla_torch_model = torch.load( # '../_old_checkpoint/gru_viewgen_-rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle')