early stopping
This commit is contained in:
parent
5cd36d27fc
commit
e9a410faa4
|
|
@ -112,24 +112,24 @@ class RecurrentDataModule(pl.LightningDataModule):
|
|||
if stage == 'fit' or stage is None:
|
||||
l_train_index, l_train_target = self.multilingualIndex.l_train()
|
||||
# Debug settings: reducing number of samples
|
||||
l_train_index = {l: train[:5] for l, train in l_train_index.items()}
|
||||
l_train_target = {l: target[:5] for l, target in l_train_target.items()}
|
||||
# l_train_index = {l: train[:5] for l, train in l_train_index.items()}
|
||||
# l_train_target = {l: target[:5] for l, target in l_train_target.items()}
|
||||
|
||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||
lPad_index=self.multilingualIndex.l_pad())
|
||||
|
||||
l_val_index, l_val_target = self.multilingualIndex.l_val()
|
||||
# Debug settings: reducing number of samples
|
||||
l_val_index = {l: train[:5] for l, train in l_val_index.items()}
|
||||
l_val_target = {l: target[:5] for l, target in l_val_target.items()}
|
||||
# l_val_index = {l: train[:5] for l, train in l_val_index.items()}
|
||||
# l_val_target = {l: target[:5] for l, target in l_val_target.items()}
|
||||
|
||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||
lPad_index=self.multilingualIndex.l_pad())
|
||||
if stage == 'test' or stage is None:
|
||||
l_test_index, l_test_target = self.multilingualIndex.l_test()
|
||||
# Debug settings: reducing number of samples
|
||||
l_test_index = {l: train[:5] for l, train in l_test_index.items()}
|
||||
l_test_target = {l: target[:5] for l, target in l_test_target.items()}
|
||||
# l_test_index = {l: train[:5] for l, train in l_test_index.items()}
|
||||
# l_test_target = {l: target[:5] for l, target in l_test_target.items()}
|
||||
|
||||
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
||||
lPad_index=self.multilingualIndex.l_pad())
|
||||
|
|
@ -182,8 +182,8 @@ class BertDataModule(RecurrentDataModule):
|
|||
if stage == 'fit' or stage is None:
|
||||
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw()
|
||||
# Debug settings: reducing number of samples
|
||||
l_train_raw = {l: train[:5] for l, train in l_train_raw.items()}
|
||||
l_train_target = {l: target[:5] for l, target in l_train_target.items()}
|
||||
# l_train_raw = {l: train[:5] for l, train in l_train_raw.items()}
|
||||
# l_train_target = {l: target[:5] for l, target in l_train_target.items()}
|
||||
|
||||
l_train_index = tokenize(l_train_raw, max_len=self.max_len)
|
||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||
|
|
@ -191,8 +191,8 @@ class BertDataModule(RecurrentDataModule):
|
|||
|
||||
l_val_raw, l_val_target = self.multilingualIndex.l_val_raw()
|
||||
# Debug settings: reducing number of samples
|
||||
l_val_raw = {l: train[:5] for l, train in l_val_raw.items()}
|
||||
l_val_target = {l: target[:5] for l, target in l_val_target.items()}
|
||||
# l_val_raw = {l: train[:5] for l, train in l_val_raw.items()}
|
||||
# l_val_target = {l: target[:5] for l, target in l_val_target.items()}
|
||||
|
||||
l_val_index = tokenize(l_val_raw, max_len=self.max_len)
|
||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||
|
|
@ -201,8 +201,8 @@ class BertDataModule(RecurrentDataModule):
|
|||
if stage == 'test' or stage is None:
|
||||
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw()
|
||||
# Debug settings: reducing number of samples
|
||||
l_test_raw = {l: train[:5] for l, train in l_test_raw.items()}
|
||||
l_test_target = {l: target[:5] for l, target in l_test_target.items()}
|
||||
# l_test_raw = {l: train[:5] for l, train in l_test_raw.items()}
|
||||
# l_test_target = {l: target[:5] for l, target in l_test_target.items()}
|
||||
|
||||
l_test_index = tokenize(l_test_raw, max_len=self.max_len)
|
||||
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
||||
|
|
|
|||
|
|
@ -20,6 +20,8 @@ from abc import ABC, abstractmethod
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
|
||||
|
||||
from src.data.datamodule import RecurrentDataModule, BertDataModule, tokenize
|
||||
from src.models.learners import *
|
||||
|
|
@ -27,7 +29,7 @@ from src.models.pl_bert import BertModel
|
|||
from src.models.pl_gru import RecurrentModel
|
||||
from src.util.common import TfidfVectorizerMultilingual, _normalize
|
||||
from src.util.embeddings_manager import MuseLoader, XdotM, wce_matrix
|
||||
# TODO: add early stop monitoring validation macroF1 + model checkpointing and loading from checkpoint
|
||||
# TODO: add model checkpointing and loading from checkpoint + training on validation after convergence is reached
|
||||
|
||||
|
||||
class ViewGen(ABC):
|
||||
|
|
@ -235,6 +237,8 @@ class RecurrentGen(ViewGen):
|
|||
self.model = self._init_model()
|
||||
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False)
|
||||
# self.logger = CSVLogger(save_dir='csv_logs', name='rnn_dev')
|
||||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||
patience=5, verbose=False, mode='max')
|
||||
|
||||
def _init_model(self):
|
||||
if self.stored_path:
|
||||
|
|
@ -271,7 +275,7 @@ class RecurrentGen(ViewGen):
|
|||
print('# Fitting RecurrentGen (G)...')
|
||||
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs)
|
||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
||||
checkpoint_callback=False)
|
||||
callbacks=[self.early_stop_callback], checkpoint_callback=False)
|
||||
|
||||
# vanilla_torch_model = torch.load(
|
||||
# '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle')
|
||||
|
|
@ -330,6 +334,8 @@ class BertGen(ViewGen):
|
|||
self.stored_path = stored_path
|
||||
self.model = self._init_model()
|
||||
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='bert', default_hp_metric=False)
|
||||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||
patience=5, verbose=False, mode='max')
|
||||
|
||||
def _init_model(self):
|
||||
output_size = self.multilingualIndex.get_target_dim()
|
||||
|
|
@ -348,7 +354,7 @@ class BertGen(ViewGen):
|
|||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
||||
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
|
||||
logger=self.logger, checkpoint_callback=False)
|
||||
logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False)
|
||||
trainer.fit(self.model, datamodule=bertDataModule)
|
||||
trainer.test(self.model, datamodule=bertDataModule)
|
||||
return self
|
||||
|
|
|
|||
Loading…
Reference in New Issue