QuaPy/quapy/evaluation.py

85 lines
3.4 KiB
Python
Raw Normal View History

from typing import Union, Callable, Iterable
from data import LabelledCollection
from method.aggregative import AggregativeQuantifier, AggregativeProbabilisticQuantifier
from method.base import BaseQuantifier
from util import temp_seed
import numpy as np
from joblib import Parallel, delayed
from tqdm import tqdm
import error
def artificial_sampling_prediction(
model: BaseQuantifier,
test: LabelledCollection,
sample_size,
n_prevpoints=210,
n_repetitions=1,
n_jobs=-1,
random_seed=42,
verbose=True
):
"""
Performs the predictions for all samples generated according to the artificial sampling protocol.
:param model: the model in charge of generating the class prevalence estimations
:param test: the test set on which to perform arificial sampling
:param sample_size: the size of the samples
:param n_prevpoints: the number of different prevalences to sample
:param n_repetitions: the number of repetitions for each prevalence
:param n_jobs: number of jobs to be run in parallel
:param random_seed: allows to replicate the samplings. The seed is local to the method and does not affect
any other random process.
:param verbose: if True, shows a progress bar
2021-01-07 17:58:48 +01:00
:return: two ndarrays of shape (m,n) with m the number of samples (n_prevpoints*n_repetitions) and n the
number of classes. The first one contains the true prevalences for the samples generated while the second one
2021-01-07 17:58:48 +01:00
contains the the prevalence estimations
"""
with temp_seed(random_seed):
indexes = list(test.artificial_sampling_index_generator(sample_size, n_prevpoints, n_repetitions))
if isinstance(model, AggregativeQuantifier):
quantification_func = model.aggregate
if isinstance(model, AggregativeProbabilisticQuantifier):
preclassified_instances = model.posterior_probabilities(test.instances)
else:
preclassified_instances = model.classify(test.instances)
test = LabelledCollection(preclassified_instances, test.labels)
else:
quantification_func = model.quantify
def _predict_prevalences(index):
sample = test.sampling_from_index(index)
true_prevalence = sample.prevalence()
estim_prevalence = quantification_func(sample.instances)
return true_prevalence, estim_prevalence
pbar = tqdm(indexes, desc='[artificial sampling protocol] predicting') if verbose else indexes
results = Parallel(n_jobs=n_jobs)(
delayed(_predict_prevalences)(index) for index in pbar
)
true_prevalences, estim_prevalences = zip(*results)
true_prevalences = np.asarray(true_prevalences)
estim_prevalences = np.asarray(estim_prevalences)
return true_prevalences, estim_prevalences
def evaluate(model: BaseQuantifier, test_samples:Iterable[LabelledCollection], err:Union[str, Callable], n_jobs:int=-1):
if isinstance(err, str):
err = getattr(error, err)
assert err.__name__ in error.QUANTIFICATION_ERROR_NAMES, \
f'error={err} does not seem to be a quantification error'
scores = Parallel(n_jobs=n_jobs)(
delayed(_delayed_eval)(model, Ti, err) for Ti in test_samples
)
return np.mean(scores)
def _delayed_eval(model:BaseQuantifier, test:LabelledCollection, error:Callable):
prev_estim = model.quantify(test.instances)
prev_true = test.prevalence()
return error(prev_true, prev_estim)