This commit is contained in:
Alejandro Moreo Fernandez 2021-06-21 11:13:38 +02:00
commit f33abb5319
1 changed files with 5 additions and 5 deletions

View File

@ -249,7 +249,7 @@ class TextClassifierNet(torch.nn.Module, metaclass=ABCMeta):
class LSTMnet(TextClassifierNet):
def __init__(self, vocabulary_size, n_classes, embedding_size=100, hidden_size=256, repr_size=100, lstm_nlayers=1,
def __init__(self, vocabulary_size, n_classes, embedding_size=100, hidden_size=256, repr_size=100, lstm_class_nlayers=1,
drop_p=0.5):
super().__init__()
self.vocabulary_size_ = vocabulary_size
@ -258,12 +258,12 @@ class LSTMnet(TextClassifierNet):
'embedding_size': embedding_size,
'hidden_size': hidden_size,
'repr_size': repr_size,
'lstm_nlayers': lstm_nlayers,
'lstm_class_nlayers': lstm_class_nlayers,
'drop_p': drop_p
}
self.word_embedding = torch.nn.Embedding(vocabulary_size, embedding_size)
self.lstm = torch.nn.LSTM(embedding_size, hidden_size, lstm_nlayers, dropout=drop_p, batch_first=True)
self.lstm = torch.nn.LSTM(embedding_size, hidden_size, lstm_class_nlayers, dropout=drop_p, batch_first=True)
self.dropout = torch.nn.Dropout(drop_p)
self.dim = repr_size
@ -272,8 +272,8 @@ class LSTMnet(TextClassifierNet):
def init_hidden(self, set_size):
opt = self.hyperparams
var_hidden = torch.zeros(opt['lstm_nlayers'], set_size, opt['lstm_hidden_size'])
var_cell = torch.zeros(opt['lstm_nlayers'], set_size, opt['lstm_hidden_size'])
var_hidden = torch.zeros(opt['lstm_class_nlayers'], set_size, opt['hidden_size'])
var_cell = torch.zeros(opt['lstm_class_nlayers'], set_size, opt['hidden_size'])
if next(self.lstm.parameters()).is_cuda:
var_hidden, var_cell = var_hidden.cuda(), var_cell.cuda()
return var_hidden, var_cell