QuAcc/quacc/baseline.py

166 lines
5.0 KiB
Python
Raw Normal View History

2023-09-13 00:11:20 +02:00
from statistics import mean
2023-09-22 01:40:36 +02:00
from typing import Dict
import numpy as np
import quapy as qp
from quapy.data import LabelledCollection
2023-09-13 00:11:20 +02:00
from sklearn.base import BaseEstimator
from sklearn.model_selection import cross_validate
2023-09-22 01:40:36 +02:00
import elsahar19_rca.rca as rca
2023-09-17 21:47:34 +02:00
import garg22_ATC.ATC_helper as atc
import guillory21_doc.doc as doc
2023-09-22 01:40:36 +02:00
import jiang18_trustscore.trustscore as trustscore
import lipton_bbse.labelshift as bbse
2023-09-13 00:11:20 +02:00
2023-09-14 01:52:19 +02:00
def kfcv(c_model: BaseEstimator, validation: LabelledCollection) -> Dict:
2023-09-13 00:11:20 +02:00
scoring = ["f1_macro"]
2023-09-14 01:52:19 +02:00
scores = cross_validate(c_model, validation.X, validation.y, scoring=scoring)
return {"f1_score": mean(scores["test_f1_macro"])}
2023-09-17 21:47:34 +02:00
def atc_mc(
2023-09-14 01:52:19 +02:00
c_model: BaseEstimator,
validation: LabelledCollection,
test: LabelledCollection,
predict_method="predict_proba",
):
c_model_predict = getattr(c_model, predict_method)
## Load ID validation data probs and labels
val_probs, val_labels = c_model_predict(validation.X), validation.y
## Load OOD test data probs
test_probs = c_model_predict(test.X)
## score function, e.g., negative entropy or argmax confidence
2023-09-17 21:47:34 +02:00
val_scores = atc.get_max_conf(val_probs)
2023-09-14 01:52:19 +02:00
val_preds = np.argmax(val_probs, axis=-1)
2023-09-17 21:47:34 +02:00
test_scores = atc.get_max_conf(test_probs)
2023-09-18 09:24:20 +02:00
_, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds)
atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores)
2023-09-14 01:52:19 +02:00
2023-09-13 00:11:20 +02:00
return {
2023-09-16 01:59:49 +02:00
"true_acc": 100 * np.mean(np.argmax(test_probs, axis=-1) == test.y),
2023-09-17 21:47:34 +02:00
"pred_acc": atc_accuracy,
2023-09-13 00:11:20 +02:00
}
2023-09-14 01:52:19 +02:00
2023-09-16 01:59:49 +02:00
2023-09-17 21:47:34 +02:00
def atc_ne(
2023-09-14 01:52:19 +02:00
c_model: BaseEstimator,
validation: LabelledCollection,
test: LabelledCollection,
predict_method="predict_proba",
):
c_model_predict = getattr(c_model, predict_method)
## Load ID validation data probs and labels
val_probs, val_labels = c_model_predict(validation.X), validation.y
## Load OOD test data probs
test_probs = c_model_predict(test.X)
## score function, e.g., negative entropy or argmax confidence
2023-09-17 21:47:34 +02:00
val_scores = atc.get_entropy(val_probs)
2023-09-14 01:52:19 +02:00
val_preds = np.argmax(val_probs, axis=-1)
2023-09-17 21:47:34 +02:00
test_scores = atc.get_entropy(test_probs)
2023-09-14 01:52:19 +02:00
2023-09-17 21:47:34 +02:00
_, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds)
atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores)
2023-09-14 01:52:19 +02:00
return {
2023-09-16 01:59:49 +02:00
"true_acc": 100 * np.mean(np.argmax(test_probs, axis=-1) == test.y),
2023-09-17 21:47:34 +02:00
"pred_acc": atc_accuracy,
2023-09-14 01:52:19 +02:00
}
2023-09-16 01:59:49 +02:00
def trust_score(
c_model: BaseEstimator,
validation: LabelledCollection,
test: LabelledCollection,
predict_method="predict",
):
c_model_predict = getattr(c_model, predict_method)
test_pred = c_model_predict(test.X)
2023-09-17 21:47:34 +02:00
trust_model = trustscore.TrustScore()
2023-09-16 01:59:49 +02:00
trust_model.fit(validation.X, validation.y)
return trust_model.get_score(test.X, test_pred)
2023-09-17 21:47:34 +02:00
def doc_feat(
c_model: BaseEstimator,
validation: LabelledCollection,
test: LabelledCollection,
predict_method="predict_proba",
):
c_model_predict = getattr(c_model, predict_method)
val_probs, val_labels = c_model_predict(validation.X), validation.y
test_probs = c_model_predict(test.X)
val_scores = np.max(val_probs, axis=-1)
test_scores = np.max(test_probs, axis=-1)
val_preds = np.argmax(val_probs, axis=-1)
2023-09-18 09:24:20 +02:00
v1acc = np.mean(val_preds == val_labels) * 100
2023-09-17 21:47:34 +02:00
return v1acc + doc.get_doc(val_scores, test_scores)
2023-09-18 18:19:13 +02:00
def rca_score(
c_model: BaseEstimator,
validation: LabelledCollection,
test: LabelledCollection,
predict_method="predict",
):
c_model_predict = getattr(c_model, predict_method)
test_pred = c_model_predict(test.X)
c_model2 = rca.clone_fit(test.X, test_pred)
c_model2_predict = getattr(c_model2, predict_method)
val_pred1 = c_model_predict(validation.X)
val_pred2 = c_model2_predict(validation.X)
return rca.get_score(val_pred1, val_pred2, validation.y)
def rca_star_score(
c_model: BaseEstimator,
validation: LabelledCollection,
test: LabelledCollection,
predict_method="predict",
):
c_model_predict = getattr(c_model, predict_method)
validation1, validation2 = validation.split_stratified(train_prop=0.5)
test_pred = c_model_predict(test.X)
val1_pred = c_model_predict(validation1.X)
c_model1 = rca.clone_fit(validation1.X, val1_pred)
c_model2 = rca.clone_fit(test.X, test_pred)
c_model1_predict = getattr(c_model1, predict_method)
c_model2_predict = getattr(c_model2, predict_method)
val2_pred1 = c_model1_predict(validation2.X)
val2_pred2 = c_model2_predict(validation2.X)
return rca.get_score(val2_pred1, val2_pred2, validation2.y)
2023-09-22 01:40:36 +02:00
def bbse_score(
c_model: BaseEstimator,
validation: LabelledCollection,
test: LabelledCollection,
predict_method="predict_proba",
):
c_model_predict = getattr(c_model, predict_method)
val_probs, val_labels = c_model_predict(validation.X), validation.y
test_probs = c_model_predict(test.X)
wt = bbse.estimate_labelshift_ratio(val_labels, val_probs, test_probs, 2)
estim_prev = bbse.estimate_target_dist(wt, val_labels, 2)
true_prev = test.prevalence()
return qp.error.ae(true_prev, estim_prev)