116 lines
3.4 KiB
Python
116 lines
3.4 KiB
Python
from functools import wraps
|
|
|
|
import numpy as np
|
|
import quapy.functional as F
|
|
import sklearn.metrics as metrics
|
|
from quapy.method.aggregative import ACC, EMQ
|
|
from sklearn import clone
|
|
from sklearn.linear_model import LogisticRegression
|
|
|
|
import quacc as qc
|
|
from quacc.evaluation.report import EvaluationReport
|
|
|
|
_alts = {}
|
|
|
|
|
|
def alt(func):
|
|
@wraps(func)
|
|
def wrapper(c_model, validation, protocol):
|
|
return func(c_model, validation, protocol)
|
|
|
|
wrapper.name = func.__name__
|
|
_alts[func.__name__] = wrapper
|
|
|
|
return wrapper
|
|
|
|
|
|
@alt
|
|
def cross(c_model, validation, protocol):
|
|
y_val = validation.labels
|
|
y_hat_val = c_model.predict(validation.instances)
|
|
|
|
qcls = clone(c_model)
|
|
qcls.fit(*validation.Xy)
|
|
|
|
er = EvaluationReport(name="cross")
|
|
for sample in protocol():
|
|
y_hat = c_model.predict(sample.instances)
|
|
y = sample.labels
|
|
ground_acc = (y_hat == y).mean()
|
|
ground_f1 = metrics.f1_score(y, y_hat, zero_division=0)
|
|
|
|
q = EMQ(qcls)
|
|
q.fit(validation, fit_classifier=False)
|
|
|
|
M_hat = ACC.getPteCondEstim(validation.classes_, y_val, y_hat_val)
|
|
p_hat = q.quantify(sample.instances)
|
|
cont_table_hat = p_hat * M_hat
|
|
|
|
acc_score = qc.error.acc(cont_table_hat)
|
|
f1_score = qc.error.f1(cont_table_hat)
|
|
|
|
meta_acc = abs(acc_score - ground_acc)
|
|
meta_f1 = abs(f1_score - ground_f1)
|
|
er.append_row(
|
|
sample.prevalence(),
|
|
acc=meta_acc,
|
|
f1=meta_f1,
|
|
acc_score=acc_score,
|
|
f1_score=f1_score,
|
|
)
|
|
|
|
return er
|
|
|
|
|
|
@alt
|
|
def cross2(c_model, validation, protocol):
|
|
classes = validation.classes_
|
|
y_val = validation.labels
|
|
y_hat_val = c_model.predict(validation.instances)
|
|
M_hat = ACC.getPteCondEstim(classes, y_val, y_hat_val)
|
|
pos_prev_val = validation.prevalence()[1]
|
|
|
|
er = EvaluationReport(name="cross2")
|
|
for sample in protocol():
|
|
y_test = sample.labels
|
|
y_hat_test = c_model.predict(sample.instances)
|
|
ground_acc = (y_hat_test == y_test).mean()
|
|
ground_f1 = metrics.f1_score(y_test, y_hat_test, zero_division=0)
|
|
pos_prev_cc = F.prevalence_from_labels(y_hat_test, classes)[1]
|
|
tpr_hat = M_hat[1, 1]
|
|
fpr_hat = M_hat[1, 0]
|
|
tnr_hat = M_hat[0, 0]
|
|
pos_prev_test_hat = (pos_prev_cc - fpr_hat) / (tpr_hat - fpr_hat)
|
|
pos_prev_test_hat = np.clip(pos_prev_test_hat, 0, 1)
|
|
|
|
if pos_prev_val > 0.5:
|
|
# in this case, the tpr might be a more reliable estimate than tnr
|
|
A = np.asarray(
|
|
[[0, 0, 1, 1], [0, 1, 0, 1], [1, 1, 1, 1], [0, tpr_hat, 0, tpr_hat - 1]]
|
|
)
|
|
else:
|
|
# in this case, the tnr might be a more reliable estimate than tpr
|
|
A = np.asarray(
|
|
[[0, 0, 1, 1], [0, 1, 0, 1], [1, 1, 1, 1], [tnr_hat - 1, 0, tnr_hat, 0]]
|
|
)
|
|
|
|
b = np.asarray([pos_prev_cc, pos_prev_test_hat, 1, 0])
|
|
|
|
tn, fn, fp, tp = np.linalg.solve(A, b)
|
|
cont_table_hat = np.array([[tn, fp], [fn, tp]])
|
|
|
|
acc_score = qc.error.acc(cont_table_hat)
|
|
f1_score = qc.error.f1(cont_table_hat)
|
|
|
|
meta_acc = abs(acc_score - ground_acc)
|
|
meta_f1 = abs(f1_score - ground_f1)
|
|
er.append_row(
|
|
sample.prevalence(),
|
|
acc=meta_acc,
|
|
f1=meta_f1,
|
|
acc_score=acc_score,
|
|
f1_score=f1_score,
|
|
)
|
|
|
|
return er
|