diff --git a/.coverage b/.coverage index 7ce4189..e04d9e2 100644 Binary files a/.coverage and b/.coverage differ diff --git a/garg22_ATC/ATC_helper.py b/garg22_ATC/ATC_helper.py new file mode 100644 index 0000000..5b4dfad --- /dev/null +++ b/garg22_ATC/ATC_helper.py @@ -0,0 +1,34 @@ +import numpy as np + +def get_entropy(probs): + return np.sum( np.multiply(probs, np.log(probs + 1e-20)) , axis=1) + +def get_max_conf(probs): + return np.max(probs, axis=-1) + +def find_ATC_threshold(scores, labels): + sorted_idx = np.argsort(scores) + + sorted_scores = scores[sorted_idx] + sorted_labels = labels[sorted_idx] + + fp = np.sum(labels==0) + fn = 0.0 + + min_fp_fn = np.abs(fp - fn) + thres = 0.0 + for i in range(len(labels)): + if sorted_labels[i] == 0: + fp -= 1 + else: + fn += 1 + + if np.abs(fp - fn) < min_fp_fn: + min_fp_fn = np.abs(fp - fn) + thres = sorted_scores[i] + + return min_fp_fn, thres + + +def get_ATC_acc(thres, scores): + return np.mean(scores>=thres)*100.0 diff --git a/garg22_ATC/__pycache__/ATC_helper.cpython-311.pyc b/garg22_ATC/__pycache__/ATC_helper.cpython-311.pyc new file mode 100644 index 0000000..5b2f09d Binary files /dev/null and b/garg22_ATC/__pycache__/ATC_helper.cpython-311.pyc differ diff --git a/quacc/baseline.py b/quacc/baseline.py index e456fcb..2cc95d0 100644 --- a/quacc/baseline.py +++ b/quacc/baseline.py @@ -1,14 +1,77 @@ - from statistics import mean from typing import Dict from sklearn.base import BaseEstimator from sklearn.model_selection import cross_validate from quapy.data import LabelledCollection +from garg22_ATC.ATC_helper import ( + find_ATC_threshold, + get_ATC_acc, + get_entropy, + get_max_conf, +) +import numpy as np -def kfcv(c_model: BaseEstimator, train: LabelledCollection) -> Dict: + +def kfcv(c_model: BaseEstimator, validation: LabelledCollection) -> Dict: scoring = ["f1_macro"] - scores = cross_validate(c_model, train.X, train.y, scoring=scoring) + scores = cross_validate(c_model, validation.X, validation.y, scoring=scoring) + return {"f1_score": mean(scores["test_f1_macro"])} + + +def ATC_MC( + c_model: BaseEstimator, + validation: LabelledCollection, + test: LabelledCollection, + predict_method="predict_proba", +): + c_model_predict = getattr(c_model, predict_method) + + ## Load ID validation data probs and labels + val_probs, val_labels = c_model_predict(validation.X), validation.y + + ## Load OOD test data probs + test_probs = c_model_predict(test.X) + + ## score function, e.g., negative entropy or argmax confidence + val_scores = get_max_conf(val_probs) + val_preds = np.argmax(val_probs, axis=-1) + + test_scores = get_max_conf(test_probs) + + _, ATC_thres = find_ATC_threshold(val_scores, val_labels == val_preds) + ATC_accuracy = get_ATC_acc(ATC_thres, test_scores) + return { - "f1_score": mean(scores["test_f1_macro"]) + "true_acc": 100*np.mean(np.argmax(test_probs, axis=-1) == test.y), + "pred_acc": ATC_accuracy } + +def ATC_NE( + c_model: BaseEstimator, + validation: LabelledCollection, + test: LabelledCollection, + predict_method="predict_proba", +): + c_model_predict = getattr(c_model, predict_method) + + ## Load ID validation data probs and labels + val_probs, val_labels = c_model_predict(validation.X), validation.y + + ## Load OOD test data probs + test_probs = c_model_predict(test.X) + + ## score function, e.g., negative entropy or argmax confidence + val_scores = get_entropy(val_probs) + val_preds = np.argmax(val_probs, axis=-1) + + test_scores = get_entropy(test_probs) + + _, ATC_thres = find_ATC_threshold(val_scores, val_labels == val_preds) + ATC_accuracy = get_ATC_acc(ATC_thres, test_scores) + + return { + "true_acc": 100*np.mean(np.argmax(test_probs, axis=-1) == test.y), + "pred_acc": ATC_accuracy + } + diff --git a/quacc/dataset.py b/quacc/dataset.py index 36d1485..7521e6a 100644 --- a/quacc/dataset.py +++ b/quacc/dataset.py @@ -1,7 +1,35 @@ +from typing import Tuple +import numpy as np +from quapy.data.base import LabelledCollection import quapy as qp +from sklearn.conftest import fetch_rcv1 -def get_imdb_traintest(): - return qp.datasets.fetch_reviews("imdb", tfidf=True).train_test +TRAIN_VAL_PROP = 0.5 -def get_spambase_traintest(): - return qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test \ No newline at end of file + +def get_imdb() -> Tuple[LabelledCollection]: + train, test = qp.datasets.fetch_reviews("imdb", tfidf=True).train_test + train, validation = train.split_stratified(train_prop=TRAIN_VAL_PROP) + return train, validation, test + + +def get_spambase(): + train, test = qp.datasets.fetch_UCIDataset("spambase", verbose=False).train_test + train, validation = train.split_stratified(train_prop=TRAIN_VAL_PROP) + return train, validation, test + + +def get_rcv1(sample_size=100): + dataset = fetch_rcv1() + + target_labels = [ + (target, dataset.target[:, ind].toarray().flatten()) + for (ind, target) in enumerate(dataset.target_names) + ] + filtered_target_labels = filter( + lambda _, labels: np.sum(labels) >= sample_size, target_labels + ) + return { + target: LabelledCollection(dataset.data, labels, classes=[0, 1]) + for (target, labels) in filtered_target_labels + } diff --git a/quacc/main.py b/quacc/main.py index ca63d46..1b1dd4b 100644 --- a/quacc/main.py +++ b/quacc/main.py @@ -9,7 +9,7 @@ from quacc.estimator import ( MulticlassAccuracyEstimator, ) -from quacc.dataset import get_imdb_traintest +from quacc.dataset import get_imdb qp.environ["SAMPLE_SIZE"] = 100 @@ -20,7 +20,7 @@ dataset_name = "imdb" def estimate_multiclass(): print(dataset_name) - train, test = get_imdb_traintest(dataset_name) + train, validation, test = get_imdb(dataset_name) model = LogisticRegression() @@ -59,7 +59,7 @@ def estimate_multiclass(): def estimate_binary(): print(dataset_name) - train, test = get_imdb_traintest(dataset_name) + train, validation, test = get_imdb(dataset_name) model = LogisticRegression() diff --git a/tests/test_baseline.py b/tests/test_baseline.py index 82b0218..7351497 100644 --- a/tests/test_baseline.py +++ b/tests/test_baseline.py @@ -1,12 +1,12 @@ from sklearn.linear_model import LogisticRegression from quacc.baseline import kfcv -from quacc.dataset import get_spambase_traintest +from quacc.dataset import get_spambase class TestBaseline: def test_kfcv(self): - train, _ = get_spambase_traintest() + train, _, _ = get_spambase() c_model = LogisticRegression() assert "f1_score" in kfcv(c_model, train) \ No newline at end of file