early stopping
This commit is contained in:
parent
5cd36d27fc
commit
e9a410faa4
|
|
@ -112,24 +112,24 @@ class RecurrentDataModule(pl.LightningDataModule):
|
||||||
if stage == 'fit' or stage is None:
|
if stage == 'fit' or stage is None:
|
||||||
l_train_index, l_train_target = self.multilingualIndex.l_train()
|
l_train_index, l_train_target = self.multilingualIndex.l_train()
|
||||||
# Debug settings: reducing number of samples
|
# Debug settings: reducing number of samples
|
||||||
l_train_index = {l: train[:5] for l, train in l_train_index.items()}
|
# l_train_index = {l: train[:5] for l, train in l_train_index.items()}
|
||||||
l_train_target = {l: target[:5] for l, target in l_train_target.items()}
|
# l_train_target = {l: target[:5] for l, target in l_train_target.items()}
|
||||||
|
|
||||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
|
|
||||||
l_val_index, l_val_target = self.multilingualIndex.l_val()
|
l_val_index, l_val_target = self.multilingualIndex.l_val()
|
||||||
# Debug settings: reducing number of samples
|
# Debug settings: reducing number of samples
|
||||||
l_val_index = {l: train[:5] for l, train in l_val_index.items()}
|
# l_val_index = {l: train[:5] for l, train in l_val_index.items()}
|
||||||
l_val_target = {l: target[:5] for l, target in l_val_target.items()}
|
# l_val_target = {l: target[:5] for l, target in l_val_target.items()}
|
||||||
|
|
||||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
if stage == 'test' or stage is None:
|
if stage == 'test' or stage is None:
|
||||||
l_test_index, l_test_target = self.multilingualIndex.l_test()
|
l_test_index, l_test_target = self.multilingualIndex.l_test()
|
||||||
# Debug settings: reducing number of samples
|
# Debug settings: reducing number of samples
|
||||||
l_test_index = {l: train[:5] for l, train in l_test_index.items()}
|
# l_test_index = {l: train[:5] for l, train in l_test_index.items()}
|
||||||
l_test_target = {l: target[:5] for l, target in l_test_target.items()}
|
# l_test_target = {l: target[:5] for l, target in l_test_target.items()}
|
||||||
|
|
||||||
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
||||||
lPad_index=self.multilingualIndex.l_pad())
|
lPad_index=self.multilingualIndex.l_pad())
|
||||||
|
|
@ -182,8 +182,8 @@ class BertDataModule(RecurrentDataModule):
|
||||||
if stage == 'fit' or stage is None:
|
if stage == 'fit' or stage is None:
|
||||||
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw()
|
l_train_raw, l_train_target = self.multilingualIndex.l_train_raw()
|
||||||
# Debug settings: reducing number of samples
|
# Debug settings: reducing number of samples
|
||||||
l_train_raw = {l: train[:5] for l, train in l_train_raw.items()}
|
# l_train_raw = {l: train[:5] for l, train in l_train_raw.items()}
|
||||||
l_train_target = {l: target[:5] for l, target in l_train_target.items()}
|
# l_train_target = {l: target[:5] for l, target in l_train_target.items()}
|
||||||
|
|
||||||
l_train_index = tokenize(l_train_raw, max_len=self.max_len)
|
l_train_index = tokenize(l_train_raw, max_len=self.max_len)
|
||||||
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
self.training_dataset = RecurrentDataset(l_train_index, l_train_target,
|
||||||
|
|
@ -191,8 +191,8 @@ class BertDataModule(RecurrentDataModule):
|
||||||
|
|
||||||
l_val_raw, l_val_target = self.multilingualIndex.l_val_raw()
|
l_val_raw, l_val_target = self.multilingualIndex.l_val_raw()
|
||||||
# Debug settings: reducing number of samples
|
# Debug settings: reducing number of samples
|
||||||
l_val_raw = {l: train[:5] for l, train in l_val_raw.items()}
|
# l_val_raw = {l: train[:5] for l, train in l_val_raw.items()}
|
||||||
l_val_target = {l: target[:5] for l, target in l_val_target.items()}
|
# l_val_target = {l: target[:5] for l, target in l_val_target.items()}
|
||||||
|
|
||||||
l_val_index = tokenize(l_val_raw, max_len=self.max_len)
|
l_val_index = tokenize(l_val_raw, max_len=self.max_len)
|
||||||
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
self.val_dataset = RecurrentDataset(l_val_index, l_val_target,
|
||||||
|
|
@ -201,8 +201,8 @@ class BertDataModule(RecurrentDataModule):
|
||||||
if stage == 'test' or stage is None:
|
if stage == 'test' or stage is None:
|
||||||
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw()
|
l_test_raw, l_test_target = self.multilingualIndex.l_test_raw()
|
||||||
# Debug settings: reducing number of samples
|
# Debug settings: reducing number of samples
|
||||||
l_test_raw = {l: train[:5] for l, train in l_test_raw.items()}
|
# l_test_raw = {l: train[:5] for l, train in l_test_raw.items()}
|
||||||
l_test_target = {l: target[:5] for l, target in l_test_target.items()}
|
# l_test_target = {l: target[:5] for l, target in l_test_target.items()}
|
||||||
|
|
||||||
l_test_index = tokenize(l_test_raw, max_len=self.max_len)
|
l_test_index = tokenize(l_test_raw, max_len=self.max_len)
|
||||||
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
self.test_dataset = RecurrentDataset(l_test_index, l_test_target,
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,8 @@ from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.loggers import TensorBoardLogger
|
from pytorch_lightning.loggers import TensorBoardLogger
|
||||||
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||||
|
|
||||||
|
|
||||||
from src.data.datamodule import RecurrentDataModule, BertDataModule, tokenize
|
from src.data.datamodule import RecurrentDataModule, BertDataModule, tokenize
|
||||||
from src.models.learners import *
|
from src.models.learners import *
|
||||||
|
|
@ -27,7 +29,7 @@ from src.models.pl_bert import BertModel
|
||||||
from src.models.pl_gru import RecurrentModel
|
from src.models.pl_gru import RecurrentModel
|
||||||
from src.util.common import TfidfVectorizerMultilingual, _normalize
|
from src.util.common import TfidfVectorizerMultilingual, _normalize
|
||||||
from src.util.embeddings_manager import MuseLoader, XdotM, wce_matrix
|
from src.util.embeddings_manager import MuseLoader, XdotM, wce_matrix
|
||||||
# TODO: add early stop monitoring validation macroF1 + model checkpointing and loading from checkpoint
|
# TODO: add model checkpointing and loading from checkpoint + training on validation after convergence is reached
|
||||||
|
|
||||||
|
|
||||||
class ViewGen(ABC):
|
class ViewGen(ABC):
|
||||||
|
|
@ -235,6 +237,8 @@ class RecurrentGen(ViewGen):
|
||||||
self.model = self._init_model()
|
self.model = self._init_model()
|
||||||
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False)
|
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False)
|
||||||
# self.logger = CSVLogger(save_dir='csv_logs', name='rnn_dev')
|
# self.logger = CSVLogger(save_dir='csv_logs', name='rnn_dev')
|
||||||
|
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||||
|
patience=5, verbose=False, mode='max')
|
||||||
|
|
||||||
def _init_model(self):
|
def _init_model(self):
|
||||||
if self.stored_path:
|
if self.stored_path:
|
||||||
|
|
@ -271,7 +275,7 @@ class RecurrentGen(ViewGen):
|
||||||
print('# Fitting RecurrentGen (G)...')
|
print('# Fitting RecurrentGen (G)...')
|
||||||
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs)
|
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs)
|
||||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
||||||
checkpoint_callback=False)
|
callbacks=[self.early_stop_callback], checkpoint_callback=False)
|
||||||
|
|
||||||
# vanilla_torch_model = torch.load(
|
# vanilla_torch_model = torch.load(
|
||||||
# '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle')
|
# '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle')
|
||||||
|
|
@ -330,6 +334,8 @@ class BertGen(ViewGen):
|
||||||
self.stored_path = stored_path
|
self.stored_path = stored_path
|
||||||
self.model = self._init_model()
|
self.model = self._init_model()
|
||||||
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='bert', default_hp_metric=False)
|
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='bert', default_hp_metric=False)
|
||||||
|
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||||
|
patience=5, verbose=False, mode='max')
|
||||||
|
|
||||||
def _init_model(self):
|
def _init_model(self):
|
||||||
output_size = self.multilingualIndex.get_target_dim()
|
output_size = self.multilingualIndex.get_target_dim()
|
||||||
|
|
@ -348,7 +354,7 @@ class BertGen(ViewGen):
|
||||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
||||||
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
|
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
|
||||||
logger=self.logger, checkpoint_callback=False)
|
logger=self.logger, callbacks=[self.early_stop_callback], checkpoint_callback=False)
|
||||||
trainer.fit(self.model, datamodule=bertDataModule)
|
trainer.fit(self.model, datamodule=bertDataModule)
|
||||||
trainer.test(self.model, datamodule=bertDataModule)
|
trainer.test(self.model, datamodule=bertDataModule)
|
||||||
return self
|
return self
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue