From cb40b71a383c76dc8d9a0c8b66c0f505cd3979bc Mon Sep 17 00:00:00 2001 From: Pablo Gonzalez Date: Mon, 7 Jun 2021 12:22:06 +0200 Subject: [PATCH] fixing two problems with parameters: hidden_size and lstm_nlayers --- quapy/classification/neural.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/quapy/classification/neural.py b/quapy/classification/neural.py index afeb649..0e003a2 100644 --- a/quapy/classification/neural.py +++ b/quapy/classification/neural.py @@ -249,7 +249,7 @@ class TextClassifierNet(torch.nn.Module, metaclass=ABCMeta): class LSTMnet(TextClassifierNet): - def __init__(self, vocabulary_size, n_classes, embedding_size=100, hidden_size=256, repr_size=100, lstm_nlayers=1, + def __init__(self, vocabulary_size, n_classes, embedding_size=100, hidden_size=256, repr_size=100, lstm_class_nlayers=1, drop_p=0.5): super().__init__() self.vocabulary_size_ = vocabulary_size @@ -258,12 +258,12 @@ class LSTMnet(TextClassifierNet): 'embedding_size': embedding_size, 'hidden_size': hidden_size, 'repr_size': repr_size, - 'lstm_nlayers': lstm_nlayers, + 'lstm_class_nlayers': lstm_class_nlayers, 'drop_p': drop_p } self.word_embedding = torch.nn.Embedding(vocabulary_size, embedding_size) - self.lstm = torch.nn.LSTM(embedding_size, hidden_size, lstm_nlayers, dropout=drop_p, batch_first=True) + self.lstm = torch.nn.LSTM(embedding_size, hidden_size, lstm_class_nlayers, dropout=drop_p, batch_first=True) self.dropout = torch.nn.Dropout(drop_p) self.dim = repr_size @@ -272,8 +272,8 @@ class LSTMnet(TextClassifierNet): def init_hidden(self, set_size): opt = self.hyperparams - var_hidden = torch.zeros(opt['lstm_nlayers'], set_size, opt['lstm_hidden_size']) - var_cell = torch.zeros(opt['lstm_nlayers'], set_size, opt['lstm_hidden_size']) + var_hidden = torch.zeros(opt['lstm_class_nlayers'], set_size, opt['hidden_size']) + var_cell = torch.zeros(opt['lstm_class_nlayers'], set_size, opt['hidden_size']) if next(self.lstm.parameters()).is_cuda: var_hidden, var_cell = var_hidden.cuda(), var_cell.cuda() return var_hidden, var_cell