diff --git a/main.py b/main.py index 42623bd..b99f024 100644 --- a/main.py +++ b/main.py @@ -108,7 +108,7 @@ if __name__ == '__main__': parser.add_argument('dataset', help='Path to the dataset') - parser.add_argument('-o', '--output', dest='csv_dir', + parser.add_argument('-o', '--output', dest='csv_dir', metavar='', help='Result file (default ../csv_logs/gfun/gfun_results.csv)', type=str, default='../csv_logs/gfun/gfun_results.csv') @@ -133,22 +133,22 @@ if __name__ == '__main__': default=False) parser.add_argument('-c', '--c_optimize', dest='optimc', action='store_true', - help='Optimize SVMs C hyperparameter', + help='Optimize SVMs C hyperparameter at metaclassifier level', default=False) - parser.add_argument('-j', '--n_jobs', dest='n_jobs', type=int, + parser.add_argument('-j', '--n_jobs', dest='n_jobs', type=int, metavar='', help='Number of parallel jobs (default is -1, all)', default=-1) - parser.add_argument('--nepochs_rnn', dest='nepochs_rnn', type=int, - help='Number of max epochs to train Recurrent embedder (i.e., -g), default 150.', + parser.add_argument('--nepochs_rnn', dest='nepochs_rnn', type=int, metavar='', + help='Number of max epochs to train Recurrent embedder (i.e., -g), default 150', default=150) - parser.add_argument('--nepochs_bert', dest='nepochs_bert', type=int, + parser.add_argument('--nepochs_bert', dest='nepochs_bert', type=int, metavar='', help='Number of max epochs to train Bert model (i.e., -g), default 10', default=10) - parser.add_argument('--muse_dir', dest='muse_dir', type=str, + parser.add_argument('--muse_dir', dest='muse_dir', type=str, metavar='', help='Path to the MUSE polylingual word embeddings (default ../embeddings)', default='../embeddings') @@ -156,15 +156,15 @@ if __name__ == '__main__': help='Deploy WCE embedding as embedding layer of the GRU View Generator', default=False) - parser.add_argument('--gru_dir', dest='gru_dir', type=str, + parser.add_argument('--gru_dir', dest='gru_dir', type=str, metavar='', help='Set the path to a pretrained GRU model (i.e., -g view generator)', default=None) - parser.add_argument('--bert_dir', dest='bert_dir', type=str, + parser.add_argument('--bert_dir', dest='bert_dir', type=str, metavar='', help='Set the path to a pretrained mBERT model (i.e., -b view generator)', default=None) - parser.add_argument('--gpus', help='specifies how many GPUs to use per node', + parser.add_argument('--gpus', metavar='', help='specifies how many GPUs to use per node', default=None) args = parser.parse_args() diff --git a/readme.md b/readme.md index 06c8633..4569ba8 100644 --- a/readme.md +++ b/readme.md @@ -38,7 +38,7 @@ optional arguments: -g, --gru_embedder deploy a GRU in order to compute document embeddings -c, --c_optimize optimize SVMs C hyperparameter -j, --n_jobs number of parallel jobs (default is -1, all) - --nepochs_rnn number of max epochs to train Recurrent embedder (i.e., -g), default 150. + --nepochs_rnn number of max epochs to train Recurrent embedder (i.e., -g), default 150 --nepochs_bert number of max epochs to train Bert model (i.e., -g), default 10 --muse_dir path to the MUSE polylingual word embeddings (default ../embeddings) --gru_wce deploy WCE embedding as embedding layer of the GRU View Generator diff --git a/src/view_generators.py b/src/view_generators.py index d014ef0..9b352f8 100644 --- a/src/view_generators.py +++ b/src/view_generators.py @@ -22,13 +22,13 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks.early_stopping import EarlyStopping - from src.data.datamodule import RecurrentDataModule, BertDataModule, tokenize from src.models.learners import * from src.models.pl_bert import BertModel from src.models.pl_gru import RecurrentModel from src.util.common import TfidfVectorizerMultilingual, _normalize from src.util.embeddings_manager import MuseLoader, XdotM, wce_matrix +from src.util.file import create_if_not_exist # TODO: add model checkpointing and loading from checkpoint + training on validation after convergence is reached @@ -203,7 +203,7 @@ class RecurrentGen(ViewGen): the network internal state at the second feed-forward layer level. Training metrics are logged via TensorBoard. """ def __init__(self, multilingualIndex, pretrained_embeddings, wce, batch_size=512, nepochs=50, - gpus=0, n_jobs=-1, stored_path=None): + gpus=0, n_jobs=-1, patience=5, stored_path=None): """ Init RecurrentGen. :param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents @@ -217,6 +217,7 @@ class RecurrentGen(ViewGen): :param nepochs: int, number of max epochs to train the model. :param gpus: int, specifies how many GPUs to use per node. If False computation will take place on cpu. :param n_jobs: int, number of concurrent workers (i.e., parallelizing data loading). + :param patience: int, number of epochs with no improvements in val-macroF1 before early stopping. :param stored_path: str, path to a pretrained model. If None the model will be trained from scratch. """ super().__init__() @@ -227,6 +228,7 @@ class RecurrentGen(ViewGen): self.n_jobs = n_jobs self.stored_path = stored_path self.nepochs = nepochs + self.patience = patience # EMBEDDINGS to be deployed self.pretrained = pretrained_embeddings @@ -238,7 +240,7 @@ class RecurrentGen(ViewGen): self.logger = TensorBoardLogger(save_dir='../tb_logs', name='rnn', default_hp_metric=False) # self.logger = CSVLogger(save_dir='csv_logs', name='rnn_dev') self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00, - patience=5, verbose=False, mode='max') + patience=self.patience, verbose=False, mode='max') def _init_model(self): if self.stored_path: @@ -273,12 +275,13 @@ class RecurrentGen(ViewGen): :return: self. """ print('# Fitting RecurrentGen (G)...') + create_if_not_exist('../tb_logs') recurrentDataModule = RecurrentDataModule(self.multilingualIndex, batchsize=self.batch_size, n_jobs=self.n_jobs) trainer = Trainer(gradient_clip_val=1e-1, gpus=self.gpus, logger=self.logger, max_epochs=self.nepochs, callbacks=[self.early_stop_callback], checkpoint_callback=False) # vanilla_torch_model = torch.load( - # '/home/andreapdr/funneling_pdr/checkpoint/gru_viewgen_-jrc_doclist_1958-2005vs2006_all_top300_noparallel_processed_run0.pickle') + # '../_old_checkpoint/gru_viewgen_-rcv1-2_doclist_trByLang1000_teByLang1000_processed_run0.pickle') # self.model.linear0 = vanilla_torch_model.linear0 # self.model.linear1 = vanilla_torch_model.linear1 # self.model.linear2 = vanilla_torch_model.linear2 @@ -314,7 +317,7 @@ class BertGen(ViewGen): At inference time, the model returns the network internal state at the last original layer (i.e. 12th). Document embeddings are the state associated with the "start" token. Training metrics are logged via TensorBoard. """ - def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, stored_path=None): + def __init__(self, multilingualIndex, batch_size=128, nepochs=50, gpus=0, n_jobs=-1, patience=5, stored_path=None): """ Init Bert model :param multilingualIndex: MultilingualIndex, it is a dictionary of training and test documents @@ -322,6 +325,7 @@ class BertGen(ViewGen): :param batch_size: int, number of samples per batch. :param nepochs: int, number of max epochs to train the model. :param gpus: int, specifies how many GPUs to use per node. If False computation will take place on cpu. + :param patience: int, number of epochs with no improvements in val-macroF1 before early stopping. :param n_jobs: int, number of concurrent workers. :param stored_path: str, path to a pretrained model. If None the model will be trained from scratch. """ @@ -333,9 +337,10 @@ class BertGen(ViewGen): self.n_jobs = n_jobs self.stored_path = stored_path self.model = self._init_model() + self.patience = patience self.logger = TensorBoardLogger(save_dir='../tb_logs', name='bert', default_hp_metric=False) self.early_stop_callback = EarlyStopping(monitor='val-macroF1', min_delta=0.00, - patience=5, verbose=False, mode='max') + patience=self.patience, verbose=False, mode='max') def _init_model(self): output_size = self.multilingualIndex.get_target_dim() @@ -351,6 +356,7 @@ class BertGen(ViewGen): :return: self. """ print('# Fitting BertGen (M)...') + create_if_not_exist('../tb_logs') self.multilingualIndex.train_val_split(val_prop=0.2, max_val=2000, seed=1) bertDataModule = BertDataModule(self.multilingualIndex, batchsize=self.batch_size, max_len=512) trainer = Trainer(gradient_clip_val=1e-1, max_epochs=self.nepochs, gpus=self.gpus,