1
0
Fork 0
QuaPy/quapy/evaluation.py

90 lines
3.6 KiB
Python
Raw Normal View History

from typing import Union, Callable, Iterable
2021-01-15 18:32:32 +01:00
import numpy as np
from joblib import Parallel, delayed
from tqdm import tqdm
2021-01-15 18:32:32 +01:00
import quapy as qp
from quapy.data import LabelledCollection
from quapy.method.base import BaseQuantifier
from quapy.util import temp_seed
def artificial_sampling_prediction(
model: BaseQuantifier,
test: LabelledCollection,
sample_size,
n_prevpoints=210,
n_repetitions=1,
n_jobs=-1,
random_seed=42,
verbose=True
):
"""
Performs the predictions for all samples generated according to the artificial sampling protocol.
:param model: the model in charge of generating the class prevalence estimations
:param test: the test set on which to perform arificial sampling
:param sample_size: the size of the samples
:param n_prevpoints: the number of different prevalences to sample
:param n_repetitions: the number of repetitions for each prevalence
:param n_jobs: number of jobs to be run in parallel
:param random_seed: allows to replicate the samplings. The seed is local to the method and does not affect
any other random process.
:param verbose: if True, shows a progress bar
2021-01-07 17:58:48 +01:00
:return: two ndarrays of shape (m,n) with m the number of samples (n_prevpoints*n_repetitions) and n the
number of classes. The first one contains the true prevalences for the samples generated while the second one
2021-01-07 17:58:48 +01:00
contains the the prevalence estimations
"""
with temp_seed(random_seed):
indexes = list(test.artificial_sampling_index_generator(sample_size, n_prevpoints, n_repetitions))
2021-01-12 09:35:49 +01:00
if isinstance(model, qp.method.aggregative.AggregativeQuantifier):
# print('\tinstance of aggregative-quantifier')
quantification_func = model.aggregate
2021-01-12 09:35:49 +01:00
if isinstance(model, qp.method.aggregative.AggregativeProbabilisticQuantifier):
# print('\t\tinstance of probabilitstic-aggregative-quantifier')
preclassified_instances = model.posterior_probabilities(test.instances)
else:
# print('\t\tinstance of hard-aggregative-quantifier')
preclassified_instances = model.classify(test.instances)
test = LabelledCollection(preclassified_instances, test.labels)
else:
# print('\t\tinstance of base-quantifier')
quantification_func = model.quantify
def _predict_prevalences(index):
sample = test.sampling_from_index(index)
true_prevalence = sample.prevalence()
estim_prevalence = quantification_func(sample.instances)
return true_prevalence, estim_prevalence
pbar = tqdm(indexes, desc='[artificial sampling protocol] predicting') if verbose else indexes
results = Parallel(n_jobs=n_jobs)(
delayed(_predict_prevalences)(index) for index in pbar
)
true_prevalences, estim_prevalences = zip(*results)
true_prevalences = np.asarray(true_prevalences)
estim_prevalences = np.asarray(estim_prevalences)
return true_prevalences, estim_prevalences
def evaluate(model: BaseQuantifier, test_samples:Iterable[LabelledCollection], err:Union[str, Callable], n_jobs:int=-1):
if isinstance(err, str):
2021-01-15 18:32:32 +01:00
err = getattr(qp.error, err)
assert err.__name__ in qp.error.QUANTIFICATION_ERROR_NAMES, \
f'error={err} does not seem to be a quantification error'
scores = Parallel(n_jobs=n_jobs)(
delayed(_delayed_eval)(model, Ti, err) for Ti in test_samples
)
return np.mean(scores)
def _delayed_eval(model:BaseQuantifier, test:LabelledCollection, error:Callable):
prev_estim = model.quantify(test.instances)
prev_true = test.prevalence()
return error(prev_true, prev_estim)