From ae0ea1e68c56fe4e37466b90905c493277a0e926 Mon Sep 17 00:00:00 2001 From: andrea Date: Mon, 25 Jan 2021 12:48:02 +0100 Subject: [PATCH] Implemented inference functions for bert (cpu and gpu) --- refactor/models/pl_bert.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/refactor/models/pl_bert.py b/refactor/models/pl_bert.py index 7503a47..965690c 100644 --- a/refactor/models/pl_bert.py +++ b/refactor/models/pl_bert.py @@ -47,7 +47,7 @@ class BertModel(pl.LightningModule): X = torch.cat(X).view([X[0].shape[0], len(X)]) # y = y.type(torch.cuda.FloatTensor) y = y.type(torch.FloatTensor) - y.to('cuda' if self.gpus else 'cpu') + y = y.to('cuda' if self.gpus else 'cpu') logits, _ = self.forward(X) loss = self.loss(logits, y) # Squashing logits through Sigmoid in order to get confidence score @@ -116,7 +116,7 @@ class BertModel(pl.LightningModule): X = torch.cat(X).view([X[0].shape[0], len(X)]) # y = y.type(torch.cuda.FloatTensor) y = y.type(torch.FloatTensor) - y.to('cuda' if self.gpus else 'cpu') + y = y.to('cuda' if self.gpus else 'cpu') logits, _ = self.forward(X) loss = self.loss(logits, y) predictions = torch.sigmoid(logits) > 0.5 @@ -136,7 +136,7 @@ class BertModel(pl.LightningModule): X = torch.cat(X).view([X[0].shape[0], len(X)]) # y = y.type(torch.cuda.FloatTensor) y = y.type(torch.FloatTensor) - y.to('cuda' if self.gpus else 'cpu') + y = y.to('cuda' if self.gpus else 'cpu') logits, _ = self.forward(X) loss = self.loss(logits, y) # Squashing logits through Sigmoid in order to get confidence score