Merge branch 'master' of https://github.com/HLT-ISTI/QuaPy
This commit is contained in:
commit
f33abb5319
|
@ -249,7 +249,7 @@ class TextClassifierNet(torch.nn.Module, metaclass=ABCMeta):
|
|||
|
||||
class LSTMnet(TextClassifierNet):
|
||||
|
||||
def __init__(self, vocabulary_size, n_classes, embedding_size=100, hidden_size=256, repr_size=100, lstm_nlayers=1,
|
||||
def __init__(self, vocabulary_size, n_classes, embedding_size=100, hidden_size=256, repr_size=100, lstm_class_nlayers=1,
|
||||
drop_p=0.5):
|
||||
super().__init__()
|
||||
self.vocabulary_size_ = vocabulary_size
|
||||
|
@ -258,12 +258,12 @@ class LSTMnet(TextClassifierNet):
|
|||
'embedding_size': embedding_size,
|
||||
'hidden_size': hidden_size,
|
||||
'repr_size': repr_size,
|
||||
'lstm_nlayers': lstm_nlayers,
|
||||
'lstm_class_nlayers': lstm_class_nlayers,
|
||||
'drop_p': drop_p
|
||||
}
|
||||
|
||||
self.word_embedding = torch.nn.Embedding(vocabulary_size, embedding_size)
|
||||
self.lstm = torch.nn.LSTM(embedding_size, hidden_size, lstm_nlayers, dropout=drop_p, batch_first=True)
|
||||
self.lstm = torch.nn.LSTM(embedding_size, hidden_size, lstm_class_nlayers, dropout=drop_p, batch_first=True)
|
||||
self.dropout = torch.nn.Dropout(drop_p)
|
||||
|
||||
self.dim = repr_size
|
||||
|
@ -272,8 +272,8 @@ class LSTMnet(TextClassifierNet):
|
|||
|
||||
def init_hidden(self, set_size):
|
||||
opt = self.hyperparams
|
||||
var_hidden = torch.zeros(opt['lstm_nlayers'], set_size, opt['lstm_hidden_size'])
|
||||
var_cell = torch.zeros(opt['lstm_nlayers'], set_size, opt['lstm_hidden_size'])
|
||||
var_hidden = torch.zeros(opt['lstm_class_nlayers'], set_size, opt['hidden_size'])
|
||||
var_cell = torch.zeros(opt['lstm_class_nlayers'], set_size, opt['hidden_size'])
|
||||
if next(self.lstm.parameters()).is_cuda:
|
||||
var_hidden, var_cell = var_hidden.cuda(), var_cell.cuda()
|
||||
return var_hidden, var_cell
|
||||
|
|
Loading…
Reference in New Issue