35 lines
989 B
Python
35 lines
989 B
Python
from typing import Callable, Union
|
|
|
|
import numpy as np
|
|
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
|
|
|
|
import quacc as qc
|
|
|
|
from ..method.base import BaseAccuracyEstimator
|
|
|
|
|
|
def evaluate(
|
|
estimator: BaseAccuracyEstimator,
|
|
protocol: AbstractProtocol,
|
|
error_metric: Union[Callable | str],
|
|
) -> float:
|
|
if isinstance(error_metric, str):
|
|
error_metric = qc.error.from_name(error_metric)
|
|
|
|
collator_bck_ = protocol.collator
|
|
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
|
|
|
|
estim_prevs, true_prevs = [], []
|
|
for sample in protocol():
|
|
e_sample = estimator.extend(sample)
|
|
estim_prev = estimator.estimate(e_sample.X, ext=True)
|
|
estim_prevs.append(estim_prev)
|
|
true_prevs.append(e_sample.prevalence())
|
|
|
|
protocol.collator = collator_bck_
|
|
|
|
true_prevs = np.array(true_prevs)
|
|
estim_prevs = np.array(estim_prevs)
|
|
|
|
return error_metric(true_prevs, estim_prevs)
|