lr monitor
This commit is contained in:
parent
2c70f37823
commit
0e01d654cf
|
@ -21,6 +21,7 @@ from abc import ABC, abstractmethod
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
|
||||
|
||||
from src.data.datamodule import RecurrentDataModule, BertDataModule, tokenize
|
||||
from src.models.learners import *
|
||||
|
@ -240,6 +241,7 @@ class RecurrentGen(ViewGen):
|
|||
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False)
|
||||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||
patience=self.patience, verbose=False, mode='max')
|
||||
self.lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||
|
||||
def _init_model(self):
|
||||
if self.stored_path:
|
||||
|
@ -277,7 +279,7 @@ class RecurrentGen(ViewGen):
|
|||
create_if_not_exist(self.logger.save_dir)
|
||||
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs)
|
||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
||||
callbacks=[self.early_stop_callback], checkpoint_callback=False)
|
||||
callbacks=[self.early_stop_callback, self.lr_monitor], checkpoint_callback=False)
|
||||
|
||||
# vanilla_torch_model = torch.load(
|
||||
# '../_old_checkpoint/gru_viewgen_-rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle')
|
||||
|
|
Loading…
Reference in New Issue