early stopping + typos
This commit is contained in:
parent
79cdaa0beb
commit
bb84422d24
20
main.py
20
main.py
|
|
@ -108,7 +108,7 @@ if __name__ == '__main__':
|
|||
|
||||
parser.add_argument('dataset', help='Path to the dataset')
|
||||
|
||||
parser.add_argument('-o', '--output', dest='csv_dir',
|
||||
parser.add_argument('-o', '--output', dest='csv_dir', metavar='',
|
||||
help='Result file (default ../csv_logs/gfun/gfun_results.csv)', type=str,
|
||||
default='../csv_logs/gfun/gfun_results.csv')
|
||||
|
||||
|
|
@ -133,22 +133,22 @@ if __name__ == '__main__':
|
|||
default=False)
|
||||
|
||||
parser.add_argument('-c', '--c_optimize', dest='optimc', action='store_true',
|
||||
help='Optimize SVMs C hyperparameter',
|
||||
help='Optimize SVMs C hyperparameter at metaclassifier level',
|
||||
default=False)
|
||||
|
||||
parser.add_argument('-j', '--n_jobs', dest='n_jobs', type=int,
|
||||
parser.add_argument('-j', '--n_jobs', dest='n_jobs', type=int, metavar='',
|
||||
help='Number of parallel jobs (default is -1, all)',
|
||||
default=-1)
|
||||
|
||||
parser.add_argument('--nepochs_rnn', dest='nepochs_rnn', type=int,
|
||||
help='Number of max epochs to train Recurrent embedder (i.e., -g), default 150.',
|
||||
parser.add_argument('--nepochs_rnn', dest='nepochs_rnn', type=int, metavar='',
|
||||
help='Number of max epochs to train Recurrent embedder (i.e., -g), default 150',
|
||||
default=150)
|
||||
|
||||
parser.add_argument('--nepochs_bert', dest='nepochs_bert', type=int,
|
||||
parser.add_argument('--nepochs_bert', dest='nepochs_bert', type=int, metavar='',
|
||||
help='Number of max epochs to train Bert model (i.e., -g), default 10',
|
||||
default=10)
|
||||
|
||||
parser.add_argument('--muse_dir', dest='muse_dir', type=str,
|
||||
parser.add_argument('--muse_dir', dest='muse_dir', type=str, metavar='',
|
||||
help='Path to the MUSE polylingual word embeddings (default ../embeddings)',
|
||||
default='../embeddings')
|
||||
|
||||
|
|
@ -156,15 +156,15 @@ if __name__ == '__main__':
|
|||
help='Deploy WCE embedding as embedding layer of the GRU View Generator',
|
||||
default=False)
|
||||
|
||||
parser.add_argument('--gru_dir', dest='gru_dir', type=str,
|
||||
parser.add_argument('--gru_dir', dest='gru_dir', type=str, metavar='',
|
||||
help='Set the path to a pretrained GRU model (i.e., -g view generator)',
|
||||
default=None)
|
||||
|
||||
parser.add_argument('--bert_dir', dest='bert_dir', type=str,
|
||||
parser.add_argument('--bert_dir', dest='bert_dir', type=str, metavar='',
|
||||
help='Set the path to a pretrained mBERT model (i.e., -b view generator)',
|
||||
default=None)
|
||||
|
||||
parser.add_argument('--gpus', help='specifies how many GPUs to use per node',
|
||||
parser.add_argument('--gpus', metavar='', help='specifies how many GPUs to use per node',
|
||||
default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ optional arguments:
|
|||
-g, --gru_embedder deploy a GRU in order to compute document embeddings
|
||||
-c, --c_optimize optimize SVMs C hyperparameter
|
||||
-j, --n_jobs number of parallel jobs (default is -1, all)
|
||||
--nepochs_rnn number of max epochs to train Recurrent embedder (i.e., -g), default 150.
|
||||
--nepochs_rnn number of max epochs to train Recurrent embedder (i.e., -g), default 150
|
||||
--nepochs_bert number of max epochs to train Bert model (i.e., -g), default 10
|
||||
--muse_dir path to the MUSE polylingual word embeddings (default ../embeddings)
|
||||
--gru_wce deploy WCE embedding as embedding layer of the GRU View Generator
|
||||
|
|
|
|||
|
|
@ -22,13 +22,13 @@ from pytorch_lightning import Trainer
|
|||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
|
||||
|
||||
from src.data.datamodule import RecurrentDataModule, BertDataModule, tokenize
|
||||
from src.models.learners import *
|
||||
from src.models.pl_bert import BertModel
|
||||
from src.models.pl_gru import RecurrentModel
|
||||
from src.util.common import TfidfVectorizerMultilingual, _normalize
|
||||
from src.util.embeddings_manager import MuseLoader, XdotM, wce_matrix
|
||||
from src.util.file import create_if_not_exist
|
||||
# TODO: add model checkpointing and loading from checkpoint + training on validation after convergence is reached
|
||||
|
||||
|
||||
|
|
@ -203,7 +203,7 @@ class RecurrentGen(ViewGen):
|
|||
the network internal state at the second feed-forward layer level. Training metrics are logged via TensorBoard.
|
||||
"""
|
||||
def __init__(self, multilingualIndex, pretrained_embeddings, wce, batch_size=512, nepochs=50,
|
||||
gpus=0, n_jobs=-1, stored_path=None):
|
||||
gpus=0, n_jobs=-1, patience=5, stored_path=None):
|
||||
"""
|
||||
Init RecurrentGen.
|
||||
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
||||
|
|
@ -217,6 +217,7 @@ class RecurrentGen(ViewGen):
|
|||
:param nepochs: int, number of max epochs to train the model.
|
||||
:param gpus: int, specifies how many GPUs to use per node. If False computation will take place on cpu.
|
||||
:param n_jobs: int, number of concurrent workers (i.e., parallelizing data loading).
|
||||
:param patience: int, number of epochs with no improvements in val-macroF1 before early stopping.
|
||||
:param stored_path: str, path to a pretrained model. If None the model will be trained from scratch.
|
||||
"""
|
||||
super().__init__()
|
||||
|
|
@ -227,6 +228,7 @@ class RecurrentGen(ViewGen):
|
|||
self.n_jobs = n_jobs
|
||||
self.stored_path = stored_path
|
||||
self.nepochs = nepochs
|
||||
self.patience = patience
|
||||
|
||||
# EMBEDDINGS to be deployed
|
||||
self.pretrained = pretrained_embeddings
|
||||
|
|
@ -238,7 +240,7 @@ class RecurrentGen(ViewGen):
|
|||
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False)
|
||||
# self.logger = CSVLogger(save_dir='csv_logs', name='rnn_dev')
|
||||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||
patience=5, verbose=False, mode='max')
|
||||
patience=self.patience, verbose=False, mode='max')
|
||||
|
||||
def _init_model(self):
|
||||
if self.stored_path:
|
||||
|
|
@ -273,12 +275,13 @@ class RecurrentGen(ViewGen):
|
|||
:return: self.
|
||||
"""
|
||||
print('# Fitting RecurrentGen (G)...')
|
||||
create_if_not_exist('../tb_logs')
|
||||
recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs)
|
||||
trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs,
|
||||
callbacks=[self.early_stop_callback], checkpoint_callback=False)
|
||||
|
||||
# vanilla_torch_model = torch.load(
|
||||
# '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle')
|
||||
# '../_old_checkpoint/gru_viewgen_-rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle')
|
||||
# self.model.linear0 = vanilla_torch_model.linear0
|
||||
# self.model.linear1 = vanilla_torch_model.linear1
|
||||
# self.model.linear2 = vanilla_torch_model.linear2
|
||||
|
|
@ -314,7 +317,7 @@ class BertGen(ViewGen):
|
|||
At inference time, the model returns the network internal state at the last original layer (i.e. 12th). Document
|
||||
embeddings are the state associated with the "start" token. Training metrics are logged via TensorBoard.
|
||||
"""
|
||||
def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, stored_path=None):
|
||||
def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, patience=5, stored_path=None):
|
||||
"""
|
||||
Init Bert model
|
||||
:param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents
|
||||
|
|
@ -322,6 +325,7 @@ class BertGen(ViewGen):
|
|||
:param batch_size: int, number of samples per batch.
|
||||
:param nepochs: int, number of max epochs to train the model.
|
||||
:param gpus: int, specifies how many GPUs to use per node. If False computation will take place on cpu.
|
||||
:param patience: int, number of epochs with no improvements in val-macroF1 before early stopping.
|
||||
:param n_jobs: int, number of concurrent workers.
|
||||
:param stored_path: str, path to a pretrained model. If None the model will be trained from scratch.
|
||||
"""
|
||||
|
|
@ -333,9 +337,10 @@ class BertGen(ViewGen):
|
|||
self.n_jobs = n_jobs
|
||||
self.stored_path = stored_path
|
||||
self.model = self._init_model()
|
||||
self.patience = patience
|
||||
self.logger = TensorBoardLogger(save_dir='../tb_logs', name='bert', default_hp_metric=False)
|
||||
self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00,
|
||||
patience=5, verbose=False, mode='max')
|
||||
patience=self.patience, verbose=False, mode='max')
|
||||
|
||||
def _init_model(self):
|
||||
output_size = self.multilingualIndex.get_target_dim()
|
||||
|
|
@ -351,6 +356,7 @@ class BertGen(ViewGen):
|
|||
:return: self.
|
||||
"""
|
||||
print('# Fitting BertGen (M)...')
|
||||
create_if_not_exist('../tb_logs')
|
||||
self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1)
|
||||
bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512)
|
||||
trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,
|
||||
|
|
|
|||
Loading…
Reference in New Issue