2023-10-23 03:14:35 +02:00
|
|
|
import numpy as np
|
|
|
|
import sklearn.metrics as metrics
|
2023-10-19 02:36:53 +02:00
|
|
|
from quapy.data import LabelledCollection
|
|
|
|
from quapy.protocol import (
|
|
|
|
AbstractStochasticSeededProtocol,
|
|
|
|
OnLabelledCollectionProtocol,
|
|
|
|
)
|
|
|
|
from sklearn.base import BaseEstimator
|
|
|
|
|
|
|
|
import quacc.error as error
|
|
|
|
from quacc.evaluation.report import EvaluationReport
|
|
|
|
|
|
|
|
from ..estimator import (
|
|
|
|
AccuracyEstimator,
|
|
|
|
BinaryQuantifierAccuracyEstimator,
|
|
|
|
MulticlassAccuracyEstimator,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def estimate(
|
|
|
|
estimator: AccuracyEstimator,
|
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
|
|
):
|
|
|
|
# ensure that the protocol returns a LabelledCollection for each iteration
|
|
|
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
|
|
|
|
2023-10-23 03:14:35 +02:00
|
|
|
base_prevs, true_prevs, estim_prevs, pred_probas, labels = [], [], [], [], []
|
2023-10-19 02:36:53 +02:00
|
|
|
for sample in protocol():
|
2023-10-23 03:14:35 +02:00
|
|
|
e_sample, pred_proba = estimator.extend(sample)
|
2023-10-19 02:36:53 +02:00
|
|
|
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
|
|
|
base_prevs.append(sample.prevalence())
|
|
|
|
true_prevs.append(e_sample.prevalence())
|
|
|
|
estim_prevs.append(estim_prev)
|
2023-10-23 03:14:35 +02:00
|
|
|
pred_probas.append(pred_proba)
|
|
|
|
labels.append(sample.y)
|
2023-10-19 02:36:53 +02:00
|
|
|
|
2023-10-23 03:14:35 +02:00
|
|
|
return base_prevs, true_prevs, estim_prevs, pred_probas, labels
|
2023-10-19 02:36:53 +02:00
|
|
|
|
|
|
|
|
|
|
|
def evaluation_report(
|
|
|
|
estimator: AccuracyEstimator,
|
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
|
|
method: str,
|
|
|
|
) -> EvaluationReport:
|
2023-10-23 03:14:35 +02:00
|
|
|
base_prevs, true_prevs, estim_prevs, pred_probas, labels = estimate(
|
|
|
|
estimator, protocol
|
|
|
|
)
|
|
|
|
report = EvaluationReport(name=method)
|
|
|
|
|
|
|
|
for base_prev, true_prev, estim_prev, pred_proba, label in zip(
|
|
|
|
base_prevs, true_prevs, estim_prevs, pred_probas, labels
|
|
|
|
):
|
|
|
|
pred = np.argmax(pred_proba, axis=-1)
|
2023-10-19 02:36:53 +02:00
|
|
|
acc_score = error.acc(estim_prev)
|
|
|
|
f1_score = error.f1(estim_prev)
|
|
|
|
report.append_row(
|
2023-10-20 23:36:05 +02:00
|
|
|
base_prev,
|
2023-10-23 03:14:35 +02:00
|
|
|
acc_score=acc_score,
|
|
|
|
acc=abs(metrics.accuracy_score(label, pred) - acc_score),
|
2023-10-19 02:36:53 +02:00
|
|
|
f1_score=f1_score,
|
2023-10-20 23:36:05 +02:00
|
|
|
f1=abs(error.f1(true_prev) - f1_score),
|
2023-10-19 02:36:53 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
return report
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
|
|
method: str,
|
2023-10-23 03:14:35 +02:00
|
|
|
q_model: str,
|
|
|
|
**kwargs,
|
2023-10-19 02:36:53 +02:00
|
|
|
):
|
2023-10-20 23:36:05 +02:00
|
|
|
estimator: AccuracyEstimator = {
|
2023-10-19 02:36:53 +02:00
|
|
|
"bin": BinaryQuantifierAccuracyEstimator,
|
|
|
|
"mul": MulticlassAccuracyEstimator,
|
2023-10-23 03:14:35 +02:00
|
|
|
}[method](c_model, q_model=q_model, **kwargs)
|
2023-10-19 02:36:53 +02:00
|
|
|
estimator.fit(validation)
|
2023-10-23 03:14:35 +02:00
|
|
|
_method = f"{method}_{q_model}"
|
|
|
|
for k, v in kwargs.items():
|
|
|
|
_method += f"_{v}"
|
|
|
|
return evaluation_report(estimator, protocol, _method)
|
2023-10-19 02:36:53 +02:00
|
|
|
|
|
|
|
|
2023-10-20 23:36:05 +02:00
|
|
|
def evaluate_bin_sld(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
|
|
) -> EvaluationReport:
|
2023-10-23 03:14:35 +02:00
|
|
|
return evaluate(c_model, validation, protocol, "bin", "SLD")
|
2023-10-20 23:36:05 +02:00
|
|
|
|
|
|
|
|
|
|
|
def evaluate_mul_sld(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
|
|
) -> EvaluationReport:
|
2023-10-23 03:14:35 +02:00
|
|
|
return evaluate(c_model, validation, protocol, "mul", "SLD")
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_bin_sld_nbvs(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
|
|
) -> EvaluationReport:
|
|
|
|
return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="nbvs")
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_mul_sld_nbvs(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
|
|
) -> EvaluationReport:
|
|
|
|
return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="nbvs")
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_bin_sld_bcts(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
|
|
) -> EvaluationReport:
|
|
|
|
return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="bcts")
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_mul_sld_bcts(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
|
|
) -> EvaluationReport:
|
|
|
|
return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="bcts")
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_bin_sld_ts(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
|
|
) -> EvaluationReport:
|
|
|
|
return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="ts")
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_mul_sld_ts(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
|
|
) -> EvaluationReport:
|
|
|
|
return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="ts")
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_bin_sld_vs(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
|
|
) -> EvaluationReport:
|
|
|
|
return evaluate(c_model, validation, protocol, "bin", "SLD", recalib="vs")
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_mul_sld_vs(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
|
|
) -> EvaluationReport:
|
|
|
|
return evaluate(c_model, validation, protocol, "mul", "SLD", recalib="vs")
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_bin_cc(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
|
|
) -> EvaluationReport:
|
|
|
|
return evaluate(c_model, validation, protocol, "bin", "CC")
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_mul_cc(
|
|
|
|
c_model: BaseEstimator,
|
|
|
|
validation: LabelledCollection,
|
|
|
|
protocol: AbstractStochasticSeededProtocol,
|
|
|
|
) -> EvaluationReport:
|
|
|
|
return evaluate(c_model, validation, protocol, "mul", "CC")
|