gFun/refactor/util/pl_metrics.py

142 lines
5.8 KiB
Python

import torch
from pytorch_lightning.metrics import Metric
from util.common import is_false, is_true
def _update(pred, target, device):
assert pred.shape == target.shape
# preparing preds and targets for count
true_pred = is_true(pred, device)
false_pred = is_false(pred, device)
true_target = is_true(target, device)
false_target = is_false(target, device)
tp = torch.sum(true_pred * true_target, dim=0)
tn = torch.sum(false_pred * false_target, dim=0)
fp = torch.sum(true_pred * false_target, dim=0)
fn = torch.sum(false_pred * target, dim=0)
return tp, tn, fp, fn
class CustomF1(Metric):
def __init__(self, num_classes, device, average='micro'):
"""
Custom F1 metric.
Scikit learn provides a full set of evaluation metrics, but they treat special cases differently.
I.e., when the number of true positives, false positives, and false negatives amount to 0, all
affected metrics (precision, recall, and thus f1) output 0 in Scikit learn.
We adhere to the common practice of outputting 1 in this case since the classifier has correctly
classified all examples as negatives.
:param num_classes:
:param device:
:param average:
"""
super().__init__()
self.num_classes = num_classes
self.average = average
self.device = 'cuda' if device else 'cpu'
self.add_state('true_positive', default=torch.zeros(self.num_classes))
self.add_state('true_negative', default=torch.zeros(self.num_classes))
self.add_state('false_positive', default=torch.zeros(self.num_classes))
self.add_state('false_negative', default=torch.zeros(self.num_classes))
def update(self, preds, target):
true_positive, true_negative, false_positive, false_negative = _update(preds, target, self.device)
self.true_positive += true_positive
self.true_negative += true_negative
self.false_positive += false_positive
self.false_negative += false_negative
def compute(self):
if self.average == 'micro':
num = 2.0 * self.true_positive.sum()
den = 2.0 * self.true_positive.sum() + self.false_positive.sum() + self.false_negative.sum()
if den > 0:
return (num / den).to(self.device)
return torch.FloatTensor([1.]).to(self.device)
if self.average == 'macro':
class_specific = []
for i in range(self.num_classes):
class_tp = self.true_positive[i]
class_tn = self.true_negative[i]
class_fp = self.false_positive[i]
class_fn = self.false_negative[i]
num = 2.0 * class_tp
den = 2.0 * class_tp + class_fp + class_fn
if den > 0:
class_specific.append(num / den)
else:
class_specific.append(1.)
average = torch.sum(torch.Tensor(class_specific))/self.num_classes
return average.to(self.device)
class CustomK(Metric):
def __init__(self, num_classes, device, average='micro'):
"""
K metric. https://dl.acm.org/doi/10.1145/2808194.2809449
:param num_classes:
:param device:
:param average:
"""
super().__init__()
self.num_classes = num_classes
self.average = average
self.device = 'cuda' if device else 'cpu'
self.add_state('true_positive', default=torch.zeros(self.num_classes))
self.add_state('true_negative', default=torch.zeros(self.num_classes))
self.add_state('false_positive', default=torch.zeros(self.num_classes))
self.add_state('false_negative', default=torch.zeros(self.num_classes))
def update(self, preds, target):
true_positive, true_negative, false_positive, false_negative = _update(preds, target, self.device)
self.true_positive += true_positive
self.true_negative += true_negative
self.false_positive += false_positive
self.false_negative += false_negative
def compute(self):
if self.average == 'micro':
specificity, recall = 0., 0.
absolute_negatives = self.true_negative.sum() + self.false_positive.sum()
if absolute_negatives != 0:
specificity = self.true_negative.sum()/absolute_negatives
absolute_positives = self.true_positive.sum() + self.false_negative.sum()
if absolute_positives != 0:
recall = self.true_positive.sum()/absolute_positives
if absolute_positives == 0:
return 2. * specificity - 1
elif absolute_negatives == 0:
return 2. * recall - 1
else:
return specificity + recall - 1
if self.average == 'macro':
class_specific = []
for i in range(self.num_classes):
class_tp = self.true_positive[i]
class_tn = self.true_negative[i]
class_fp = self.false_positive[i]
class_fn = self.false_negative[i]
specificity, recall = 0., 0.
absolute_negatives = class_tn + class_fp
if absolute_negatives != 0:
specificity = class_tn / absolute_negatives
absolute_positives = class_tp + class_fn
if absolute_positives != 0:
recall = class_tp / absolute_positives
if absolute_positives == 0:
class_specific.append(2. * specificity - 1)
elif absolute_negatives == 0:
class_specific.append(2. * recall - 1)
else:
class_specific.append(specificity + recall - 1)
average = torch.sum(torch.Tensor(class_specific)) / self.num_classes
return average.to(self.device)