From f537ecb5e4cc3accdeee5f4d116b14c5c34233c7 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Sun, 17 Sep 2023 21:47:34 +0200 Subject: [PATCH] guillory21 imported as baseline --- .coverage | Bin 53248 -> 53248 bytes .../__pycache__/doc.cpython-311.pyc | Bin 0 -> 437 bytes guillory21_doc/doc.py | 4 ++ quacc/baseline.py | 60 +++++++++++------- 4 files changed, 42 insertions(+), 22 deletions(-) create mode 100644 guillory21_doc/__pycache__/doc.cpython-311.pyc create mode 100644 guillory21_doc/doc.py diff --git a/.coverage b/.coverage index c9d78c745be22ea3acbe8ab6ae1c7c49ce1dcae9..c3df57d9673c22a120c1172d0b9d7342ff9a96c3 100644 GIT binary patch delta 112 zcmZozz}&Ead4e<}=R_H2M$U~1tLs^x^55n^yjh@OF29cmD+?p17VEdu2FV-@Ob!VG z3=Rwo4q_}oQ9YI_h6dwgcLoNA4-5?a3>6G=9Lzv*A?AH7jzUa8mKxJP^&dd#8U}_2 M{s#WdKl|Ao08gSCZ~y=R delta 112 zcmZozz}&Ead4e<}*F+g-My`zstLs^x@!#P;vRR;EF29c`D+?p14(qqm2FV-@Ob!pk z85}P5UuP9#0SfD}R53J^r(eJSe)Z@3`Tq}BGRSc-1I2}y_p$g3G67jCOmzn5*%=sq PFfcI0>u&zp&+Y&K>3$-u diff --git a/guillory21_doc/__pycache__/doc.cpython-311.pyc b/guillory21_doc/__pycache__/doc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a98676ff1b02d78f029ef0b05b54314e53db98de GIT binary patch literal 437 zcmZ3^%ge<81gW9ysfj@PF^B^LOi;#W5g=naLkdF*V-7ag;>?)z(#)Kk{Gv)D!}yf^ Dict: @@ -19,7 +16,7 @@ def kfcv(c_model: BaseEstimator, validation: LabelledCollection) -> Dict: return {"f1_score": mean(scores["test_f1_macro"])} -def ATC_MC( +def atc_mc( c_model: BaseEstimator, validation: LabelledCollection, test: LabelledCollection, @@ -34,21 +31,23 @@ def ATC_MC( test_probs = c_model_predict(test.X) ## score function, e.g., negative entropy or argmax confidence - val_scores = get_max_conf(val_probs) + val_scores = atc.get_max_conf(val_probs) + #pred_idxv1 #calib_probsv1/probsv1 val_preds = np.argmax(val_probs, axis=-1) - - test_scores = get_max_conf(test_probs) - - _, ATC_thres = find_ATC_threshold(val_scores, val_labels == val_preds) - ATC_accuracy = get_ATC_acc(ATC_thres, test_scores) + #pred_probs_new #probs_new + test_scores = atc.get_max_conf(test_probs) + #pred_probsv1 #labelsv1 #pred_idxv1 + _, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds) + #calib_thres_balance #pred_probs_new + atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores) return { "true_acc": 100 * np.mean(np.argmax(test_probs, axis=-1) == test.y), - "pred_acc": ATC_accuracy, + "pred_acc": atc_accuracy, } -def ATC_NE( +def atc_ne( c_model: BaseEstimator, validation: LabelledCollection, test: LabelledCollection, @@ -63,17 +62,17 @@ def ATC_NE( test_probs = c_model_predict(test.X) ## score function, e.g., negative entropy or argmax confidence - val_scores = get_entropy(val_probs) + val_scores = atc.get_entropy(val_probs) val_preds = np.argmax(val_probs, axis=-1) - test_scores = get_entropy(test_probs) + test_scores = atc.get_entropy(test_probs) - _, ATC_thres = find_ATC_threshold(val_scores, val_labels == val_preds) - ATC_accuracy = get_ATC_acc(ATC_thres, test_scores) + _, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds) + atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores) return { "true_acc": 100 * np.mean(np.argmax(test_probs, axis=-1) == test.y), - "pred_acc": ATC_accuracy, + "pred_acc": atc_accuracy, } @@ -87,8 +86,25 @@ def trust_score( test_pred = c_model_predict(test.X) - trust_model = TrustScore() + trust_model = trustscore.TrustScore() trust_model.fit(validation.X, validation.y) return trust_model.get_score(test.X, test_pred) + +def doc_feat( + c_model: BaseEstimator, + validation: LabelledCollection, + test: LabelledCollection, + predict_method="predict_proba", +): + c_model_predict = getattr(c_model, predict_method) + + val_probs, val_labels = c_model_predict(validation.X), validation.y + test_probs = c_model_predict(test.X) + val_scores = np.max(val_probs, axis=-1) + test_scores = np.max(test_probs, axis=-1) + val_preds = np.argmax(val_probs, axis=-1) + + v1acc = np.mean(val_preds == val_labels)*100 + return v1acc + doc.get_doc(val_scores, test_scores)