2023-09-13 00:11:20 +02:00
|
|
|
from statistics import mean
|
|
|
|
from typing import Dict
|
|
|
|
from sklearn.base import BaseEstimator
|
|
|
|
from sklearn.model_selection import cross_validate
|
|
|
|
from quapy.data import LabelledCollection
|
2023-09-17 21:47:34 +02:00
|
|
|
import garg22_ATC.ATC_helper as atc
|
2023-09-14 01:52:19 +02:00
|
|
|
import numpy as np
|
2023-09-17 21:47:34 +02:00
|
|
|
import jiang18_trustscore.trustscore as trustscore
|
|
|
|
import guillory21_doc.doc as doc
|
2023-09-13 00:11:20 +02:00
|
|
|
|
|
|
|
|
2023-09-14 01:52:19 +02:00
|
|
|
def kfcv(c_model: BaseEstimator, validation: LabelledCollection) -> Dict:
|
2023-09-13 00:11:20 +02:00
|
|
|
scoring = ["f1_macro"]
|
2023-09-14 01:52:19 +02:00
|
|
|
scores = cross_validate(c_model, validation.X, validation.y, scoring=scoring)
|
|
|
|
return {"f1_score": mean(scores["test_f1_macro"])}
|
|
|
|
|
|
|
|
|
2023-09-17 21:47:34 +02:00
|
|
|
def atc_mc(
|
2023-09-14 01:52:19 +02:00
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
test: LabelledCollection,
|
|
|
|
predict_method="predict_proba",
|
|
|
|
):
|
|
|
|
c_model_predict = getattr(c_model, predict_method)
|
|
|
|
|
|
|
|
## Load ID validation data probs and labels
|
|
|
|
val_probs, val_labels = c_model_predict(validation.X), validation.y
|
|
|
|
|
|
|
|
## Load OOD test data probs
|
|
|
|
test_probs = c_model_predict(test.X)
|
|
|
|
|
|
|
|
## score function, e.g., negative entropy or argmax confidence
|
2023-09-17 21:47:34 +02:00
|
|
|
val_scores = atc.get_max_conf(val_probs)
|
2023-09-14 01:52:19 +02:00
|
|
|
val_preds = np.argmax(val_probs, axis=-1)
|
2023-09-17 21:47:34 +02:00
|
|
|
test_scores = atc.get_max_conf(test_probs)
|
2023-09-18 09:24:20 +02:00
|
|
|
|
|
|
|
_, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds)
|
|
|
|
atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores)
|
2023-09-14 01:52:19 +02:00
|
|
|
|
2023-09-13 00:11:20 +02:00
|
|
|
return {
|
2023-09-16 01:59:49 +02:00
|
|
|
"true_acc": 100 * np.mean(np.argmax(test_probs, axis=-1) == test.y),
|
2023-09-17 21:47:34 +02:00
|
|
|
"pred_acc": atc_accuracy,
|
2023-09-13 00:11:20 +02:00
|
|
|
}
|
2023-09-14 01:52:19 +02:00
|
|
|
|
2023-09-16 01:59:49 +02:00
|
|
|
|
2023-09-17 21:47:34 +02:00
|
|
|
def atc_ne(
|
2023-09-14 01:52:19 +02:00
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
test: LabelledCollection,
|
|
|
|
predict_method="predict_proba",
|
|
|
|
):
|
|
|
|
c_model_predict = getattr(c_model, predict_method)
|
|
|
|
|
|
|
|
## Load ID validation data probs and labels
|
|
|
|
val_probs, val_labels = c_model_predict(validation.X), validation.y
|
|
|
|
|
|
|
|
## Load OOD test data probs
|
|
|
|
test_probs = c_model_predict(test.X)
|
|
|
|
|
|
|
|
## score function, e.g., negative entropy or argmax confidence
|
2023-09-17 21:47:34 +02:00
|
|
|
val_scores = atc.get_entropy(val_probs)
|
2023-09-14 01:52:19 +02:00
|
|
|
val_preds = np.argmax(val_probs, axis=-1)
|
|
|
|
|
2023-09-17 21:47:34 +02:00
|
|
|
test_scores = atc.get_entropy(test_probs)
|
2023-09-14 01:52:19 +02:00
|
|
|
|
2023-09-17 21:47:34 +02:00
|
|
|
_, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds)
|
|
|
|
atc_accuracy = atc.get_ATC_acc(atc_thres, test_scores)
|
2023-09-14 01:52:19 +02:00
|
|
|
|
|
|
|
return {
|
2023-09-16 01:59:49 +02:00
|
|
|
"true_acc": 100 * np.mean(np.argmax(test_probs, axis=-1) == test.y),
|
2023-09-17 21:47:34 +02:00
|
|
|
"pred_acc": atc_accuracy,
|
2023-09-14 01:52:19 +02:00
|
|
|
}
|
|
|
|
|
2023-09-16 01:59:49 +02:00
|
|
|
|
|
|
|
def trust_score(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
test: LabelledCollection,
|
|
|
|
predict_method="predict",
|
|
|
|
):
|
|
|
|
c_model_predict = getattr(c_model, predict_method)
|
|
|
|
|
|
|
|
test_pred = c_model_predict(test.X)
|
|
|
|
|
2023-09-17 21:47:34 +02:00
|
|
|
trust_model = trustscore.TrustScore()
|
2023-09-16 01:59:49 +02:00
|
|
|
trust_model.fit(validation.X, validation.y)
|
|
|
|
|
|
|
|
return trust_model.get_score(test.X, test_pred)
|
|
|
|
|
2023-09-17 21:47:34 +02:00
|
|
|
|
|
|
|
def doc_feat(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
test: LabelledCollection,
|
|
|
|
predict_method="predict_proba",
|
|
|
|
):
|
|
|
|
c_model_predict = getattr(c_model, predict_method)
|
|
|
|
|
|
|
|
val_probs, val_labels = c_model_predict(validation.X), validation.y
|
|
|
|
test_probs = c_model_predict(test.X)
|
|
|
|
val_scores = np.max(val_probs, axis=-1)
|
|
|
|
test_scores = np.max(test_probs, axis=-1)
|
|
|
|
val_preds = np.argmax(val_probs, axis=-1)
|
|
|
|
|
2023-09-18 09:24:20 +02:00
|
|
|
v1acc = np.mean(val_preds == val_labels) * 100
|
2023-09-17 21:47:34 +02:00
|
|
|
return v1acc + doc.get_doc(val_scores, test_scores)
|