QuAcc/quacc/evaluation/__init__.py

35 lines
1023 B
Python
Raw Normal View History

2023-11-08 17:26:44 +01:00
from typing import Callable, Union
import numpy as np
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
import quacc as qc
from ..method.base import BaseAccuracyEstimator
def evaluate(
estimator: BaseAccuracyEstimator,
protocol: AbstractProtocol,
error_metric: Union[Callable | str],
) -> float:
if isinstance(error_metric, str):
error_metric = qc.error.from_name(error_metric)
collator_bck_ = protocol.collator
protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection")
estim_prevs, true_prevs = [], []
for sample in protocol():
e_sample = estimator.extend(sample)
estim_prev = estimator.estimate(e_sample.X, ext=True)
estim_prevs.append(estim_prev)
true_prevs.append(e_sample.prevalence())
protocol.collator = collator_bck_
true_prevs = np.array(true_prevs)
estim_prevs = np.array(estim_prevs)
return error_metric(true_prevs, estim_prevs)