2021-01-06 14:58:29 +01:00
|
|
|
from typing import Union, Callable, Iterable
|
2021-01-15 18:32:32 +01:00
|
|
|
|
2020-12-10 19:04:33 +01:00
|
|
|
import numpy as np
|
|
|
|
from joblib import Parallel, delayed
|
|
|
|
from tqdm import tqdm
|
2021-01-15 18:32:32 +01:00
|
|
|
|
|
|
|
import quapy as qp
|
|
|
|
from quapy.data import LabelledCollection
|
|
|
|
from quapy.method.base import BaseQuantifier
|
|
|
|
from quapy.util import temp_seed
|
2021-01-18 16:52:19 +01:00
|
|
|
import quapy.functional as F
|
2021-01-28 18:22:43 +01:00
|
|
|
import pandas as pd
|
2020-12-10 19:04:33 +01:00
|
|
|
|
|
|
|
def artificial_sampling_prediction(
|
|
|
|
model: BaseQuantifier,
|
|
|
|
test: LabelledCollection,
|
|
|
|
sample_size,
|
2020-12-11 19:28:17 +01:00
|
|
|
n_prevpoints=210,
|
|
|
|
n_repetitions=1,
|
2021-01-22 09:58:12 +01:00
|
|
|
n_jobs=1,
|
2020-12-22 17:43:23 +01:00
|
|
|
random_seed=42,
|
|
|
|
verbose=True
|
|
|
|
):
|
2020-12-10 19:04:33 +01:00
|
|
|
"""
|
|
|
|
Performs the predictions for all samples generated according to the artificial sampling protocol.
|
|
|
|
:param model: the model in charge of generating the class prevalence estimations
|
|
|
|
:param test: the test set on which to perform arificial sampling
|
|
|
|
:param sample_size: the size of the samples
|
2020-12-11 19:28:17 +01:00
|
|
|
:param n_prevpoints: the number of different prevalences to sample
|
|
|
|
:param n_repetitions: the number of repetitions for each prevalence
|
2020-12-10 19:04:33 +01:00
|
|
|
:param n_jobs: number of jobs to be run in parallel
|
|
|
|
:param random_seed: allows to replicate the samplings. The seed is local to the method and does not affect
|
|
|
|
any other random process.
|
2020-12-22 17:43:23 +01:00
|
|
|
:param verbose: if True, shows a progress bar
|
2021-01-07 17:58:48 +01:00
|
|
|
:return: two ndarrays of shape (m,n) with m the number of samples (n_prevpoints*n_repetitions) and n the
|
2020-12-10 19:04:33 +01:00
|
|
|
number of classes. The first one contains the true prevalences for the samples generated while the second one
|
2021-01-07 17:58:48 +01:00
|
|
|
contains the the prevalence estimations
|
2020-12-10 19:04:33 +01:00
|
|
|
"""
|
|
|
|
|
|
|
|
with temp_seed(random_seed):
|
2020-12-11 19:28:17 +01:00
|
|
|
indexes = list(test.artificial_sampling_index_generator(sample_size, n_prevpoints, n_repetitions))
|
|
|
|
|
2021-01-18 16:52:19 +01:00
|
|
|
if model.aggregative: #isinstance(model, qp.method.aggregative.AggregativeQuantifier):
|
2021-01-19 18:26:40 +01:00
|
|
|
# print('\tinstance of aggregative-quantifier')
|
2020-12-11 19:28:17 +01:00
|
|
|
quantification_func = model.aggregate
|
2021-01-18 16:52:19 +01:00
|
|
|
if model.probabilistic: # isinstance(model, qp.method.aggregative.AggregativeProbabilisticQuantifier):
|
2021-01-19 18:26:40 +01:00
|
|
|
# print('\t\tinstance of probabilitstic-aggregative-quantifier')
|
2020-12-11 19:28:17 +01:00
|
|
|
preclassified_instances = model.posterior_probabilities(test.instances)
|
|
|
|
else:
|
2021-01-19 18:26:40 +01:00
|
|
|
# print('\t\tinstance of hard-aggregative-quantifier')
|
2020-12-11 19:28:17 +01:00
|
|
|
preclassified_instances = model.classify(test.instances)
|
|
|
|
test = LabelledCollection(preclassified_instances, test.labels)
|
|
|
|
else:
|
2021-01-19 18:26:40 +01:00
|
|
|
# print('\t\tinstance of base-quantifier')
|
2020-12-11 19:28:17 +01:00
|
|
|
quantification_func = model.quantify
|
2020-12-10 19:04:33 +01:00
|
|
|
|
|
|
|
def _predict_prevalences(index):
|
|
|
|
sample = test.sampling_from_index(index)
|
|
|
|
true_prevalence = sample.prevalence()
|
2020-12-11 19:28:17 +01:00
|
|
|
estim_prevalence = quantification_func(sample.instances)
|
2020-12-10 19:04:33 +01:00
|
|
|
return true_prevalence, estim_prevalence
|
|
|
|
|
2020-12-22 17:43:23 +01:00
|
|
|
pbar = tqdm(indexes, desc='[artificial sampling protocol] predicting') if verbose else indexes
|
2021-01-27 09:54:41 +01:00
|
|
|
results = qp.util.parallel(_predict_prevalences, pbar, n_jobs=n_jobs)
|
2020-12-10 19:04:33 +01:00
|
|
|
|
|
|
|
true_prevalences, estim_prevalences = zip(*results)
|
|
|
|
true_prevalences = np.asarray(true_prevalences)
|
|
|
|
estim_prevalences = np.asarray(estim_prevalences)
|
|
|
|
|
|
|
|
return true_prevalences, estim_prevalences
|
|
|
|
|
|
|
|
|
2021-01-28 18:22:43 +01:00
|
|
|
def artificial_sampling_report(
|
|
|
|
model: BaseQuantifier,
|
|
|
|
test: LabelledCollection,
|
|
|
|
sample_size,
|
|
|
|
n_prevpoints=210,
|
|
|
|
n_repetitions=1,
|
|
|
|
n_jobs=1,
|
|
|
|
random_seed=42,
|
|
|
|
error_metrics:Iterable[Union[str,Callable]]='mae',
|
|
|
|
verbose=True):
|
|
|
|
|
|
|
|
if isinstance(error_metrics, str):
|
|
|
|
error_metrics=[error_metrics]
|
|
|
|
|
|
|
|
error_names = [e if isinstance(e, str) else e.__name__ for e in error_metrics]
|
|
|
|
error_funcs = [qp.error.from_name(e) if isinstance(e, str) else e for e in error_metrics]
|
|
|
|
assert all(hasattr(e, '__call__') for e in error_funcs), 'invalid error functions'
|
|
|
|
|
|
|
|
df = pd.DataFrame(columns=['true-prev', 'estim-prev']+error_names)
|
|
|
|
true_prevs, estim_prevs = artificial_sampling_prediction(
|
|
|
|
model, test, sample_size, n_prevpoints, n_repetitions, n_jobs, random_seed, verbose
|
|
|
|
)
|
|
|
|
for true_prev, estim_prev in zip(true_prevs, estim_prevs):
|
|
|
|
series = {'true-prev': true_prev, 'estim-prev': estim_prev}
|
|
|
|
for error_name, error_metric in zip(error_names, error_funcs):
|
|
|
|
score = error_metric(true_prev, estim_prev)
|
|
|
|
series[error_name] = score
|
|
|
|
df = df.append(series, ignore_index=True)
|
|
|
|
|
|
|
|
return df
|
|
|
|
|
|
|
|
|
|
|
|
def artificial_sampling_eval(
|
|
|
|
model: BaseQuantifier,
|
|
|
|
test: LabelledCollection,
|
|
|
|
sample_size,
|
|
|
|
n_prevpoints=210,
|
|
|
|
n_repetitions=1,
|
|
|
|
n_jobs=1,
|
|
|
|
random_seed=42,
|
|
|
|
error_metric:Union[str,Callable]='mae',
|
|
|
|
verbose=True):
|
|
|
|
|
|
|
|
if isinstance(error_metric, str):
|
|
|
|
error_metric = qp.error.from_name(error_metric)
|
|
|
|
|
|
|
|
assert hasattr(error_metric, '__call__'), 'invalid error function'
|
|
|
|
|
|
|
|
true_prevs, estim_prevs = artificial_sampling_prediction(
|
|
|
|
model, test, sample_size, n_prevpoints, n_repetitions, n_jobs, random_seed, verbose
|
|
|
|
)
|
|
|
|
|
|
|
|
return error_metric(true_prevs, estim_prevs)
|
|
|
|
|
|
|
|
|
2021-01-06 14:58:29 +01:00
|
|
|
def evaluate(model: BaseQuantifier, test_samples:Iterable[LabelledCollection], err:Union[str, Callable], n_jobs:int=-1):
|
|
|
|
if isinstance(err, str):
|
2021-01-27 09:54:41 +01:00
|
|
|
err = qp.error.from_name(err)
|
|
|
|
scores = qp.util.parallel(_delayed_eval, ((model, Ti, err) for Ti in test_samples), n_jobs=n_jobs)
|
2021-01-06 14:58:29 +01:00
|
|
|
return np.mean(scores)
|
|
|
|
|
2020-12-10 19:04:33 +01:00
|
|
|
|
2021-01-27 09:54:41 +01:00
|
|
|
def _delayed_eval(args):
|
|
|
|
model, test, error = args
|
2021-01-06 14:58:29 +01:00
|
|
|
prev_estim = model.quantify(test.instances)
|
|
|
|
prev_true = test.prevalence()
|
|
|
|
return error(prev_true, prev_estim)
|
2020-12-10 19:04:33 +01:00
|
|
|
|