from typing import Union, Callable, Iterable from data import LabelledCollection from method.aggregative import AggregativeQuantifier, AggregativeProbabilisticQuantifier from method.base import BaseQuantifier from util import temp_seed import numpy as np from joblib import Parallel, delayed from tqdm import tqdm import error def artificial_sampling_prediction( model: BaseQuantifier, test: LabelledCollection, sample_size, n_prevpoints=210, n_repetitions=1, n_jobs=-1, random_seed=42, verbose=True ): """ Performs the predictions for all samples generated according to the artificial sampling protocol. :param model: the model in charge of generating the class prevalence estimations :param test: the test set on which to perform arificial sampling :param sample_size: the size of the samples :param n_prevpoints: the number of different prevalences to sample :param n_repetitions: the number of repetitions for each prevalence :param n_jobs: number of jobs to be run in parallel :param random_seed: allows to replicate the samplings. The seed is local to the method and does not affect any other random process. :param verbose: if True, shows a progress bar :return: two ndarrays of shape (m,n) with m the number of samples (n_prevpoints*n_repetitions) and n the number of classes. The first one contains the true prevalences for the samples generated while the second one contains the the prevalence estimations """ with temp_seed(random_seed): indexes = list(test.artificial_sampling_index_generator(sample_size, n_prevpoints, n_repetitions)) if isinstance(model, AggregativeQuantifier): quantification_func = model.aggregate if isinstance(model, AggregativeProbabilisticQuantifier): preclassified_instances = model.posterior_probabilities(test.instances) else: preclassified_instances = model.classify(test.instances) test = LabelledCollection(preclassified_instances, test.labels) else: quantification_func = model.quantify def _predict_prevalences(index): sample = test.sampling_from_index(index) true_prevalence = sample.prevalence() estim_prevalence = quantification_func(sample.instances) return true_prevalence, estim_prevalence pbar = tqdm(indexes, desc='[artificial sampling protocol] predicting') if verbose else indexes results = Parallel(n_jobs=n_jobs)( delayed(_predict_prevalences)(index) for index in pbar ) true_prevalences, estim_prevalences = zip(*results) true_prevalences = np.asarray(true_prevalences) estim_prevalences = np.asarray(estim_prevalences) return true_prevalences, estim_prevalences def evaluate(model: BaseQuantifier, test_samples:Iterable[LabelledCollection], err:Union[str, Callable], n_jobs:int=-1): if isinstance(err, str): err = getattr(error, err) assert err.__name__ in error.QUANTIFICATION_ERROR_NAMES, \ f'error={err} does not seem to be a quantification error' scores = Parallel(n_jobs=n_jobs)( delayed(_delayed_eval)(model, Ti, err) for Ti in test_samples ) return np.mean(scores) def _delayed_eval(model:BaseQuantifier, test:LabelledCollection, error:Callable): prev_estim = model.quantify(test.instances) prev_true = test.prevalence() return error(prev_true, prev_estim)