2023-10-27 12:37:18 +02:00
|
|
|
from functools import wraps
|
|
|
|
|
2023-10-23 03:14:35 +02:00
|
|
|
import numpy as np
|
|
|
|
import sklearn.metrics as metrics
|
2023-10-19 02:36:53 +02:00
|
|
|
from quapy.data import LabelledCollection
|
2023-10-27 12:37:18 +02:00
|
|
|
from quapy.protocol import AbstractStochasticSeededProtocol
|
2023-10-19 02:36:53 +02:00
|
|
|
from sklearn.base import BaseEstimator
|
|
|
|
|
|
|
|
import quacc.error as error
|
|
|
|
from quacc.evaluation.report import EvaluationReport
|
|
|
|
|
|
|
|
from ..estimator import (
|
|
|
|
AccuracyEstimator,
|
|
|
|
BinaryQuantifierAccuracyEstimator,
|
|
|
|
MulticlassAccuracyEstimator,
|
|
|
|
)
|
|
|
|
|
2023-10-27 12:37:18 +02:00
|
|
|
_methods = {}
|
|
|
|
|
|
|
|
|
|
|
|
def method(func):
|
|
|
|
@wraps(func)
|
|
|
|
def wrapper(c_model, validation, protocol):
|
|
|
|
return func(c_model, validation, protocol)
|
|
|
|
|
|
|
|
_methods[func.__name__] = wrapper
|
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
2023-10-19 02:36:53 +02:00
|
|
|
|
|
|
|
def estimate(
|
|
|
|
estimator: AccuracyEstimator,
|
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
|
|
):
|
2023-10-23 03:14:35 +02:00
|
|
|
base_prevs, true_prevs, estim_prevs, pred_probas, labels = [], [], [], [], []
|
2023-10-19 02:36:53 +02:00
|
|
|
for sample in protocol():
|
2023-10-23 03:14:35 +02:00
|
|
|
e_sample, pred_proba = estimator.extend(sample)
|
2023-10-19 02:36:53 +02:00
|
|
|
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
|
|
|
base_prevs.append(sample.prevalence())
|
|
|
|
true_prevs.append(e_sample.prevalence())
|
|
|
|
estim_prevs.append(estim_prev)
|
2023-10-23 03:14:35 +02:00
|
|
|
pred_probas.append(pred_proba)
|
|
|
|
labels.append(sample.y)
|
2023-10-19 02:36:53 +02:00
|
|
|
|
2023-10-23 03:14:35 +02:00
|
|
|
return base_prevs, true_prevs, estim_prevs, pred_probas, labels
|
2023-10-19 02:36:53 +02:00
|
|
|
|
|
|
|
|
|
|
|
def evaluation_report(
|
|
|
|
estimator: AccuracyEstimator,
|
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
|
|
method: str,
|
|
|
|
) -> EvaluationReport:
|
2023-10-23 03:14:35 +02:00
|
|
|
base_prevs, true_prevs, estim_prevs, pred_probas, labels = estimate(
|
|
|
|
estimator, protocol
|
|
|
|
)
|
|
|
|
report = EvaluationReport(name=method)
|
|
|
|
|
|
|
|
for base_prev, true_prev, estim_prev, pred_proba, label in zip(
|
|
|
|
base_prevs, true_prevs, estim_prevs, pred_probas, labels
|
|
|
|
):
|
|
|
|
pred = np.argmax(pred_proba, axis=-1)
|
2023-10-19 02:36:53 +02:00
|
|
|
acc_score = error.acc(estim_prev)
|
|
|
|
f1_score = error.f1(estim_prev)
|
|
|
|
report.append_row(
|
2023-10-20 23:36:05 +02:00
|
|
|
base_prev,
|
2023-10-23 03:14:35 +02:00
|
|
|
acc_score=acc_score,
|
|
|
|
acc=abs(metrics.accuracy_score(label, pred) - acc_score),
|
2023-10-19 02:36:53 +02:00
|
|
|
f1_score=f1_score,
|
2023-10-20 23:36:05 +02:00
|
|
|
f1=abs(error.f1(true_prev) - f1_score),
|
2023-10-19 02:36:53 +02:00
|
|
|
)
|
|
|
|
|
2023-10-27 12:37:18 +02:00
|
|
|
report.fit_score = estimator.fit_score
|
|
|
|
|
2023-10-19 02:36:53 +02:00
|
|
|
return report
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
|
|
method: str,
|
2023-10-23 03:14:35 +02:00
|
|
|
q_model: str,
|
|
|
|
**kwargs,
|
2023-10-19 02:36:53 +02:00
|
|
|
):
|
2023-10-20 23:36:05 +02:00
|
|
|
estimator: AccuracyEstimator = {
|
2023-10-19 02:36:53 +02:00
|
|
|
"bin": BinaryQuantifierAccuracyEstimator,
|
|
|
|
"mul": MulticlassAccuracyEstimator,
|
2023-10-27 12:37:18 +02:00
|
|
|
}[method](c_model, q_model=q_model.upper(), **kwargs)
|
2023-10-19 02:36:53 +02:00
|
|
|
estimator.fit(validation)
|
2023-10-23 03:14:35 +02:00
|
|
|
_method = f"{method}_{q_model}"
|
2023-10-27 12:37:18 +02:00
|
|
|
if "recalib" in kwargs:
|
|
|
|
_method += f"_{kwargs['recalib']}"
|
|
|
|
if ("gs", True) in kwargs.items():
|
|
|
|
_method += "_gs"
|
2023-10-23 03:14:35 +02:00
|
|
|
return evaluation_report(estimator, protocol, _method)
|
2023-10-19 02:36:53 +02:00
|
|
|
|
|
|
|
|
2023-10-27 12:37:18 +02:00
|
|
|
@method
|
|
|
|
def bin_sld(c_model, validation, protocol) -> EvaluationReport:
|
|
|
|
return evaluate(c_model, validation, protocol, "bin", "sld")
|
2023-10-20 23:36:05 +02:00
|
|
|
|
|
|
|
|
2023-10-27 12:37:18 +02:00
|
|
|
@method
|
|
|
|
def mul_sld(c_model, validation, protocol) -> EvaluationReport:
|
|
|
|
return evaluate(c_model, validation, protocol, "mul", "sld")
|
2023-10-23 03:14:35 +02:00
|
|
|
|
|
|
|
|
2023-10-27 12:37:18 +02:00
|
|
|
@method
|
|
|
|
def bin_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
|
|
|
|
return evaluate(c_model, validation, protocol, "bin", "sld", recalib="bcts")
|
2023-10-23 03:14:35 +02:00
|
|
|
|
|
|
|
|
2023-10-27 12:37:18 +02:00
|
|
|
@method
|
|
|
|
def mul_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
|
|
|
|
return evaluate(c_model, validation, protocol, "mul", "sld", recalib="bcts")
|
2023-10-23 03:14:35 +02:00
|
|
|
|
|
|
|
|
2023-10-27 12:37:18 +02:00
|
|
|
@method
|
|
|
|
def bin_sld_gs(c_model, validation, protocol) -> EvaluationReport:
|
|
|
|
return evaluate(c_model, validation, protocol, "bin", "sld", gs=True)
|
2023-10-23 03:14:35 +02:00
|
|
|
|
|
|
|
|
2023-10-27 12:37:18 +02:00
|
|
|
@method
|
|
|
|
def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport:
|
|
|
|
return evaluate(c_model, validation, protocol, "mul", "sld", gs=True)
|
2023-10-23 03:14:35 +02:00
|
|
|
|
|
|
|
|
2023-10-27 12:37:18 +02:00
|
|
|
@method
|
|
|
|
def bin_cc(c_model, validation, protocol) -> EvaluationReport:
|
|
|
|
return evaluate(c_model, validation, protocol, "bin", "cc")
|
2023-10-23 03:14:35 +02:00
|
|
|
|
|
|
|
|
2023-10-27 12:37:18 +02:00
|
|
|
@method
|
|
|
|
def mul_cc(c_model, validation, protocol) -> EvaluationReport:
|
|
|
|
return evaluate(c_model, validation, protocol, "mul", "cc")
|