adding natural sampling protocol
This commit is contained in:
parent
3d544135f1
commit
731b54c5ba
|
@ -12,6 +12,7 @@ import quapy.functional as F
|
|||
import pandas as pd
|
||||
|
||||
|
||||
|
||||
def artificial_sampling_prediction(
|
||||
model: BaseQuantifier,
|
||||
test: LabelledCollection,
|
||||
|
@ -21,8 +22,7 @@ def artificial_sampling_prediction(
|
|||
eval_budget: int = None,
|
||||
n_jobs=1,
|
||||
random_seed=42,
|
||||
verbose=True
|
||||
):
|
||||
verbose=False):
|
||||
"""
|
||||
Performs the predictions for all samples generated according to the artificial sampling protocol.
|
||||
:param model: the model in charge of generating the class prevalence estimations
|
||||
|
@ -48,6 +48,45 @@ def artificial_sampling_prediction(
|
|||
with temp_seed(random_seed):
|
||||
indexes = list(test.artificial_sampling_index_generator(sample_size, n_prevpoints, n_repetitions))
|
||||
|
||||
return _predict_from_indexes(indexes, model, test, n_jobs, verbose)
|
||||
|
||||
|
||||
def natural_sampling_prediction(
|
||||
model: BaseQuantifier,
|
||||
test: LabelledCollection,
|
||||
sample_size,
|
||||
n_repetitions=1,
|
||||
n_jobs=1,
|
||||
random_seed=42,
|
||||
verbose=False):
|
||||
"""
|
||||
Performs the predictions for all samples generated according to the artificial sampling protocol.
|
||||
:param model: the model in charge of generating the class prevalence estimations
|
||||
:param test: the test set on which to perform arificial sampling
|
||||
:param sample_size: the size of the samples
|
||||
:param n_repetitions: the number of repetitions for each prevalence
|
||||
:param n_jobs: number of jobs to be run in parallel
|
||||
:param random_seed: allows to replicate the samplings. The seed is local to the method and does not affect
|
||||
any other random process.
|
||||
:param verbose: if True, shows a progress bar
|
||||
:return: two ndarrays of shape (m,n) with m the number of samples (n_repetitions) and n the
|
||||
number of classes. The first one contains the true prevalences for the samples generated while the second one
|
||||
contains the the prevalence estimations
|
||||
"""
|
||||
|
||||
with temp_seed(random_seed):
|
||||
indexes = list(test.natural_sampling_index_generator(sample_size, n_repetitions))
|
||||
|
||||
return _predict_from_indexes(indexes, model, test, n_jobs, verbose)
|
||||
|
||||
|
||||
def _predict_from_indexes(
|
||||
indexes,
|
||||
model: BaseQuantifier,
|
||||
test: LabelledCollection,
|
||||
n_jobs=1,
|
||||
verbose=False):
|
||||
|
||||
if model.aggregative: #isinstance(model, qp.method.aggregative.AggregativeQuantifier):
|
||||
# print('\tinstance of aggregative-quantifier')
|
||||
quantification_func = model.aggregate
|
||||
|
@ -88,19 +127,43 @@ def artificial_sampling_report(
|
|||
n_jobs=1,
|
||||
random_seed=42,
|
||||
error_metrics:Iterable[Union[str,Callable]]='mae',
|
||||
verbose=True):
|
||||
verbose=False):
|
||||
|
||||
true_prevs, estim_prevs = artificial_sampling_prediction(
|
||||
model, test, sample_size, n_prevpoints, n_repetitions, eval_budget, n_jobs, random_seed, verbose
|
||||
)
|
||||
return __sampling_report(true_prevs, estim_prevs, error_metrics)
|
||||
|
||||
|
||||
def natural_sampling_report(
|
||||
model: BaseQuantifier,
|
||||
test: LabelledCollection,
|
||||
sample_size,
|
||||
n_repetitions=1,
|
||||
n_jobs=1,
|
||||
random_seed=42,
|
||||
error_metrics:Iterable[Union[str,Callable]]='mae',
|
||||
verbose=False):
|
||||
|
||||
true_prevs, estim_prevs = natural_sampling_prediction(
|
||||
model, test, sample_size, n_repetitions, n_jobs, random_seed, verbose
|
||||
)
|
||||
return __sampling_report(true_prevs, estim_prevs, error_metrics)
|
||||
|
||||
|
||||
def __sampling_report(
|
||||
true_prevs,
|
||||
estim_prevs,
|
||||
error_metrics: Iterable[Union[str, Callable]] = 'mae'):
|
||||
|
||||
if isinstance(error_metrics, str):
|
||||
error_metrics=[error_metrics]
|
||||
error_metrics = [error_metrics]
|
||||
|
||||
error_names = [e if isinstance(e, str) else e.__name__ for e in error_metrics]
|
||||
error_funcs = [qp.error.from_name(e) if isinstance(e, str) else e for e in error_metrics]
|
||||
assert all(hasattr(e, '__call__') for e in error_funcs), 'invalid error functions'
|
||||
|
||||
df = pd.DataFrame(columns=['true-prev', 'estim-prev']+error_names)
|
||||
true_prevs, estim_prevs = artificial_sampling_prediction(
|
||||
model, test, sample_size, n_prevpoints, n_repetitions, eval_budget, n_jobs, random_seed, verbose
|
||||
)
|
||||
df = pd.DataFrame(columns=['true-prev', 'estim-prev'] + error_names)
|
||||
for true_prev, estim_prev in zip(true_prevs, estim_prevs):
|
||||
series = {'true-prev': true_prev, 'estim-prev': estim_prev}
|
||||
for error_name, error_metric in zip(error_names, error_funcs):
|
||||
|
@ -110,7 +173,6 @@ def artificial_sampling_report(
|
|||
|
||||
return df
|
||||
|
||||
|
||||
def artificial_sampling_eval(
|
||||
model: BaseQuantifier,
|
||||
test: LabelledCollection,
|
||||
|
@ -121,7 +183,7 @@ def artificial_sampling_eval(
|
|||
n_jobs=1,
|
||||
random_seed=42,
|
||||
error_metric:Union[str,Callable]='mae',
|
||||
verbose=True):
|
||||
verbose=False):
|
||||
|
||||
if isinstance(error_metric, str):
|
||||
error_metric = qp.error.from_name(error_metric)
|
||||
|
@ -135,6 +197,28 @@ def artificial_sampling_eval(
|
|||
return error_metric(true_prevs, estim_prevs)
|
||||
|
||||
|
||||
def natural_sampling_eval(
|
||||
model: BaseQuantifier,
|
||||
test: LabelledCollection,
|
||||
sample_size,
|
||||
n_repetitions=1,
|
||||
n_jobs=1,
|
||||
random_seed=42,
|
||||
error_metric:Union[str,Callable]='mae',
|
||||
verbose=False):
|
||||
|
||||
if isinstance(error_metric, str):
|
||||
error_metric = qp.error.from_name(error_metric)
|
||||
|
||||
assert hasattr(error_metric, '__call__'), 'invalid error function'
|
||||
|
||||
true_prevs, estim_prevs = natural_sampling_prediction(
|
||||
model, test, sample_size, n_repetitions, n_jobs, random_seed, verbose
|
||||
)
|
||||
|
||||
return error_metric(true_prevs, estim_prevs)
|
||||
|
||||
|
||||
def evaluate(model: BaseQuantifier, test_samples:Iterable[LabelledCollection], err:Union[str, Callable], n_jobs:int=-1):
|
||||
if isinstance(err, str):
|
||||
err = qp.error.from_name(err)
|
||||
|
@ -149,7 +233,7 @@ def _delayed_eval(args):
|
|||
return error(prev_true, prev_estim)
|
||||
|
||||
|
||||
def _check_num_evals(n_classes, n_prevpoints=None, eval_budget=None, n_repetitions=1, verbose=True):
|
||||
def _check_num_evals(n_classes, n_prevpoints=None, eval_budget=None, n_repetitions=1, verbose=False):
|
||||
if n_prevpoints is None and eval_budget is None:
|
||||
raise ValueError('either n_prevpoints or eval_budget has to be specified')
|
||||
elif n_prevpoints is None:
|
||||
|
|
Loading…
Reference in New Issue