import itertools from quapy.protocol import ( OnLabelledCollectionProtocol, AbstractStochasticSeededProtocol, ) from typing import Iterable, Callable, Union from .estimator import AccuracyEstimator import pandas as pd import numpy as np import quacc.error as error def estimate( estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol, ): # ensure that the protocol returns a LabelledCollection for each iteration protocol.collator = OnLabelledCollectionProtocol.get_collator("labelled_collection") base_prevs, true_prevs, estim_prevs = [], [], [] for sample in protocol(): e_sample = estimator.extend(sample) estim_prev = estimator.estimate(e_sample.X, ext=True) # base_prevs.append(_prettyfloat(accuracy, sample.prevalence())) # true_prevs.append(_prettyfloat(accuracy, e_sample.prevalence())) # estim_prevs.append(_prettyfloat(accuracy, estim_prev)) base_prevs.append(sample.prevalence()) true_prevs.append(e_sample.prevalence()) estim_prevs.append(estim_prev) return base_prevs, true_prevs, estim_prevs _bprev_col_0 = ["base"] _bprev_col_1 = ["0", "1"] _prev_col_0 = ["true", "estim"] _prev_col_1 = ["T0", "F1", "F0", "T1"] _err_col_0 = ["errors"] def _report_columns(err_names): bprev_cols = list(itertools.product(_bprev_col_0, _bprev_col_1)) prev_cols = list(itertools.product(_prev_col_0, _prev_col_1)) err_1 = err_names err_cols = list(itertools.product(_err_col_0, err_1)) cols = bprev_cols + prev_cols + err_cols return pd.MultiIndex.from_tuples(cols) def _dict_prev(base_prev, true_prev, estim_prev): prev_cols = list(itertools.product(_bprev_col_0, _bprev_col_1)) + list( itertools.product(_prev_col_0, _prev_col_1) ) return { k: v for (k, v) in zip( prev_cols, np.concatenate((base_prev, true_prev, estim_prev), axis=0) ) } def evaluation_report( estimator: AccuracyEstimator, protocol: AbstractStochasticSeededProtocol, error_metrics: Iterable[Union[str, Callable]] = "all", ): base_prevs, true_prevs, estim_prevs = estimate(estimator, protocol) if error_metrics == "all": error_metrics = ["mae", "rae", "mrae", "kld", "nkld", "f1e"] error_funcs = [ error.from_name(e) if isinstance(e, str) else e for e in error_metrics ] assert all(hasattr(e, "__call__") for e in error_funcs), "invalid error function" error_names = [e.__name__ for e in error_funcs] error_cols = error_names.copy() if "f1e" in error_cols: error_cols.remove("f1e") error_cols.extend(["f1e_true", "f1e_estim"]) # df_cols = ["base_prev", "true_prev", "estim_prev"] + error_names df_cols = _report_columns(error_cols) lst = [] for base_prev, true_prev, estim_prev in zip(base_prevs, true_prevs, estim_prevs): series = _dict_prev(base_prev, true_prev, estim_prev) for error_name, error_metric in zip(error_names, error_funcs): if error_name == "f1e": series[("errors", "f1e_true")] = error_metric(true_prev) series[("errors", "f1e_estim")] = error_metric(estim_prev) continue score = error_metric(true_prev, estim_prev) series[("errors", error_name)] = score lst.append(series) df = pd.DataFrame(lst, columns=df_cols) return df