Implemented inference functions for bert (cpu and gpu)
This commit is contained in:
parent
6e0b66e13e
commit
ae0ea1e68c
|
|
@ -47,7 +47,7 @@ class BertModel(pl.LightningModule):
|
||||||
X = torch.cat(X).view([X[0].shape[0], len(X)])
|
X = torch.cat(X).view([X[0].shape[0], len(X)])
|
||||||
# y = y.type(torch.cuda.FloatTensor)
|
# y = y.type(torch.cuda.FloatTensor)
|
||||||
y = y.type(torch.FloatTensor)
|
y = y.type(torch.FloatTensor)
|
||||||
y.to('cuda' if self.gpus else 'cpu')
|
y = y.to('cuda' if self.gpus else 'cpu')
|
||||||
logits, _ = self.forward(X)
|
logits, _ = self.forward(X)
|
||||||
loss = self.loss(logits, y)
|
loss = self.loss(logits, y)
|
||||||
# Squashing logits through Sigmoid in order to get confidence score
|
# Squashing logits through Sigmoid in order to get confidence score
|
||||||
|
|
@ -116,7 +116,7 @@ class BertModel(pl.LightningModule):
|
||||||
X = torch.cat(X).view([X[0].shape[0], len(X)])
|
X = torch.cat(X).view([X[0].shape[0], len(X)])
|
||||||
# y = y.type(torch.cuda.FloatTensor)
|
# y = y.type(torch.cuda.FloatTensor)
|
||||||
y = y.type(torch.FloatTensor)
|
y = y.type(torch.FloatTensor)
|
||||||
y.to('cuda' if self.gpus else 'cpu')
|
y = y.to('cuda' if self.gpus else 'cpu')
|
||||||
logits, _ = self.forward(X)
|
logits, _ = self.forward(X)
|
||||||
loss = self.loss(logits, y)
|
loss = self.loss(logits, y)
|
||||||
predictions = torch.sigmoid(logits) > 0.5
|
predictions = torch.sigmoid(logits) > 0.5
|
||||||
|
|
@ -136,7 +136,7 @@ class BertModel(pl.LightningModule):
|
||||||
X = torch.cat(X).view([X[0].shape[0], len(X)])
|
X = torch.cat(X).view([X[0].shape[0], len(X)])
|
||||||
# y = y.type(torch.cuda.FloatTensor)
|
# y = y.type(torch.cuda.FloatTensor)
|
||||||
y = y.type(torch.FloatTensor)
|
y = y.type(torch.FloatTensor)
|
||||||
y.to('cuda' if self.gpus else 'cpu')
|
y = y.to('cuda' if self.gpus else 'cpu')
|
||||||
logits, _ = self.forward(X)
|
logits, _ = self.forward(X)
|
||||||
loss = self.loss(logits, y)
|
loss = self.loss(logits, y)
|
||||||
# Squashing logits through Sigmoid in order to get confidence score
|
# Squashing logits through Sigmoid in order to get confidence score
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue