diff --git a/main.py b/main.py index da2748d..25d7a07 100644 --- a/main.py +++ b/main.py @@ -71,7 +71,7 @@ def main(args): print('\n[Testing Generalized Funnelling]') time_te = time.time() ly_ = gfun.predict(lXte) - l_eval = evaluate(ly_true=lyte, ly_pred=ly_) + l_eval = evaluate(ly_true=lyte, ly_pred=ly_, n_jobs=args.n_jobs) time_te = round(time.time() - time_te, 3) print(f'Testing completed in {time_te} seconds!') diff --git a/src/view_generators.py b/src/view_generators.py index e972ce7..a934181 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -18,6 +18,7 @@ This module contains the view generators that take care of computing the view sp from abc import ABC, abstractmethod # from time import time +import torch from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks.early_stopping import EarlyStopping @@ -241,6 +242,10 @@ class RecurrentGen(ViewGen): self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False) self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00, patience=self.patience, verbose=False, mode='max') + + # modifying EarlyStopping global var in order to compute >= with respect to the best score + self.early_stop_callback.mode_dict['max'] = torch.ge + self.lr_monitor = LearningRateMonitor(logging_interval='epoch') def _init_model(self): @@ -348,6 +353,9 @@ class BertGen(ViewGen): self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00, patience=self.patience, verbose=False, mode='max') + # modifying EarlyStopping global var in order to compute >= with respect to the best score + self.early_stop_callback.mode_dict['max'] = torch.ge + def _init_model(self): output_size = self.multilingualIndex.get_target_dim() return BertModel(output_size=output_size, stored_path=self.stored_path, gpus=self.gpus)