From 623f6f96f332f2ad32db90dd43316880d8f5420c Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Mon, 18 Sep 2023 09:24:20 +0200 Subject: [PATCH] Comments fixed --- quacc/baseline.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/quacc/baseline.py b/quacc/baseline.py index c508dce..c4aadee 100644 --- a/quacc/baseline.py +++ b/quacc/baseline.py @@ -32,14 +32,11 @@ def atc_mc( ## score function, e.g., negative entropy or argmax confidence val_scores = atc.get_max_conf(val_probs) - #pred_idxv1 #calib_probsv1/probsv1 val_preds = np.argmax(val_probs, axis=-1) - #pred_probs_new #probs_new test_scores = atc.get_max_conf(test_probs) - #pred_probsv1 #labelsv1 #pred_idxv1 - _, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds) - #calib_thres_balance #pred_probs_new - atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores) + + _, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds) + atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores) return { "true_acc": 100 * np.mean(np.argmax(test_probs, axis=-1) == test.y), @@ -106,5 +103,5 @@ def doc_feat( test_scores = np.max(test_probs, axis=-1) val_preds = np.argmax(val_probs, axis=-1) - v1acc = np.mean(val_preds == val_labels)*100 + v1acc = np.mean(val_preds == val_labels) * 100 return v1acc + doc.get_doc(val_scores, test_scores)