f1 updated

This commit is contained in:
Lorenzo Volpi 2023-12-21 16:47:20 +01:00
parent e01006e663
commit 5d0ecfda39
1 changed files with 2 additions and 2 deletions

View File

@ -38,7 +38,7 @@ def get_ATC_acc(thres, scores):
return np.mean(scores >= thres)
def get_ATC_f1(thres, scores, probs):
def get_ATC_f1(thres, scores, probs, average="binary"):
preds = np.argmax(probs, axis=-1)
estim_y = np.abs(1 - (scores >= thres) ^ preds)
return f1_score(estim_y, preds)
return f1_score(estim_y, preds, average=average)