From 37392c654521a847a01af9d71b4502203662f845 Mon Sep 17 00:00:00 2001 From: Lorenzo Volpi Date: Thu, 14 Sep 2023 01:52:19 +0200 Subject: [PATCH] ATC baseline added, rcv1 dataset added --- .coverage | Bin 53248 -> 53248 bytes garg22_ATC/ATC_helper.py | 34 +++++++++ .../__pycache__/ATC_helper.cpython-311.pyc | Bin 0 -> 1824 bytes quacc/baseline.py | 71 +++++++++++++++++- quacc/dataset.py | 36 ++++++++- quacc/main.py | 6 +- tests/test_baseline.py | 4 +- 7 files changed, 138 insertions(+), 13 deletions(-) create mode 100644 garg22_ATC/ATC_helper.py create mode 100644 garg22_ATC/__pycache__/ATC_helper.cpython-311.pyc diff --git a/.coverage b/.coverage index 7ce4189d0f7865664c1e56671a5955c6966ec45e..e04d9e2fa668df2b66fdf5e3c682ff32f7a9abf9 100644 GIT binary patch delta 103 zcmZozz}&Eac>`Mm&t?YxpZu@+AM&5%-@I8+U?IPo2rCOCrxxqC(+0^L3``CQ0t^le z3=U!}Kv6xGDuxE*WOoJzh7Sx3{0tQgavaP+aUte?ERH-(K$a9!oxnkM28No=Kl|Ao E0D{OD`2YX_ delta 97 zcmZozz}&Eac>`Mm&rSyZpZp*AU+`bx-?>>(U^%~oC@Tvirw;44(+0^L3``CW#2Fke z_FrceV*v{5u~acMl&4?6|9=thres)*100.0 diff --git a/garg22_ATC/__pycache__/ATC_helper.cpython-311.pyc b/garg22_ATC/__pycache__/ATC_helper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b2f09d08c4dafe16fc86de164b96ec6081d0d56 GIT binary patch literal 1824 zcmb6aO>Y}j@V))YI&0TSYG_GNVq=ov0~G`j(gr9P2UG;Pv>-JyGLCocnAp4Ktxc0? zjeOwXNvbLN!CEO&B$PvLks=|~3(VWtZk#HIj=Zz;W@qNj zxBu|@1`)vb|1?sUScLwjhb}N3;&=*(`$$GI&Y>kNGryoEMrHxCG6$HGdBD7kZ=#5B zNR<{-xT8y79(ceuI{q2ZeN;U4?S~%S$L-#XjDBfq#RUU=*#PPx=s)V&cnErrdb%?6 znC<6)f81gw9NqN(y7hGakm~JVhfA(y^$6!MYQbUkQr;2srCc#v$gMkUPRoDM0QP7^`n!NCXiprh3HO+$WG?(Z^V{s7lYdR_yxj`V+2Og? z&{ccr>XVQq##&;`7GoBU9fK0;%QFWIH_umqI4{pMbB~!9^1?12{s%yKP>nET0A3E8 zPp%OOsFTQvk21Y90{DufokhL#kiSH)CMc^@B{j zTto(Y+M}{qyJZX*Vw!PBh0wHRx&ac}UlHM?HwK`?b&V8>2uz1fF6$0Yl4?eQnNifp zpd;uh7!TbMa>->Sr#pjGT9Ff3dCl>4=#EJD*pXz?1&2wijwhd06X`+%NWQp2ZQ-m= zw^gtQCoMUW&Z;sUwuCFTqUB@~202V5;(U zyGI>(q}r=>!Ni-MmS@!Vj8>v;kGB#%7#e8`bz(*vcN+etzjCAP^;Z`5d?WYXtN+xP zZ_c-T6Si;S2#Nfpi>kOST{eepY0Ug_R|;BEurCfSZYxbwij<*(jo2SVE+JMd<8;op5JP**Hz%;LsgoT=XM z&q2WF&J;G#DxJPFoej$rxKMLuu~Y8bpLDkXmsgT1973Q)X!Gi{D}uq3+`9{nraqZU z5&Azt#q=>ikaP&~gZixb{^mz5DP&8b?a7ujZE@4C_YuM2)lwezAYC@@GIZC8TRPX& z4sr$Ei=K%e<~6yLQ$8jU5T~nHUjlT*VvO5p#OnTzIE)wY5juqhjG;U|B)o!OKSDk9 N4#scMuV;~a_y0o_djJ3c literal 0 HcmV?d00001 diff --git a/quacc/baseline.py b/quacc/baseline.py index e456fcb..2cc95d0 100644 --- a/quacc/baseline.py +++ b/quacc/baseline.py @@ -1,14 +1,77 @@ - from statistics import mean from typing import Dict from sklearn.base import BaseEstimator from sklearn.model_selection import cross_validate from quapy.data import LabelledCollection +from garg22_ATC.ATC_helper import ( + find_ATC_threshold, + get_ATC_acc, + get_entropy, + get_max_conf, +) +import numpy as np -def kfcv(c_model: BaseEstimator, train: LabelledCollection) -> Dict: + +def kfcv(c_model: BaseEstimator, validation: LabelledCollection) -> Dict: scoring = ["f1_macro"] - scores = cross_validate(c_model, train.X, train.y, scoring=scoring) + scores = cross_validate(c_model, validation.X, validation.y, scoring=scoring) + return {"f1_score": mean(scores["test_f1_macro"])} + + +def ATC_MC( + c_model: BaseEstimator, + validation: LabelledCollection, + test: LabelledCollection, + predict_method="predict_proba", +): + c_model_predict = getattr(c_model, predict_method) + + ## Load ID validation data probs and labels + val_probs, val_labels = c_model_predict(validation.X), validation.y + + ## Load OOD test data probs + test_probs = c_model_predict(test.X) + + ## score function, e.g., negative entropy or argmax confidence + val_scores = get_max_conf(val_probs) + val_preds = np.argmax(val_probs, axis=-1) + + test_scores = get_max_conf(test_probs) + + _, ATC_thres = find_ATC_threshold(val_scores, val_labels == val_preds) + ATC_accuracy = get_ATC_acc(ATC_thres, test_scores) + return { - "f1_score": mean(scores["test_f1_macro"]) + "true_acc": 100*np.mean(np.argmax(test_probs, axis=-1) == test.y), + "pred_acc": ATC_accuracy } + +def ATC_NE( + c_model: BaseEstimator, + validation: LabelledCollection, + test: LabelledCollection, + predict_method="predict_proba", +): + c_model_predict = getattr(c_model, predict_method) + + ## Load ID validation data probs and labels + val_probs, val_labels = c_model_predict(validation.X), validation.y + + ## Load OOD test data probs + test_probs = c_model_predict(test.X) + + ## score function, e.g., negative entropy or argmax confidence + val_scores = get_entropy(val_probs) + val_preds = np.argmax(val_probs, axis=-1) + + test_scores = get_entropy(test_probs) + + _, ATC_thres = find_ATC_threshold(val_scores, val_labels == val_preds) + ATC_accuracy = get_ATC_acc(ATC_thres, test_scores) + + return { + "true_acc": 100*np.mean(np.argmax(test_probs, axis=-1) == test.y), + "pred_acc": ATC_accuracy + } + diff --git a/quacc/dataset.py b/quacc/dataset.py index 36d1485..7521e6a 100644 --- a/quacc/dataset.py +++ b/quacc/dataset.py @@ -1,7 +1,35 @@ +from typing import Tuple +import numpy as np +from quapy.data.base import LabelledCollection import quapy as qp +from sklearn.conftest import fetch_rcv1 -def get_imdb_traintest(): - return qp.datasets.fetch_reviews("imdb", tfidf=True).train_test +TRAIN_VAL_PROP = 0.5 -def get_spambase_traintest(): - return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test \ No newline at end of file + +def get_imdb() -> Tuple[LabelledCollection]: + train, test = qp.datasets.fetch_reviews("imdb", tfidf=True).train_test + train, validation = train.split_stratified(train_prop=TRAIN_VAL_PROP) + return train, validation, test + + +def get_spambase(): + train, test = qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test + train, validation = train.split_stratified(train_prop=TRAIN_VAL_PROP) + return train, validation, test + + +def get_rcv1(sample_size=100): + dataset = fetch_rcv1() + + target_labels = [ + (target, dataset.target[:, ind].toarray().flatten()) + for (ind, target) in enumerate(dataset.target_names) + ] + filtered_target_labels = filter( + lambda _, labels: np.sum(labels) >= sample_size, target_labels + ) + return { + target: LabelledCollection(dataset.data, labels, classes=[0, 1]) + for (target, labels) in filtered_target_labels + } diff --git a/quacc/main.py b/quacc/main.py index ca63d46..1b1dd4b 100644 --- a/quacc/main.py +++ b/quacc/main.py @@ -9,7 +9,7 @@ from quacc.estimator import ( MulticlassAccuracyEstimator, ) -from quacc.dataset import get_imdb_traintest +from quacc.dataset import get_imdb qp.environ["SAMPLE_SIZE"] = 100 @@ -20,7 +20,7 @@ dataset_name = "imdb" def estimate_multiclass(): print(dataset_name) - train, test = get_imdb_traintest(dataset_name) + train, validation, test = get_imdb(dataset_name) model = LogisticRegression() @@ -59,7 +59,7 @@ def estimate_multiclass(): def estimate_binary(): print(dataset_name) - train, test = get_imdb_traintest(dataset_name) + train, validation, test = get_imdb(dataset_name) model = LogisticRegression() diff --git a/tests/test_baseline.py b/tests/test_baseline.py index 82b0218..7351497 100644 --- a/tests/test_baseline.py +++ b/tests/test_baseline.py @@ -1,12 +1,12 @@ from sklearn.linear_model import LogisticRegression from quacc.baseline import kfcv -from quacc.dataset import get_spambase_traintest +from quacc.dataset import get_spambase class TestBaseline: def test_kfcv(self): - train, _ = get_spambase_traintest() + train, _, _ = get_spambase() c_model = LogisticRegression() assert "f1_score" in kfcv(c_model, train) \ No newline at end of file