lr monitor

This commit is contained in:
andrea 2021-01-29 16:57:21 +01:00
parent 2c70f37823
commit 0e01d654cf
1 changed files with 3 additions and 1 deletions

View File

@ -21,6 +21,7 @@ from abc import ABC, abstractmethod
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from src.data.datamodule import RecurrentDataModule, BertDataModule, tokenize
from src.models.learners import *
@ -240,6 +241,7 @@ class RecurrentGen(ViewGen):
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False)
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
patience=self.patience, verbose=False, mode='max')
self.lr_monitor = LearningRateMonitor(logging_interval='epoch')
def _init_model(self):
if self.stored_path:
@ -277,7 +279,7 @@ class RecurrentGen(ViewGen):
create_if_not_exist(self.logger.save_dir)
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs)
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
callbacks=[self.early_stop_callback], checkpoint_callback=False)
callbacks=[self.early_stop_callback, self.lr_monitor], checkpoint_callback=False)
# vanilla_torch_model = torch.load(
# '../_old_checkpoint/gru_viewgen_-rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle')