forked from moreo/QuaPy
103 lines
4.8 KiB
Python
103 lines
4.8 KiB
Python
|
from typing import Union, Callable, Iterable
|
||
|
import numpy as np
|
||
|
from tqdm import tqdm
|
||
|
import inspect
|
||
|
import quapy as qp
|
||
|
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
|
||
|
from quapy.data import LabelledCollection
|
||
|
from quapy.method.base import BaseQuantifier
|
||
|
from quapy.util import temp_seed
|
||
|
import quapy.functional as F
|
||
|
import pandas as pd
|
||
|
|
||
|
|
||
|
def prediction(model: BaseQuantifier, protocol: AbstractProtocol, verbose=False):
|
||
|
sout = lambda x: print(x) if verbose else None
|
||
|
from method.aggregative import AggregativeQuantifier
|
||
|
if isinstance(model, AggregativeQuantifier) and isinstance(protocol, OnLabelledCollectionProtocol):
|
||
|
sout('speeding up the prediction for the aggregative quantifier')
|
||
|
pre_classified = model.classify(protocol.get_labelled_collection().instances)
|
||
|
return __prediction_helper(model.aggregate, protocol.on_preclassified_instances(pre_classified), verbose)
|
||
|
else:
|
||
|
sout(f'the method is not aggregative, or the protocol is not an instance of '
|
||
|
f'{OnLabelledCollectionProtocol.__name__}, so no optimization can be carried out')
|
||
|
return __prediction_helper(model.quantify, protocol, verbose)
|
||
|
|
||
|
|
||
|
def __prediction_helper(quantification_fn, protocol: AbstractProtocol, verbose=False):
|
||
|
true_prevs, estim_prevs = [], []
|
||
|
for sample in tqdm(protocol(), total=protocol.total()) if verbose else protocol():
|
||
|
estim_prevs.append(quantification_fn(sample.instances))
|
||
|
true_prevs.append(sample.prevalence())
|
||
|
|
||
|
true_prevs = np.asarray(true_prevs)
|
||
|
estim_prevs = np.asarray(estim_prevs)
|
||
|
|
||
|
return true_prevs, estim_prevs
|
||
|
|
||
|
|
||
|
def evaluation_report(model: BaseQuantifier,
|
||
|
protocol: AbstractProtocol,
|
||
|
error_metrics:Iterable[Union[str,Callable]]='mae',
|
||
|
verbose=False):
|
||
|
|
||
|
true_prevs, estim_prevs = prediction(model, protocol, verbose)
|
||
|
return _prevalence_report(true_prevs, estim_prevs, error_metrics)
|
||
|
|
||
|
|
||
|
def _prevalence_report(true_prevs, estim_prevs, error_metrics: Iterable[Union[str, Callable]] = 'mae'):
|
||
|
|
||
|
if isinstance(error_metrics, str):
|
||
|
error_metrics = [error_metrics]
|
||
|
|
||
|
error_funcs = [qp.error.from_name(e) if isinstance(e, str) else e for e in error_metrics]
|
||
|
assert all(hasattr(e, '__call__') for e in error_funcs), 'invalid error functions'
|
||
|
error_names = [e.__name__ for e in error_funcs]
|
||
|
|
||
|
df = pd.DataFrame(columns=['true-prev', 'estim-prev'] + error_names)
|
||
|
for true_prev, estim_prev in zip(true_prevs, estim_prevs):
|
||
|
series = {'true-prev': true_prev, 'estim-prev': estim_prev}
|
||
|
for error_name, error_metric in zip(error_names, error_funcs):
|
||
|
score = error_metric(true_prev, estim_prev)
|
||
|
series[error_name] = score
|
||
|
df = df.append(series, ignore_index=True)
|
||
|
|
||
|
return df
|
||
|
|
||
|
|
||
|
def evaluate(model: BaseQuantifier, protocol: AbstractProtocol, error_metric:Union[str, Callable], verbose=False):
|
||
|
if isinstance(error_metric, str):
|
||
|
error_metric = qp.error.from_name(error_metric)
|
||
|
true_prevs, estim_prevs = prediction(model, protocol, verbose)
|
||
|
return error_metric(true_prevs, estim_prevs)
|
||
|
|
||
|
|
||
|
|
||
|
def _check_num_evals(n_classes, n_prevpoints=None, eval_budget=None, repeats=1, verbose=False):
|
||
|
if n_prevpoints is None and eval_budget is None:
|
||
|
raise ValueError('either n_prevpoints or eval_budget has to be specified')
|
||
|
elif n_prevpoints is None:
|
||
|
assert eval_budget > 0, 'eval_budget must be a positive integer'
|
||
|
n_prevpoints = F.get_nprevpoints_approximation(eval_budget, n_classes, repeats)
|
||
|
eval_computations = F.num_prevalence_combinations(n_prevpoints, n_classes, repeats)
|
||
|
if verbose:
|
||
|
print(f'setting n_prevpoints={n_prevpoints} so that the number of '
|
||
|
f'evaluations ({eval_computations}) does not exceed the evaluation '
|
||
|
f'budget ({eval_budget})')
|
||
|
elif eval_budget is None:
|
||
|
eval_computations = F.num_prevalence_combinations(n_prevpoints, n_classes, repeats)
|
||
|
if verbose:
|
||
|
print(f'{eval_computations} evaluations will be performed for each '
|
||
|
f'combination of hyper-parameters')
|
||
|
else:
|
||
|
eval_computations = F.num_prevalence_combinations(n_prevpoints, n_classes, repeats)
|
||
|
if eval_computations > eval_budget:
|
||
|
n_prevpoints = F.get_nprevpoints_approximation(eval_budget, n_classes, repeats)
|
||
|
new_eval_computations = F.num_prevalence_combinations(n_prevpoints, n_classes, repeats)
|
||
|
if verbose:
|
||
|
print(f'the budget of evaluations would be exceeded with '
|
||
|
f'n_prevpoints={n_prevpoints}. Chaning to n_prevpoints={n_prevpoints}. This will produce '
|
||
|
f'{new_eval_computations} evaluation computations for each hyper-parameter combination.')
|
||
|
return n_prevpoints, eval_computations
|
||
|
|