diff --git a/refactor/models/pl_bert.py b/refactor/models/pl_bert.py index 7503a47..965690c 100644 --- a/refactor/models/pl_bert.py +++ b/refactor/models/pl_bert.py @@ -47,7 +47,7 @@ class BertModel(pl.LightningModule): X = torch.cat(X).view([X[0].shape[0], len(X)]) # y = y.type(torch.cuda.FloatTensor) y = y.type(torch.FloatTensor) - y.to('cuda' if self.gpus else 'cpu') + y = y.to('cuda' if self.gpus else 'cpu') logits, _ = self.forward(X) loss = self.loss(logits, y) # Squashing logits through Sigmoid in order to get confidence score @@ -116,7 +116,7 @@ class BertModel(pl.LightningModule): X = torch.cat(X).view([X[0].shape[0], len(X)]) # y = y.type(torch.cuda.FloatTensor) y = y.type(torch.FloatTensor) - y.to('cuda' if self.gpus else 'cpu') + y = y.to('cuda' if self.gpus else 'cpu') logits, _ = self.forward(X) loss = self.loss(logits, y) predictions = torch.sigmoid(logits) > 0.5 @@ -136,7 +136,7 @@ class BertModel(pl.LightningModule): X = torch.cat(X).view([X[0].shape[0], len(X)]) # y = y.type(torch.cuda.FloatTensor) y = y.type(torch.FloatTensor) - y.to('cuda' if self.gpus else 'cpu') + y = y.to('cuda' if self.gpus else 'cpu') logits, _ = self.forward(X) loss = self.loss(logits, y) # Squashing logits through Sigmoid in order to get confidence score