From e9a410faa44689438919376c793a0e18dc95667b Mon Sep 17 00:00:00 2001 From: andrea Date: Tue, 26 Jan 2021 18:12:14 +0100 Subject: [PATCH] early stopping --- src/data/datamodule.py | 24 ++++++++++++------------ src/view_generators.py | 12 +++++++++--- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/data/datamodule.py b/src/data/datamodule.py index bf874c7..66146b3 100644 --- a/src/data/datamodule.py +++ b/src/data/datamodule.py @@ -112,24 +112,24 @@ class RecurrentDataModule(pl.LightningDataModule): if stage == 'fit' or stage is None: l_train_index, l_train_target = self.multilingualIndex.l_train() # Debug settings: reducing number of samples - l_train_index = {l: train[:5] for l, train in l_train_index.items()} - l_train_target = {l: target[:5] for l, target in l_train_target.items()} + # l_train_index = {l: train[:5] for l, train in l_train_index.items()} + # l_train_target = {l: target[:5] for l, target in l_train_target.items()} self.training_dataset = RecurrentDataset(l_train_index, l_train_target, lPad_index=self.multilingualIndex.l_pad()) l_val_index, l_val_target = self.multilingualIndex.l_val() # Debug settings: reducing number of samples - l_val_index = {l: train[:5] for l, train in l_val_index.items()} - l_val_target = {l: target[:5] for l, target in l_val_target.items()} + # l_val_index = {l: train[:5] for l, train in l_val_index.items()} + # l_val_target = {l: target[:5] for l, target in l_val_target.items()} self.val_dataset = RecurrentDataset(l_val_index, l_val_target, lPad_index=self.multilingualIndex.l_pad()) if stage == 'test' or stage is None: l_test_index, l_test_target = self.multilingualIndex.l_test() # Debug settings: reducing number of samples - l_test_index = {l: train[:5] for l, train in l_test_index.items()} - l_test_target = {l: target[:5] for l, target in l_test_target.items()} + # l_test_index = {l: train[:5] for l, train in l_test_index.items()} + # l_test_target = {l: target[:5] for l, target in l_test_target.items()} self.test_dataset = RecurrentDataset(l_test_index, l_test_target, lPad_index=self.multilingualIndex.l_pad()) @@ -182,8 +182,8 @@ class BertDataModule(RecurrentDataModule): if stage == 'fit' or stage is None: l_train_raw, l_train_target = self.multilingualIndex.l_train_raw() # Debug settings: reducing number of samples - l_train_raw = {l: train[:5] for l, train in l_train_raw.items()} - l_train_target = {l: target[:5] for l, target in l_train_target.items()} + # l_train_raw = {l: train[:5] for l, train in l_train_raw.items()} + # l_train_target = {l: target[:5] for l, target in l_train_target.items()} l_train_index = tokenize(l_train_raw, max_len=self.max_len) self.training_dataset = RecurrentDataset(l_train_index, l_train_target, @@ -191,8 +191,8 @@ class BertDataModule(RecurrentDataModule): l_val_raw, l_val_target = self.multilingualIndex.l_val_raw() # Debug settings: reducing number of samples - l_val_raw = {l: train[:5] for l, train in l_val_raw.items()} - l_val_target = {l: target[:5] for l, target in l_val_target.items()} + # l_val_raw = {l: train[:5] for l, train in l_val_raw.items()} + # l_val_target = {l: target[:5] for l, target in l_val_target.items()} l_val_index = tokenize(l_val_raw, max_len=self.max_len) self.val_dataset = RecurrentDataset(l_val_index, l_val_target, @@ -201,8 +201,8 @@ class BertDataModule(RecurrentDataModule): if stage == 'test' or stage is None: l_test_raw, l_test_target = self.multilingualIndex.l_test_raw() # Debug settings: reducing number of samples - l_test_raw = {l: train[:5] for l, train in l_test_raw.items()} - l_test_target = {l: target[:5] for l, target in l_test_target.items()} + # l_test_raw = {l: train[:5] for l, train in l_test_raw.items()} + # l_test_target = {l: target[:5] for l, target in l_test_target.items()} l_test_index = tokenize(l_test_raw, max_len=self.max_len) self.test_dataset = RecurrentDataset(l_test_index, l_test_target, diff --git a/src/view_generators.py b/src/view_generators.py index 452714c..20a8045 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -20,6 +20,8 @@ from abc import ABC, abstractmethod from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.callbacks.early_stopping import EarlyStopping + from src.data.datamodule import RecurrentDataModule, BertDataModule, tokenize from src.models.learners import * @@ -27,7 +29,7 @@ from src.models.pl_bert import BertModel from src.models.pl_gru import RecurrentModel from src.util.common import TfidfVectorizerMultilingual, _normalize from src.util.embeddings_manager import MuseLoader, XdotM, wce_matrix -# TODO: add early stop monitoring validation macroF1 + model checkpointing and loading from checkpoint +# TODO: add model checkpointing and loading from checkpoint + training on validation after convergence is reached class ViewGen(ABC): @@ -235,6 +237,8 @@ class RecurrentGen(ViewGen): self.model = self._init_model() self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False) # self.logger = CSVLogger(save_dir='csv_logs', name='rnn_dev') + self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00, + patience=5, verbose=False, mode='max') def _init_model(self): if self.stored_path: @@ -271,7 +275,7 @@ class RecurrentGen(ViewGen): print('# Fitting RecurrentGen (G)...') recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs) trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs, - checkpoint_callback=False) + callbacks=[self.early_stop_callback], checkpoint_callback=False) # vanilla_torch_model = torch.load( # '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle') @@ -330,6 +334,8 @@ class BertGen(ViewGen): self.stored_path = stored_path self.model = self._init_model() self.logger = TensorBoardLogger(save_dir='../tb_logs', name='bert', default_hp_metric=False) + self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00, + patience=5, verbose=False, mode='max') def _init_model(self): output_size = self.multilingualIndex.get_target_dim() @@ -348,7 +354,7 @@ class BertGen(ViewGen): self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1) bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512) trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus, - logger=self.logger, checkpoint_callback=False) + logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False) trainer.fit(self.model, datamodule=bertDataModule) trainer.test(self.model, datamodule=bertDataModule) return self