lr monitor
This commit is contained in:
parent
2c70f37823
commit
0e01d654cf
|
|
@ -21,6 +21,7 @@ from abc import ABC, abstractmethod
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.loggers import TensorBoardLogger
|
from pytorch_lightning.loggers import TensorBoardLogger
|
||||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||||
|
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
|
||||||
|
|
||||||
from src.data.datamodule import RecurrentDataModule, BertDataModule, tokenize
|
from src.data.datamodule import RecurrentDataModule, BertDataModule, tokenize
|
||||||
from src.models.learners import *
|
from src.models.learners import *
|
||||||
|
|
@ -240,6 +241,7 @@ class RecurrentGen(ViewGen):
|
||||||
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False)
|
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False)
|
||||||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||||
patience=self.patience, verbose=False, mode='max')
|
patience=self.patience, verbose=False, mode='max')
|
||||||
|
self.lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||||
|
|
||||||
def _init_model(self):
|
def _init_model(self):
|
||||||
if self.stored_path:
|
if self.stored_path:
|
||||||
|
|
@ -277,7 +279,7 @@ class RecurrentGen(ViewGen):
|
||||||
create_if_not_exist(self.logger.save_dir)
|
create_if_not_exist(self.logger.save_dir)
|
||||||
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs)
|
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs)
|
||||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
||||||
callbacks=[self.early_stop_callback], checkpoint_callback=False)
|
callbacks=[self.early_stop_callback, self.lr_monitor], checkpoint_callback=False)
|
||||||
|
|
||||||
# vanilla_torch_model = torch.load(
|
# vanilla_torch_model = torch.load(
|
||||||
# '../_old_checkpoint/gru_viewgen_-rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle')
|
# '../_old_checkpoint/gru_viewgen_-rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue