Implemented inference functions for bert (cpu and gpu)

This commit is contained in:
andrea 2021-01-25 12:48:02 +01:00
parent 6e0b66e13e
commit ae0ea1e68c
1 changed files with 3 additions and 3 deletions

View File

@ -47,7 +47,7 @@ class BertModel(pl.LightningModule):
X = torch.cat(X).view([X[0].shape[0], len(X)]) X = torch.cat(X).view([X[0].shape[0], len(X)])
# y = y.type(torch.cuda.FloatTensor) # y = y.type(torch.cuda.FloatTensor)
y = y.type(torch.FloatTensor) y = y.type(torch.FloatTensor)
y.to('cuda' if self.gpus else 'cpu') y = y.to('cuda' if self.gpus else 'cpu')
logits, _ = self.forward(X) logits, _ = self.forward(X)
loss = self.loss(logits, y) loss = self.loss(logits, y)
# Squashing logits through Sigmoid in order to get confidence score # Squashing logits through Sigmoid in order to get confidence score
@ -116,7 +116,7 @@ class BertModel(pl.LightningModule):
X = torch.cat(X).view([X[0].shape[0], len(X)]) X = torch.cat(X).view([X[0].shape[0], len(X)])
# y = y.type(torch.cuda.FloatTensor) # y = y.type(torch.cuda.FloatTensor)
y = y.type(torch.FloatTensor) y = y.type(torch.FloatTensor)
y.to('cuda' if self.gpus else 'cpu') y = y.to('cuda' if self.gpus else 'cpu')
logits, _ = self.forward(X) logits, _ = self.forward(X)
loss = self.loss(logits, y) loss = self.loss(logits, y)
predictions = torch.sigmoid(logits) > 0.5 predictions = torch.sigmoid(logits) > 0.5
@ -136,7 +136,7 @@ class BertModel(pl.LightningModule):
X = torch.cat(X).view([X[0].shape[0], len(X)]) X = torch.cat(X).view([X[0].shape[0], len(X)])
# y = y.type(torch.cuda.FloatTensor) # y = y.type(torch.cuda.FloatTensor)
y = y.type(torch.FloatTensor) y = y.type(torch.FloatTensor)
y.to('cuda' if self.gpus else 'cpu') y = y.to('cuda' if self.gpus else 'cpu')
logits, _ = self.forward(X) logits, _ = self.forward(X)
loss = self.loss(logits, y) loss = self.loss(logits, y)
# Squashing logits through Sigmoid in order to get confidence score # Squashing logits through Sigmoid in order to get confidence score