1
0
Fork 0
QuaPy/quapy/evaluation.py

103 lines
4.2 KiB
Python
Raw Normal View History

2022-05-25 19:14:33 +02:00
from typing import Union, Callable, Iterable
import numpy as np
from tqdm import tqdm
import inspect
import quapy as qp
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
from quapy.data import LabelledCollection
from quapy.method.base import BaseQuantifier
from quapy.util import temp_seed
import quapy.functional as F
import pandas as pd
def prediction(model: BaseQuantifier, protocol: AbstractProtocol, aggr_speedup='auto', verbose=False):
assert aggr_speedup in [False, True, 'auto', 'force'], 'invalid value for aggr_speedup'
2022-05-25 19:14:33 +02:00
sout = lambda x: print(x) if verbose else None
apply_optimization = False
if aggr_speedup in [True, 'auto', 'force']:
# checks whether the prediction can be made more efficiently; this check consists in verifying if the model is
# of type aggregative, if the protocol is based on LabelledCollection, and if the total number of documents to
# classify using the protocol would exceed the number of test documents in the original collection
from method.aggregative import AggregativeQuantifier
if isinstance(model, AggregativeQuantifier) and isinstance(protocol, OnLabelledCollectionProtocol):
if aggr_speedup == 'force':
apply_optimization = True
sout(f'forcing aggregative speedup')
elif hasattr(protocol, 'sample_size'):
nD = len(protocol.get_labelled_collection())
samplesD = protocol.total() * protocol.sample_size
if nD < samplesD:
apply_optimization = True
sout(f'speeding up the prediction for the aggregative quantifier, '
f'total classifications {nD} instead of {samplesD}')
if apply_optimization:
2022-05-25 19:14:33 +02:00
pre_classified = model.classify(protocol.get_labelled_collection().instances)
protocol_with_predictions = protocol.on_preclassified_instances(pre_classified)
return __prediction_helper(model.aggregate, protocol_with_predictions, verbose)
2022-05-25 19:14:33 +02:00
else:
return __prediction_helper(model.quantify, protocol, verbose)
def __prediction_helper(quantification_fn, protocol: AbstractProtocol, verbose=False):
true_prevs, estim_prevs = [], []
for sample in tqdm(protocol(), total=protocol.total()) if verbose else protocol():
estim_prevs.append(quantification_fn(sample.instances))
true_prevs.append(sample.prevalence())
true_prevs = np.asarray(true_prevs)
estim_prevs = np.asarray(estim_prevs)
return true_prevs, estim_prevs
def evaluation_report(model: BaseQuantifier,
protocol: AbstractProtocol,
error_metrics: Iterable[Union[str,Callable]] = 'mae',
aggr_speedup='auto',
2022-05-25 19:14:33 +02:00
verbose=False):
true_prevs, estim_prevs = prediction(model, protocol, aggr_speedup=aggr_speedup, verbose=verbose)
2022-05-25 19:14:33 +02:00
return _prevalence_report(true_prevs, estim_prevs, error_metrics)
def _prevalence_report(true_prevs, estim_prevs, error_metrics: Iterable[Union[str, Callable]] = 'mae'):
if isinstance(error_metrics, str):
error_metrics = [error_metrics]
error_funcs = [qp.error.from_name(e) if isinstance(e, str) else e for e in error_metrics]
assert all(hasattr(e, '__call__') for e in error_funcs), 'invalid error functions'
error_names = [e.__name__ for e in error_funcs]
df = pd.DataFrame(columns=['true-prev', 'estim-prev'] + error_names)
for true_prev, estim_prev in zip(true_prevs, estim_prevs):
series = {'true-prev': true_prev, 'estim-prev': estim_prev}
for error_name, error_metric in zip(error_names, error_funcs):
score = error_metric(true_prev, estim_prev)
series[error_name] = score
df = df.append(series, ignore_index=True)
return df
def evaluate(
model: BaseQuantifier,
protocol: AbstractProtocol,
error_metric:Union[str, Callable],
aggr_speedup='auto',
verbose=False):
2022-05-25 19:14:33 +02:00
if isinstance(error_metric, str):
error_metric = qp.error.from_name(error_metric)
true_prevs, estim_prevs = prediction(model, protocol, aggr_speedup=aggr_speedup, verbose=verbose)
2022-05-25 19:14:33 +02:00
return error_metric(true_prevs, estim_prevs)