QuAcc/garg22_ATC/ATC_helper.py

35 lines
791 B
Python
Raw Normal View History

2023-09-14 01:52:19 +02:00
import numpy as np
def get_entropy(probs):
return np.sum( np.multiply(probs, np.log(probs + 1e-20)) , axis=1)
def get_max_conf(probs):
return np.max(probs, axis=-1)
def find_ATC_threshold(scores, labels):
sorted_idx = np.argsort(scores)
sorted_scores = scores[sorted_idx]
sorted_labels = labels[sorted_idx]
fp = np.sum(labels==0)
fn = 0.0
min_fp_fn = np.abs(fp - fn)
thres = 0.0
for i in range(len(labels)):
if sorted_labels[i] == 0:
fp -= 1
else:
fn += 1
if np.abs(fp - fn) < min_fp_fn:
min_fp_fn = np.abs(fp - fn)
thres = sorted_scores[i]
return min_fp_fn, thres
def get_ATC_acc(thres, scores):
return np.mean(scores>=thres)*100.0