diff --git a/main.py b/main.py index 0650310..cff4887 100644 --- a/main.py +++ b/main.py @@ -41,7 +41,7 @@ def main(args): embedder_list.append(wceEmbedder) if args.gru_embedder: - rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=args.gru_wce, + rnnEmbedder = RecurrentGen(multilingualIndex, pretrained_embeddings=lMuse, wce=args.rnn_wce, batch_size=args.batch_rnn, nepochs=args.nepochs_rnn, patience=args.patience_rnn, gpus=args.gpus, n_jobs=args.n_jobs) embedder_list.append(rnnEmbedder) @@ -132,7 +132,7 @@ if __name__ == '__main__': default=False) parser.add_argument('-g', '--gru_embedder', dest='gru_embedder', action='store_true', - help='deploy a GRU in order to compute document embeddings', + help='deploy a GRU in order to compute document embeddings (a.k.a., RecurrentGen)', default=False) parser.add_argument('-c', '--c_optimize', dest='optimc', action='store_true', @@ -171,8 +171,8 @@ if __name__ == '__main__': help='Path to the MUSE polylingual word embeddings (default embeddings/)', default='embeddings/') - parser.add_argument('--gru_wce', dest='gru_wce', action='store_true', - help='Deploy WCE embedding as embedding layer of the GRU View Generator', + parser.add_argument('--rnn_wce', dest='rnn_wce', action='store_true', + help='Deploy WCE embedding as embedding layer of the RecurrentGen', default=False) parser.add_argument('--rnn_dir', dest='rnn_dir', type=str, metavar='',