2023-09-13 00:11:20 +02:00
|
|
|
from statistics import mean
|
2023-09-22 01:40:36 +02:00
|
|
|
from typing import Dict
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import quapy as qp
|
|
|
|
from quapy.data import LabelledCollection
|
2023-09-13 00:11:20 +02:00
|
|
|
from sklearn.base import BaseEstimator
|
|
|
|
from sklearn.model_selection import cross_validate
|
2023-09-24 02:21:18 +02:00
|
|
|
from quapy.protocol import (
|
|
|
|
AbstractStochasticSeededProtocol,
|
|
|
|
OnLabelledCollectionProtocol,
|
|
|
|
)
|
2023-09-22 01:40:36 +02:00
|
|
|
|
|
|
|
import elsahar19_rca.rca as rca
|
2023-09-17 21:47:34 +02:00
|
|
|
import garg22_ATC.ATC_helper as atc
|
|
|
|
import guillory21_doc.doc as doc
|
2023-09-22 01:40:36 +02:00
|
|
|
import jiang18_trustscore.trustscore as trustscore
|
|
|
|
import lipton_bbse.labelshift as bbse
|
2023-09-24 02:21:18 +02:00
|
|
|
import pandas as pd
|
|
|
|
import statistics as stats
|
2023-09-22 01:40:36 +02:00
|
|
|
|
2023-09-13 00:11:20 +02:00
|
|
|
|
2023-09-14 01:52:19 +02:00
|
|
|
def kfcv(c_model: BaseEstimator, validation: LabelledCollection) -> Dict:
|
2023-09-13 00:11:20 +02:00
|
|
|
scoring = ["f1_macro"]
|
2023-09-14 01:52:19 +02:00
|
|
|
scores = cross_validate(c_model, validation.X, validation.y, scoring=scoring)
|
|
|
|
return {"f1_score": mean(scores["test_f1_macro"])}
|
|
|
|
|
|
|
|
|
2023-09-24 02:21:18 +02:00
|
|
|
def avg_groupby_distribution(results):
|
|
|
|
def base_prev(s):
|
|
|
|
return (s[("base", "F")], s[("base", "T")])
|
|
|
|
|
|
|
|
grouped_list = {}
|
|
|
|
for r in results:
|
|
|
|
bp = base_prev(r)
|
|
|
|
if bp in grouped_list.keys():
|
|
|
|
grouped_list[bp].append(r)
|
|
|
|
else:
|
|
|
|
grouped_list[bp] = [r]
|
|
|
|
|
|
|
|
series = []
|
|
|
|
for (fp, tp), r_list in grouped_list.items():
|
|
|
|
assert len(r_list) > 0
|
|
|
|
r_avg = {}
|
|
|
|
r_avg[("base", "F")], r_avg[("base", "T")] = fp, tp
|
|
|
|
for pn in [(n1, n2) for ((n1, n2), _) in r_list[0].items() if n1 != "base"]:
|
|
|
|
r_avg[pn] = stats.mean(map(lambda r: r[pn], r_list))
|
|
|
|
series.append(r_avg)
|
|
|
|
|
|
|
|
return series
|
|
|
|
|
|
|
|
|
2023-09-17 21:47:34 +02:00
|
|
|
def atc_mc(
|
2023-09-14 01:52:19 +02:00
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
2023-09-24 02:21:18 +02:00
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
2023-09-14 01:52:19 +02:00
|
|
|
predict_method="predict_proba",
|
|
|
|
):
|
|
|
|
c_model_predict = getattr(c_model, predict_method)
|
|
|
|
|
|
|
|
## Load ID validation data probs and labels
|
|
|
|
val_probs, val_labels = c_model_predict(validation.X), validation.y
|
|
|
|
|
|
|
|
## score function, e.g., negative entropy or argmax confidence
|
2023-09-17 21:47:34 +02:00
|
|
|
val_scores = atc.get_max_conf(val_probs)
|
2023-09-14 01:52:19 +02:00
|
|
|
val_preds = np.argmax(val_probs, axis=-1)
|
2023-09-18 09:24:20 +02:00
|
|
|
_, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds)
|
2023-09-14 01:52:19 +02:00
|
|
|
|
2023-09-24 02:21:18 +02:00
|
|
|
# ensure that the protocol returns a LabelledCollection for each iteration
|
|
|
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
|
|
|
|
|
|
|
cols = [
|
|
|
|
("base", "F"),
|
|
|
|
("base", "T"),
|
2023-09-26 14:25:02 +02:00
|
|
|
("atc mc", "accuracy"),
|
2023-09-24 02:21:18 +02:00
|
|
|
]
|
|
|
|
results = []
|
|
|
|
for test in protocol():
|
|
|
|
## Load OOD test data probs
|
|
|
|
test_probs = c_model_predict(test.X)
|
|
|
|
test_scores = atc.get_max_conf(test_probs)
|
|
|
|
atc_accuracy = 1.0 - (atc.get_ATC_acc(atc_thres, test_scores) / 100.0)
|
|
|
|
[f_prev, t_prev] = test.prevalence()
|
|
|
|
results.append({k: v for k, v in zip(cols, [f_prev, t_prev, atc_accuracy])})
|
|
|
|
|
|
|
|
series = avg_groupby_distribution(results)
|
|
|
|
return pd.DataFrame(
|
|
|
|
series,
|
|
|
|
columns=pd.MultiIndex.from_tuples(cols),
|
|
|
|
)
|
2023-09-14 01:52:19 +02:00
|
|
|
|
2023-09-16 01:59:49 +02:00
|
|
|
|
2023-09-17 21:47:34 +02:00
|
|
|
def atc_ne(
|
2023-09-14 01:52:19 +02:00
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
2023-09-24 02:21:18 +02:00
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
2023-09-14 01:52:19 +02:00
|
|
|
predict_method="predict_proba",
|
|
|
|
):
|
|
|
|
c_model_predict = getattr(c_model, predict_method)
|
|
|
|
|
|
|
|
## Load ID validation data probs and labels
|
|
|
|
val_probs, val_labels = c_model_predict(validation.X), validation.y
|
|
|
|
|
|
|
|
## score function, e.g., negative entropy or argmax confidence
|
2023-09-17 21:47:34 +02:00
|
|
|
val_scores = atc.get_entropy(val_probs)
|
2023-09-14 01:52:19 +02:00
|
|
|
val_preds = np.argmax(val_probs, axis=-1)
|
2023-09-17 21:47:34 +02:00
|
|
|
_, atc_thres = atc.find_ATC_threshold(val_scores, val_labels == val_preds)
|
2023-09-14 01:52:19 +02:00
|
|
|
|
2023-09-24 02:21:18 +02:00
|
|
|
# ensure that the protocol returns a LabelledCollection for each iteration
|
|
|
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
|
|
|
|
|
|
|
cols = [
|
|
|
|
("base", "F"),
|
|
|
|
("base", "T"),
|
2023-09-26 14:25:02 +02:00
|
|
|
("atc ne", "accuracy"),
|
2023-09-24 02:21:18 +02:00
|
|
|
]
|
|
|
|
results = []
|
|
|
|
for test in protocol():
|
|
|
|
## Load OOD test data probs
|
|
|
|
test_probs = c_model_predict(test.X)
|
|
|
|
test_scores = atc.get_entropy(test_probs)
|
|
|
|
atc_accuracy = 1.0 - (atc.get_ATC_acc(atc_thres, test_scores) / 100.0)
|
|
|
|
[f_prev, t_prev] = test.prevalence()
|
|
|
|
results.append({k: v for k, v in zip(cols, [f_prev, t_prev, atc_accuracy])})
|
|
|
|
|
|
|
|
series = avg_groupby_distribution(results)
|
|
|
|
return pd.DataFrame(
|
|
|
|
series,
|
|
|
|
columns=pd.MultiIndex.from_tuples(cols),
|
|
|
|
)
|
2023-09-14 01:52:19 +02:00
|
|
|
|
2023-09-16 01:59:49 +02:00
|
|
|
|
|
|
|
def trust_score(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
test: LabelledCollection,
|
|
|
|
predict_method="predict",
|
|
|
|
):
|
|
|
|
c_model_predict = getattr(c_model, predict_method)
|
|
|
|
|
|
|
|
test_pred = c_model_predict(test.X)
|
|
|
|
|
2023-09-17 21:47:34 +02:00
|
|
|
trust_model = trustscore.TrustScore()
|
2023-09-16 01:59:49 +02:00
|
|
|
trust_model.fit(validation.X, validation.y)
|
|
|
|
|
|
|
|
return trust_model.get_score(test.X, test_pred)
|
|
|
|
|
2023-09-17 21:47:34 +02:00
|
|
|
|
|
|
|
def doc_feat(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
2023-09-24 02:21:18 +02:00
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
2023-09-17 21:47:34 +02:00
|
|
|
predict_method="predict_proba",
|
|
|
|
):
|
|
|
|
c_model_predict = getattr(c_model, predict_method)
|
|
|
|
|
|
|
|
val_probs, val_labels = c_model_predict(validation.X), validation.y
|
|
|
|
val_scores = np.max(val_probs, axis=-1)
|
|
|
|
val_preds = np.argmax(val_probs, axis=-1)
|
2023-09-18 09:24:20 +02:00
|
|
|
v1acc = np.mean(val_preds == val_labels) * 100
|
2023-09-24 02:21:18 +02:00
|
|
|
|
|
|
|
# ensure that the protocol returns a LabelledCollection for each iteration
|
|
|
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
|
|
|
|
|
|
|
cols = [
|
|
|
|
("base", "F"),
|
|
|
|
("base", "T"),
|
2023-09-26 14:25:02 +02:00
|
|
|
("doc feat", "score"),
|
2023-09-24 02:21:18 +02:00
|
|
|
]
|
|
|
|
results = []
|
|
|
|
for test in protocol():
|
|
|
|
test_probs = c_model_predict(test.X)
|
|
|
|
test_scores = np.max(test_probs, axis=-1)
|
|
|
|
score = 1.0 - ((v1acc + doc.get_doc(val_scores, test_scores)) / 100.0)
|
|
|
|
[f_prev, t_prev] = test.prevalence()
|
|
|
|
results.append({k: v for k, v in zip(cols, [f_prev, t_prev, score])})
|
|
|
|
|
|
|
|
series = avg_groupby_distribution(results)
|
|
|
|
return pd.DataFrame(
|
|
|
|
series,
|
|
|
|
columns=pd.MultiIndex.from_tuples(cols),
|
|
|
|
)
|
2023-09-18 18:19:13 +02:00
|
|
|
|
|
|
|
|
|
|
|
def rca_score(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
2023-09-24 02:21:18 +02:00
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
2023-09-18 18:19:13 +02:00
|
|
|
predict_method="predict",
|
|
|
|
):
|
|
|
|
c_model_predict = getattr(c_model, predict_method)
|
|
|
|
val_pred1 = c_model_predict(validation.X)
|
|
|
|
|
2023-09-24 02:21:18 +02:00
|
|
|
# ensure that the protocol returns a LabelledCollection for each iteration
|
|
|
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
|
|
|
|
|
|
|
cols = [
|
|
|
|
("base", "F"),
|
|
|
|
("base", "T"),
|
|
|
|
("rca", "score"),
|
|
|
|
]
|
|
|
|
results = []
|
|
|
|
for test in protocol():
|
2023-09-26 07:58:40 +02:00
|
|
|
try:
|
|
|
|
[f_prev, t_prev] = test.prevalence()
|
2023-09-24 02:21:18 +02:00
|
|
|
test_pred = c_model_predict(test.X)
|
|
|
|
c_model2 = rca.clone_fit(c_model, test.X, test_pred)
|
|
|
|
c_model2_predict = getattr(c_model2, predict_method)
|
|
|
|
val_pred2 = c_model2_predict(validation.X)
|
2023-09-26 07:58:40 +02:00
|
|
|
rca_score = rca.get_score(val_pred1, val_pred2, validation.y)
|
2023-09-24 02:21:18 +02:00
|
|
|
results.append({k: v for k, v in zip(cols, [f_prev, t_prev, rca_score])})
|
|
|
|
except ValueError:
|
|
|
|
results.append({k: v for k, v in zip(cols, [f_prev, t_prev, float("nan")])})
|
|
|
|
|
|
|
|
series = avg_groupby_distribution(results)
|
|
|
|
return pd.DataFrame(
|
|
|
|
series,
|
|
|
|
columns=pd.MultiIndex.from_tuples(cols),
|
|
|
|
)
|
|
|
|
|
2023-09-18 18:19:13 +02:00
|
|
|
|
|
|
|
def rca_star_score(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
2023-09-24 02:21:18 +02:00
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
2023-09-18 18:19:13 +02:00
|
|
|
predict_method="predict",
|
|
|
|
):
|
|
|
|
c_model_predict = getattr(c_model, predict_method)
|
|
|
|
validation1, validation2 = validation.split_stratified(train_prop=0.5)
|
|
|
|
val1_pred = c_model_predict(validation1.X)
|
2023-09-24 02:21:18 +02:00
|
|
|
c_model1 = rca.clone_fit(c_model, validation1.X, val1_pred)
|
2023-09-18 18:19:13 +02:00
|
|
|
c_model1_predict = getattr(c_model1, predict_method)
|
|
|
|
val2_pred1 = c_model1_predict(validation2.X)
|
|
|
|
|
2023-09-24 02:21:18 +02:00
|
|
|
# ensure that the protocol returns a LabelledCollection for each iteration
|
|
|
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
|
|
|
|
|
|
|
cols = [
|
|
|
|
("base", "F"),
|
|
|
|
("base", "T"),
|
|
|
|
("rca*", "score"),
|
|
|
|
]
|
|
|
|
results = []
|
|
|
|
for test in protocol():
|
|
|
|
[f_prev, t_prev] = test.prevalence()
|
|
|
|
try:
|
|
|
|
test_pred = c_model_predict(test.X)
|
|
|
|
c_model2 = rca.clone_fit(c_model, test.X, test_pred)
|
|
|
|
c_model2_predict = getattr(c_model2, predict_method)
|
|
|
|
val2_pred2 = c_model2_predict(validation2.X)
|
2023-09-26 07:58:40 +02:00
|
|
|
rca_star_score = rca.get_score(val2_pred1, val2_pred2, validation2.y)
|
2023-09-24 02:21:18 +02:00
|
|
|
results.append(
|
|
|
|
{k: v for k, v in zip(cols, [f_prev, t_prev, rca_star_score])}
|
|
|
|
)
|
|
|
|
except ValueError:
|
|
|
|
results.append({k: v for k, v in zip(cols, [f_prev, t_prev, float("nan")])})
|
|
|
|
|
|
|
|
series = avg_groupby_distribution(results)
|
|
|
|
return pd.DataFrame(
|
|
|
|
series,
|
|
|
|
columns=pd.MultiIndex.from_tuples(cols),
|
|
|
|
)
|
|
|
|
|
2023-09-18 18:19:13 +02:00
|
|
|
|
2023-09-22 01:40:36 +02:00
|
|
|
def bbse_score(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
2023-09-24 02:21:18 +02:00
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
2023-09-22 01:40:36 +02:00
|
|
|
predict_method="predict_proba",
|
|
|
|
):
|
|
|
|
c_model_predict = getattr(c_model, predict_method)
|
|
|
|
val_probs, val_labels = c_model_predict(validation.X), validation.y
|
|
|
|
|
2023-09-24 02:21:18 +02:00
|
|
|
# ensure that the protocol returns a LabelledCollection for each iteration
|
|
|
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
|
|
|
|
|
|
|
cols = [
|
|
|
|
("base", "F"),
|
|
|
|
("base", "T"),
|
|
|
|
("bbse", "score"),
|
|
|
|
]
|
|
|
|
results = []
|
|
|
|
for test in protocol():
|
|
|
|
test_probs = c_model_predict(test.X)
|
|
|
|
wt = bbse.estimate_labelshift_ratio(val_labels, val_probs, test_probs, 2)
|
|
|
|
estim_prev = bbse.estimate_target_dist(wt, val_labels, 2)[1]
|
|
|
|
true_prev = test.prevalence()
|
|
|
|
[f_prev, t_prev] = true_prev
|
|
|
|
acc = qp.error.ae(true_prev, estim_prev)
|
|
|
|
results.append({k: v for k, v in zip(cols, [f_prev, t_prev, acc])})
|
|
|
|
|
|
|
|
series = avg_groupby_distribution(results)
|
|
|
|
return pd.DataFrame(
|
|
|
|
series,
|
|
|
|
columns=pd.MultiIndex.from_tuples(cols),
|
|
|
|
)
|