f1 updated
This commit is contained in:
parent
e01006e663
commit
5d0ecfda39
|
@ -38,7 +38,7 @@ def get_ATC_acc(thres, scores):
|
|||
return np.mean(scores >= thres)
|
||||
|
||||
|
||||
def get_ATC_f1(thres, scores, probs):
|
||||
def get_ATC_f1(thres, scores, probs, average="binary"):
|
||||
preds = np.argmax(probs, axis=-1)
|
||||
estim_y = np.abs(1 - (scores >= thres) ^ preds)
|
||||
return f1_score(estim_y, preds)
|
||||
return f1_score(estim_y, preds, average=average)
|
||||
|
|
Loading…
Reference in New Issue