QuAcc/baselines/atc.py

45 lines
965 B
Python
Raw Normal View History

2023-11-05 00:15:40 +01:00
import numpy as np
2023-10-19 02:34:41 +02:00
from sklearn.metrics import f1_score
2023-09-14 01:52:19 +02:00
2023-11-05 00:15:40 +01:00
def get_entropy(probs):
return np.sum(np.multiply(probs, np.log(probs + 1e-20)), axis=1)
2023-09-14 01:52:19 +02:00
def get_max_conf(probs):
2023-11-05 00:15:40 +01:00
return np.max(probs, axis=-1)
def find_ATC_threshold(scores, labels):
2023-09-14 01:52:19 +02:00
sorted_idx = np.argsort(scores)
2023-11-05 00:15:40 +01:00
2023-09-14 01:52:19 +02:00
sorted_scores = scores[sorted_idx]
sorted_labels = labels[sorted_idx]
2023-11-05 00:15:40 +01:00
fp = np.sum(labels == 0)
2023-09-14 01:52:19 +02:00
fn = 0.0
2023-11-05 00:15:40 +01:00
2023-09-14 01:52:19 +02:00
min_fp_fn = np.abs(fp - fn)
thres = 0.0
2023-11-05 00:15:40 +01:00
for i in range(len(labels)):
if sorted_labels[i] == 0:
2023-09-14 01:52:19 +02:00
fp -= 1
2023-11-05 00:15:40 +01:00
else:
2023-09-14 01:52:19 +02:00
fn += 1
2023-11-05 00:15:40 +01:00
if np.abs(fp - fn) < min_fp_fn:
2023-09-14 01:52:19 +02:00
min_fp_fn = np.abs(fp - fn)
thres = sorted_scores[i]
2023-11-05 00:15:40 +01:00
2023-09-14 01:52:19 +02:00
return min_fp_fn, thres
2023-11-05 00:15:40 +01:00
def get_ATC_acc(thres, scores):
return np.mean(scores >= thres)
2023-10-19 02:34:41 +02:00
def get_ATC_f1(thres, scores, probs):
preds = np.argmax(probs, axis=-1)
2023-11-05 00:15:40 +01:00
estim_y = np.abs(1 - (scores >= thres) ^ preds)
2023-10-19 02:34:41 +02:00
return f1_score(estim_y, preds)