QuaPy/quapy/evaluation.py

188 lines
10 KiB
Python
Raw Normal View History

2022-05-25 19:14:33 +02:00
from typing import Union, Callable, Iterable
import numpy as np
from tqdm import tqdm
import quapy as qp
2023-02-13 19:27:48 +01:00
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol, IterateProtocol
2022-05-25 19:14:33 +02:00
from quapy.method.base import BaseQuantifier
import pandas as pd
2023-02-14 11:14:38 +01:00
def prediction(
model: BaseQuantifier,
protocol: AbstractProtocol,
aggr_speedup: Union[str, bool] = 'auto',
verbose=False):
"""
Uses a quantification model to generate predictions for the samples generated via a specific protocol.
This function is central to all evaluation processes, and is endowed with an optimization to speed-up the
prediction of protocols that generate samples from a large collection. The optimization applies to aggregative
2023-02-14 17:00:50 +01:00
quantifiers only, and to OnLabelledCollectionProtocol protocols, and comes down to generating the classification
2023-02-14 11:14:38 +01:00
predictions once and for all, and then generating samples over the classification predictions (instead of over
the raw instances), so that the classifier prediction is never called again. This behaviour is obtained by
setting `aggr_speedup` to 'auto' or True, and is only carried out if the overall process is convenient in terms
of computations (e.g., if the number of classification predictions needed for the original collection exceed the
number of classification predictions needed for all samples, then the optimization is not undertaken).
:param model: a quantifier, instance of :class:`quapy.method.base.BaseQuantifier`
:param protocol: :class:`quapy.protocol.AbstractProtocol`; if this object is also instance of
2023-02-14 17:00:50 +01:00
:class:`quapy.protocol.OnLabelledCollectionProtocol`, then the aggregation speed-up can be run. This is the protocol
2023-02-14 11:14:38 +01:00
in charge of generating the samples for which the model has to issue class prevalence predictions.
:param aggr_speedup: whether or not to apply the speed-up. Set to "force" for applying it even if the number of
instances in the original collection on which the protocol acts is larger than the number of instances
in the samples to be generated. Set to True or "auto" (default) for letting QuaPy decide whether it is
convenient or not. Set to False to deactivate.
:param verbose: boolean, show or not information in stdout
:return: a tuple `(true_prevs, estim_prevs)` in which each element in the tuple is an array of shape
`(n_samples, n_classes)` containing the true, or predicted, prevalence values for each sample
"""
assert aggr_speedup in [False, True, 'auto', 'force'], 'invalid value for aggr_speedup'
2022-05-25 19:14:33 +02:00
sout = lambda x: print(x) if verbose else None
apply_optimization = False
if aggr_speedup in [True, 'auto', 'force']:
# checks whether the prediction can be made more efficiently; this check consists in verifying if the model is
# of type aggregative, if the protocol is based on LabelledCollection, and if the total number of documents to
# classify using the protocol would exceed the number of test documents in the original collection
2022-06-01 18:28:59 +02:00
from quapy.method.aggregative import AggregativeQuantifier
if isinstance(model, AggregativeQuantifier) and isinstance(protocol, OnLabelledCollectionProtocol):
if aggr_speedup == 'force':
apply_optimization = True
sout(f'forcing aggregative speedup')
elif hasattr(protocol, 'sample_size'):
nD = len(protocol.get_labelled_collection())
samplesD = protocol.total() * protocol.sample_size
if nD < samplesD:
apply_optimization = True
sout(f'speeding up the prediction for the aggregative quantifier, '
f'total classifications {nD} instead of {samplesD}')
if apply_optimization:
2022-05-25 19:14:33 +02:00
pre_classified = model.classify(protocol.get_labelled_collection().instances)
protocol_with_predictions = protocol.on_preclassified_instances(pre_classified)
return __prediction_helper(model.aggregate, protocol_with_predictions, verbose)
2022-05-25 19:14:33 +02:00
else:
return __prediction_helper(model.quantify, protocol, verbose)
def __prediction_helper(quantification_fn, protocol: AbstractProtocol, verbose=False):
true_prevs, estim_prevs = [], []
for sample_instances, sample_prev in tqdm(protocol(), total=protocol.total(), desc='predicting') if verbose else protocol():
2022-06-01 18:28:59 +02:00
estim_prevs.append(quantification_fn(sample_instances))
true_prevs.append(sample_prev)
2022-05-25 19:14:33 +02:00
true_prevs = np.asarray(true_prevs)
estim_prevs = np.asarray(estim_prevs)
return true_prevs, estim_prevs
def evaluation_report(model: BaseQuantifier,
protocol: AbstractProtocol,
error_metrics: Iterable[Union[str,Callable]] = 'mae',
2023-02-14 11:14:38 +01:00
aggr_speedup: Union[str, bool] = 'auto',
2022-05-25 19:14:33 +02:00
verbose=False):
2023-02-14 11:14:38 +01:00
"""
Generates a report (a pandas' DataFrame) containing information of the evaluation of the model as according
to a specific protocol and in terms of one or more evaluation metrics (errors).
:param model: a quantifier, instance of :class:`quapy.method.base.BaseQuantifier`
:param protocol: :class:`quapy.protocol.AbstractProtocol`; if this object is also instance of
2023-02-14 17:00:50 +01:00
:class:`quapy.protocol.OnLabelledCollectionProtocol`, then the aggregation speed-up can be run. This is the protocol
2023-02-14 11:14:38 +01:00
in charge of generating the samples in which the model is evaluated.
:param error_metrics: a string, or list of strings, representing the name(s) of an error function in `qp.error`
(e.g., 'mae', the default value), or a callable function, or a list of callable functions, implementing
the error function itself.
:param aggr_speedup: whether or not to apply the speed-up. Set to "force" for applying it even if the number of
instances in the original collection on which the protocol acts is larger than the number of instances
in the samples to be generated. Set to True or "auto" (default) for letting QuaPy decide whether it is
convenient or not. Set to False to deactivate.
:param verbose: boolean, show or not information in stdout
:return: a pandas' DataFrame containing the columns 'true-prev' (the true prevalence of each sample),
'estim-prev' (the prevalence estimated by the model for each sample), and as many columns as error metrics
have been indicated, each displaying the score in terms of that metric for every sample.
"""
2022-05-25 19:14:33 +02:00
true_prevs, estim_prevs = prediction(model, protocol, aggr_speedup=aggr_speedup, verbose=verbose)
2022-05-25 19:14:33 +02:00
return _prevalence_report(true_prevs, estim_prevs, error_metrics)
def _prevalence_report(true_prevs, estim_prevs, error_metrics: Iterable[Union[str, Callable]] = 'mae'):
if isinstance(error_metrics, str):
error_metrics = [error_metrics]
error_funcs = [qp.error.from_name(e) if isinstance(e, str) else e for e in error_metrics]
assert all(hasattr(e, '__call__') for e in error_funcs), 'invalid error functions'
error_names = [e.__name__ for e in error_funcs]
df = pd.DataFrame(columns=['true-prev', 'estim-prev'] + error_names)
for true_prev, estim_prev in zip(true_prevs, estim_prevs):
series = {'true-prev': true_prev, 'estim-prev': estim_prev}
for error_name, error_metric in zip(error_names, error_funcs):
score = error_metric(true_prev, estim_prev)
series[error_name] = score
df = df.append(series, ignore_index=True)
return df
def evaluate(
model: BaseQuantifier,
protocol: AbstractProtocol,
2023-02-14 11:14:38 +01:00
error_metric: Union[str, Callable],
aggr_speedup: Union[str, bool] = 'auto',
verbose=False):
2023-02-14 11:14:38 +01:00
"""
Evaluates a quantification model according to a specific sample generation protocol and in terms of one
evaluation metric (error).
:param model: a quantifier, instance of :class:`quapy.method.base.BaseQuantifier`
:param protocol: :class:`quapy.protocol.AbstractProtocol`; if this object is also instance of
2023-02-14 17:00:50 +01:00
:class:`quapy.protocol.OnLabelledCollectionProtocol`, then the aggregation speed-up can be run. This is the
protocol in charge of generating the samples in which the model is evaluated.
2023-02-14 11:14:38 +01:00
:param error_metric: a string representing the name(s) of an error function in `qp.error`
(e.g., 'mae'), or a callable function implementing the error function itself.
:param aggr_speedup: whether or not to apply the speed-up. Set to "force" for applying it even if the number of
instances in the original collection on which the protocol acts is larger than the number of instances
in the samples to be generated. Set to True or "auto" (default) for letting QuaPy decide whether it is
convenient or not. Set to False to deactivate.
:param verbose: boolean, show or not information in stdout
:return: if the error metric is not averaged (e.g., 'ae', 'rae'), returns an array of shape `(n_samples,)` with
the error scores for each sample; if the error metric is averaged (e.g., 'mae', 'mrae') then returns
a single float
"""
2022-05-25 19:14:33 +02:00
if isinstance(error_metric, str):
error_metric = qp.error.from_name(error_metric)
true_prevs, estim_prevs = prediction(model, protocol, aggr_speedup=aggr_speedup, verbose=verbose)
2022-05-25 19:14:33 +02:00
return error_metric(true_prevs, estim_prevs)
2023-02-13 19:27:48 +01:00
def evaluate_on_samples(
model: BaseQuantifier,
2023-02-14 11:14:38 +01:00
samples: Iterable[qp.data.LabelledCollection],
error_metric: Union[str, Callable],
2023-02-13 19:27:48 +01:00
verbose=False):
2023-02-14 11:14:38 +01:00
"""
Evaluates a quantification model on a given set of samples and in terms of one evaluation metric (error).
:param model: a quantifier, instance of :class:`quapy.method.base.BaseQuantifier`
:param samples: a list of samples on which the quantifier is to be evaluated
:param error_metric: a string representing the name(s) of an error function in `qp.error`
(e.g., 'mae'), or a callable function implementing the error function itself.
:param verbose: boolean, show or not information in stdout
:return: if the error metric is not averaged (e.g., 'ae', 'rae'), returns an array of shape `(n_samples,)` with
the error scores for each sample; if the error metric is averaged (e.g., 'mae', 'mrae') then returns
a single float
"""
2023-02-13 19:27:48 +01:00
return evaluate(model, IterateProtocol(samples), error_metric, aggr_speedup=False, verbose=verbose)
2022-05-25 19:14:33 +02:00