Comments fixed
This commit is contained in:
parent
f537ecb5e4
commit
623f6f96f3
|
@ -32,14 +32,11 @@ def atc_mc(
|
|||
|
||||
## score function, e.g., negative entropy or argmax confidence
|
||||
val_scores = atc.get_max_conf(val_probs)
|
||||
#pred_idxv1 #calib_probsv1/probsv1
|
||||
val_preds = np.argmax(val_probs, axis=-1)
|
||||
#pred_probs_new #probs_new
|
||||
test_scores = atc.get_max_conf(test_probs)
|
||||
#pred_probsv1 #labelsv1 #pred_idxv1
|
||||
_, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds)
|
||||
#calib_thres_balance #pred_probs_new
|
||||
atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores)
|
||||
|
||||
_, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds)
|
||||
atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores)
|
||||
|
||||
return {
|
||||
"true_acc": 100 * np.mean(np.argmax(test_probs, axis=-1) == test.y),
|
||||
|
@ -106,5 +103,5 @@ def doc_feat(
|
|||
test_scores = np.max(test_probs, axis=-1)
|
||||
val_preds = np.argmax(val_probs, axis=-1)
|
||||
|
||||
v1acc = np.mean(val_preds == val_labels)*100
|
||||
v1acc = np.mean(val_preds == val_labels) * 100
|
||||
return v1acc + doc.get_doc(val_scores, test_scores)
|
||||
|
|
Loading…
Reference in New Issue