fixed pl early stop --> patience was consumed if actual_monitor == best_monitor. Set policy to greater or equal.
This commit is contained in:
parent
a1c4247e17
commit
a4f74dcf41
2
main.py
2
main.py
|
@ -71,7 +71,7 @@ def main(args):
|
|||
print('\n[Testing Generalized Funnelling]')
|
||||
time_te = time.time()
|
||||
ly_ = gfun.predict(lXte)
|
||||
l_eval = evaluate(ly_true=lyte, ly_pred=ly_)
|
||||
l_eval = evaluate(ly_true=lyte, ly_pred=ly_, n_jobs=args.n_jobs)
|
||||
time_te = round(time.time() - time_te, 3)
|
||||
print(f'Testing completed in {time_te} seconds!')
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ This module contains the view generators that take care of computing the view sp
|
|||
from abc import ABC, abstractmethod
|
||||
# from time import time
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
|
@ -241,6 +242,10 @@ class RecurrentGen(ViewGen):
|
|||
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False)
|
||||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||
patience=self.patience, verbose=False, mode='max')
|
||||
|
||||
# modifying EarlyStopping global var in order to compute >= with respect to the best score
|
||||
self.early_stop_callback.mode_dict['max'] = torch.ge
|
||||
|
||||
self.lr_monitor = LearningRateMonitor(logging_interval='epoch')
|
||||
|
||||
def _init_model(self):
|
||||
|
@ -348,6 +353,9 @@ class BertGen(ViewGen):
|
|||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||
patience=self.patience, verbose=False, mode='max')
|
||||
|
||||
# modifying EarlyStopping global var in order to compute >= with respect to the best score
|
||||
self.early_stop_callback.mode_dict['max'] = torch.ge
|
||||
|
||||
def _init_model(self):
|
||||
output_size = self.multilingualIndex.get_target_dim()
|
||||
return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)
|
||||
|
|
Loading…
Reference in New Issue