QuAcc/quacc/evaluation/method.py

136 lines
3.8 KiB
Python
Raw Normal View History

from functools import wraps
import numpy as np
import sklearn.metrics as metrics
from quapy.data import LabelledCollection
from quapy.protocol import AbstractStochasticSeededProtocol
from sklearn.base import BaseEstimator
import quacc.error as error
from quacc.evaluation.report import EvaluationReport
from ..estimator import (
AccuracyEstimator,
BinaryQuantifierAccuracyEstimator,
MulticlassAccuracyEstimator,
)
_methods = {}
def method(func):
@wraps(func)
def wrapper(c_model, validation, protocol):
return func(c_model, validation, protocol)
_methods[func.__name__] = wrapper
return wrapper
def estimate(
estimator: AccuracyEstimator,
protocol: AbstractStochasticSeededProtocol,
):
base_prevs, true_prevs, estim_prevs, pred_probas, labels = [], [], [], [], []
for sample in protocol():
e_sample, pred_proba = estimator.extend(sample)
estim_prev = estimator.estimate(e_sample.X, ext=True)
base_prevs.append(sample.prevalence())
true_prevs.append(e_sample.prevalence())
estim_prevs.append(estim_prev)
pred_probas.append(pred_proba)
labels.append(sample.y)
return base_prevs, true_prevs, estim_prevs, pred_probas, labels
def evaluation_report(
estimator: AccuracyEstimator,
protocol: AbstractStochasticSeededProtocol,
method: str,
) -> EvaluationReport:
base_prevs, true_prevs, estim_prevs, pred_probas, labels = estimate(
estimator, protocol
)
report = EvaluationReport(name=method)
for base_prev, true_prev, estim_prev, pred_proba, label in zip(
base_prevs, true_prevs, estim_prevs, pred_probas, labels
):
pred = np.argmax(pred_proba, axis=-1)
acc_score = error.acc(estim_prev)
f1_score = error.f1(estim_prev)
report.append_row(
base_prev,
acc_score=acc_score,
acc=abs(metrics.accuracy_score(label, pred) - acc_score),
f1_score=f1_score,
f1=abs(error.f1(true_prev) - f1_score),
)
report.fit_score = estimator.fit_score
return report
def evaluate(
c_model: BaseEstimator,
validation: LabelledCollection,
protocol: AbstractStochasticSeededProtocol,
method: str,
q_model: str,
**kwargs,
):
estimator: AccuracyEstimator = {
"bin": BinaryQuantifierAccuracyEstimator,
"mul": MulticlassAccuracyEstimator,
}[method](c_model, q_model=q_model.upper(), **kwargs)
estimator.fit(validation)
_method = f"{method}_{q_model}"
if "recalib" in kwargs:
_method += f"_{kwargs['recalib']}"
if ("gs", True) in kwargs.items():
_method += "_gs"
return evaluation_report(estimator, protocol, _method)
@method
def bin_sld(c_model, validation, protocol) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "bin", "sld")
@method
def mul_sld(c_model, validation, protocol) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "mul", "sld")
@method
def bin_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "bin", "sld", recalib="bcts")
@method
def mul_sld_bcts(c_model, validation, protocol) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "mul", "sld", recalib="bcts")
@method
def bin_sld_gs(c_model, validation, protocol) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "bin", "sld", gs=True)
@method
def mul_sld_gs(c_model, validation, protocol) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "mul", "sld", gs=True)
@method
def bin_cc(c_model, validation, protocol) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "bin", "cc")
@method
def mul_cc(c_model, validation, protocol) -> EvaluationReport:
return evaluate(c_model, validation, protocol, "mul", "cc")