Comments fixed

This commit is contained in:
Lorenzo Volpi 2023-09-18 09:24:20 +02:00
parent f537ecb5e4
commit 623f6f96f3
1 changed files with 4 additions and 7 deletions

View File

@ -32,14 +32,11 @@ def atc_mc(
## score function, e.g., negative entropy or argmax confidence
val_scores = atc.get_max_conf(val_probs)
#pred_idxv1 #calib_probsv1/probsv1
val_preds = np.argmax(val_probs, axis=-1)
#pred_probs_new #probs_new
test_scores = atc.get_max_conf(test_probs)
#pred_probsv1 #labelsv1 #pred_idxv1
_, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds)
#calib_thres_balance #pred_probs_new
atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores)
_, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds)
atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores)
return {
"true_acc": 100 * np.mean(np.argmax(test_probs, axis=-1) == test.y),
@ -106,5 +103,5 @@ def doc_feat(
test_scores = np.max(test_probs, axis=-1)
val_preds = np.argmax(val_probs, axis=-1)
v1acc = np.mean(val_preds == val_labels)*100
v1acc = np.mean(val_preds == val_labels) * 100
return v1acc + doc.get_doc(val_scores, test_scores)