1
0
Fork 0

optimization conditional in the prediction function

This commit is contained in:
Alejandro Moreo Fernandez 2022-05-26 17:59:23 +02:00
parent 4bc9d19635
commit eba6fd8123
6 changed files with 119 additions and 142 deletions

View File

@ -1,9 +1,9 @@
# main changes in 0.1.7
- Protocols is now an abstraction, AbstractProtocol. There is a new class extending AbstractProtocol called
- Protocols are now abstracted as AbstractProtocol. There is a new class extending AbstractProtocol called
AbstractStochasticSeededProtocol, which implements a seeding policy to allow replicate the series of samplings.
There are some examples of protocols, APP, NPP, USimplexPP, CovariateShiftPP (experimental).
The idea is to start the sampling by simpli calling the __call__ method.
The idea is to start the sampling by simply calling the __call__ method.
This change has a great impact in the framework, since many functions in qp.evaluation, qp.model_selection,
and sampling functions in LabelledCollection make use of the old functions.
@ -11,7 +11,6 @@
Things to fix:
- eval budget policy?
- clean functions like binary, aggregative, probabilistic, etc; those should be resolved via isinstance()
- clean classes_ and n_classes from methods (maybe not from aggregative ones, but those have to be used only
internally and not imposed in any abstract class)
@ -31,4 +30,20 @@ Things to fix:
return instead crisp decisions. The idea was to unify the quantification function (i.e., now it is always
classify & aggregate, irrespective of the class). However, this has caused a problem with OneVsAll. This has to
be checked, since it is now innecessarily complicated (it also has old references to .probabilistic, and all this
stuff).
stuff).
- Check method def __parallel(self, func, *args, **kwargs) in aggregative.OneVsAll
# 0.1.7
# change the LabelledCollection API (removing protocol-related samplings)
# need to change the two references to the above in the wiki / doc, and code examples...
# removed artificial_prevalence_sampling from functional
# also: some parameters in the init could be used to indicate that the method should return a tuple with
# unlabelled instances and the vector of prevalence values (and not a LabelledCollection).
# Or: this can be done in a different function; i.e., we use one function (now __call__) to return
# LabelledCollections, and another new one for returning the other output, which is more general for
# evaluation purposes.
# the so-called "gen" function has to be implemented as a protocol. The problem here is that this function
# should be able to return only unlabelled instances plus a vector of prevalences (and not LabelledCollections).
# This was coded as different functions in 0.1.6

View File

@ -11,16 +11,35 @@ import quapy.functional as F
import pandas as pd
def prediction(model: BaseQuantifier, protocol: AbstractProtocol, verbose=False):
def prediction(model: BaseQuantifier, protocol: AbstractProtocol, aggr_speedup='auto', verbose=False):
assert aggr_speedup in [False, True, 'auto', 'force'], 'invalid value for aggr_speedup'
sout = lambda x: print(x) if verbose else None
from method.aggregative import AggregativeQuantifier
if isinstance(model, AggregativeQuantifier) and isinstance(protocol, OnLabelledCollectionProtocol):
sout('speeding up the prediction for the aggregative quantifier')
apply_optimization = False
if aggr_speedup in [True, 'auto', 'force']:
# checks whether the prediction can be made more efficiently; this check consists in verifying if the model is
# of type aggregative, if the protocol is based on LabelledCollection, and if the total number of documents to
# classify using the protocol would exceed the number of test documents in the original collection
from method.aggregative import AggregativeQuantifier
if isinstance(model, AggregativeQuantifier) and isinstance(protocol, OnLabelledCollectionProtocol):
if aggr_speedup == 'force':
apply_optimization = True
sout(f'forcing aggregative speedup')
elif hasattr(protocol, 'sample_size'):
nD = len(protocol.get_labelled_collection())
samplesD = protocol.total() * protocol.sample_size
if nD < samplesD:
apply_optimization = True
sout(f'speeding up the prediction for the aggregative quantifier, '
f'total classifications {nD} instead of {samplesD}')
if apply_optimization:
pre_classified = model.classify(protocol.get_labelled_collection().instances)
return __prediction_helper(model.aggregate, protocol.on_preclassified_instances(pre_classified), verbose)
protocol_with_predictions = protocol.on_preclassified_instances(pre_classified)
return __prediction_helper(model.aggregate, protocol_with_predictions, verbose)
else:
sout(f'the method is not aggregative, or the protocol is not an instance of '
f'{OnLabelledCollectionProtocol.__name__}, so no optimization can be carried out')
return __prediction_helper(model.quantify, protocol, verbose)
@ -38,10 +57,11 @@ def __prediction_helper(quantification_fn, protocol: AbstractProtocol, verbose=F
def evaluation_report(model: BaseQuantifier,
protocol: AbstractProtocol,
error_metrics:Iterable[Union[str,Callable]]='mae',
error_metrics: Iterable[Union[str,Callable]] = 'mae',
aggr_speedup='auto',
verbose=False):
true_prevs, estim_prevs = prediction(model, protocol, verbose)
true_prevs, estim_prevs = prediction(model, protocol, aggr_speedup=aggr_speedup, verbose=verbose)
return _prevalence_report(true_prevs, estim_prevs, error_metrics)
@ -65,38 +85,18 @@ def _prevalence_report(true_prevs, estim_prevs, error_metrics: Iterable[Union[st
return df
def evaluate(model: BaseQuantifier, protocol: AbstractProtocol, error_metric:Union[str, Callable], verbose=False):
def evaluate(
model: BaseQuantifier,
protocol: AbstractProtocol,
error_metric:Union[str, Callable],
aggr_speedup='auto',
verbose=False):
if isinstance(error_metric, str):
error_metric = qp.error.from_name(error_metric)
true_prevs, estim_prevs = prediction(model, protocol, verbose)
true_prevs, estim_prevs = prediction(model, protocol, aggr_speedup=aggr_speedup, verbose=verbose)
return error_metric(true_prevs, estim_prevs)
def _check_num_evals(n_classes, n_prevpoints=None, eval_budget=None, repeats=1, verbose=False):
if n_prevpoints is None and eval_budget is None:
raise ValueError('either n_prevpoints or eval_budget has to be specified')
elif n_prevpoints is None:
assert eval_budget > 0, 'eval_budget must be a positive integer'
n_prevpoints = F.get_nprevpoints_approximation(eval_budget, n_classes, repeats)
eval_computations = F.num_prevalence_combinations(n_prevpoints, n_classes, repeats)
if verbose:
print(f'setting n_prevpoints={n_prevpoints} so that the number of '
f'evaluations ({eval_computations}) does not exceed the evaluation '
f'budget ({eval_budget})')
elif eval_budget is None:
eval_computations = F.num_prevalence_combinations(n_prevpoints, n_classes, repeats)
if verbose:
print(f'{eval_computations} evaluations will be performed for each '
f'combination of hyper-parameters')
else:
eval_computations = F.num_prevalence_combinations(n_prevpoints, n_classes, repeats)
if eval_computations > eval_budget:
n_prevpoints = F.get_nprevpoints_approximation(eval_budget, n_classes, repeats)
new_eval_computations = F.num_prevalence_combinations(n_prevpoints, n_classes, repeats)
if verbose:
print(f'the budget of evaluations would be exceeded with '
f'n_prevpoints={n_prevpoints}. Chaning to n_prevpoints={n_prevpoints}. This will produce '
f'{new_eval_computations} evaluation computations for each hyper-parameter combination.')
return n_prevpoints, eval_computations

View File

@ -4,7 +4,6 @@ import scipy
import numpy as np
def prevalence_linspace(n_prevalences=21, repeats=1, smooth_limits_epsilon=0.01):
"""
Produces an array of uniformly separated values of prevalence.

View File

@ -1023,15 +1023,18 @@ class OneVsAll(AggregativeQuantifier):
"""
def __init__(self, binary_quantifier, n_jobs=-1):
assert isinstance(self.binary_quantifier, BaseQuantifier), \
f'{self.binary_quantifier} does not seem to be a Quantifier'
assert isinstance(self.binary_quantifier, AggregativeQuantifier), \
f'{self.binary_quantifier} does not seem to be of type Aggregative'
self.binary_quantifier = binary_quantifier
self.n_jobs = n_jobs
def fit(self, data: LabelledCollection, fit_learner=True):
assert not data.binary, \
f'{self.__class__.__name__} expect non-binary data'
assert isinstance(self.binary_quantifier, BaseQuantifier), \
f'{self.binary_quantifier} does not seem to be a Quantifier'
assert fit_learner == True, 'fit_learner must be True'
assert fit_learner == True, \
'fit_learner must be True'
self.dict_binary_quantifiers = {c: deepcopy(self.binary_quantifier) for c in data.classes_}
self.__parallel(self._delayed_binary_fit, data)
@ -1057,42 +1060,11 @@ class OneVsAll(AggregativeQuantifier):
return np.swapaxes(classif_predictions, 0, 1)
else:
return classif_predictions.T
#
# def posterior_probabilities(self, instances):
# """
# Returns a matrix of shape `(n,m,2)` with `n` the number of instances and `m` the number of classes. The entry
# `(i,j,1)` (resp. `(i,j,0)`) is a value in [0,1] indicating the posterior probability that instance `i` belongs
# (resp. does not belong) to class `j`.
# The posterior probabilities are independent of each other, meaning that, in general, they do not sum
# up to one.
#
# :param instances: array-like
# :return: `np.ndarray`
# """
#
# if not isinstance(self.binary_quantifier, AggregativeProbabilisticQuantifier):
# raise NotImplementedError(f'{self.__class__.__name__} does not implement posterior_probabilities because '
# f'the base quantifier {self.binary_quantifier.__class__.__name__} is not '
# f'probabilistic')
# posterior_predictions_bin = self.__parallel(self._delayed_binary_posteriors, instances)
# return np.swapaxes(posterior_predictions_bin, 0, 1)
def aggregate(self, classif_predictions):
# if self.probabilistic:
# assert classif_predictions.shape[1] == self.n_classes and classif_predictions.shape[2] == 2, \
# 'param classif_predictions_bin does not seem to be a valid matrix (ndarray) of posterior ' \
# 'probabilities (2 dimensions) for each document (row) and class (columns)'
# else:
# assert set(np.unique(classif_predictions)).issubset({0, 1}), \
# 'param classif_predictions_bin does not seem to be a valid matrix (ndarray) of binary ' \
# 'predictions for each document (row) and class (columns)'
prevalences = self.__parallel(self._delayed_binary_aggregate, classif_predictions)
return F.normalize_prevalence(prevalences)
# def quantify(self, X):
# predictions = self.classify(X)
# return self.aggregate(predictions)
def __parallel(self, func, *args, **kwargs):
return np.asarray(
# some quantifiers (in particular, ELM-based ones) cannot be run with multiprocess, since the temp dir they

View File

@ -1,4 +1,5 @@
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from quapy.data import LabelledCollection
@ -62,52 +63,50 @@ class BinaryQuantifier(BaseQuantifier):
assert data.binary, f'{quantifier_name} works only on problems of binary classification. ' \
f'Use the class OneVsAll to enable {quantifier_name} work on single-label data.'
# class OneVsAll:
# """
# Allows any binary quantifier to perform quantification on single-label datasets. The method maintains one binary
# quantifier for each class, and then l1-normalizes the outputs so that the class prevelences sum up to 1.
# """
#
# def __init__(self, binary_method, n_jobs=-1):
# self.binary_method = binary_method
# self.n_jobs = n_jobs
#
# def fit(self, data: LabelledCollection, **kwargs):
# assert not data.binary, f'{self.__class__.__name__} expect non-binary data'
# assert isinstance(self.binary_method, BaseQuantifier), f'{self.binary_method} does not seem to be a Quantifier'
# self.class_method = {c: deepcopy(self.binary_method) for c in data.classes_}
# Parallel(n_jobs=self.n_jobs, backend='threading')(
# delayed(self._delayed_binary_fit)(c, self.class_method, data, **kwargs) for c in data.classes_
# )
# return self
#
# def quantify(self, X, *args):
# prevalences = np.asarray(
# Parallel(n_jobs=self.n_jobs, backend='threading')(
# delayed(self._delayed_binary_predict)(c, self.class_method, X) for c in self.classes
# )
# )
# return F.normalize_prevalence(prevalences)
#
# @property
# def classes(self):
# return sorted(self.class_method.keys())
#
# def set_params(self, **parameters):
# self.binary_method.set_params(**parameters)
#
# def get_params(self, deep=True):
# return self.binary_method.get_params()
#
# def _delayed_binary_predict(self, c, learners, X):
# return learners[c].quantify(X)[:,1] # the mean is the estimation for the positive class prevalence
#
# def _delayed_binary_fit(self, c, learners, data, **kwargs):
# bindata = LabelledCollection(data.instances, data.labels == c, n_classes=2)
# learners[c].fit(bindata, **kwargs)
class OneVsAllGeneric:
"""
Allows any binary quantifier to perform quantification on single-label datasets. The method maintains one binary
quantifier for each class, and then l1-normalizes the outputs so that the class prevelences sum up to 1.
"""
def __init__(self, binary_quantifier, n_jobs=1):
assert isinstance(binary_quantifier, BaseQuantifier), \
f'{binary_quantifier} does not seem to be a Quantifier'
self.binary_quantifier = binary_quantifier
self.n_jobs = n_jobs
def fit(self, data: LabelledCollection, **kwargs):
assert not data.binary, \
f'{self.__class__.__name__} expect non-binary data'
self.class_quatifier = {c: deepcopy(self.binary_quantifier) for c in data.classes_}
Parallel(n_jobs=self.n_jobs, backend='threading')(
delayed(self._delayed_binary_fit)(c, self.class_quatifier, data, **kwargs) for c in data.classes_
)
return self
def quantify(self, X, *args):
prevalences = np.asarray(
Parallel(n_jobs=self.n_jobs, backend='threading')(
delayed(self._delayed_binary_predict)(c, self.class_quatifier, X) for c in self.classes
)
)
return F.normalize_prevalence(prevalences)
@property
def classes(self):
return sorted(self.class_quatifier.keys())
def set_params(self, **parameters):
self.binary_quantifier.set_params(**parameters)
def get_params(self, deep=True):
return self.binary_quantifier.get_params()
def _delayed_binary_predict(self, c, learners, X):
return learners[c].quantify(X)[:,1] # the mean is the estimation for the positive class prevalence
def _delayed_binary_fit(self, c, learners, data, **kwargs):
bindata = LabelledCollection(data.instances, data.labels == c, n_classes=2)
learners[c].fit(bindata, **kwargs)

View File

@ -13,24 +13,6 @@ from os.path import exists
from glob import glob
# 0.1.7
# change the LabelledCollection API (removing protocol-related samplings)
# need to change the two references to the above in the wiki / doc, and code examples...
# removed artificial_prevalence_sampling from functional
# maybe add some parameters in the init of the protocols (or maybe only for IndexableWhateverProtocols
# indicating that the protocol should return indexes, and not samples themselves?
# also: some parameters in the init could be used to indicate that the method should return a tuple with
# unlabelled instances and the vector of prevalence values (and not a LabelledCollection).
# Or: this can be done in a different function; i.e., we use one function (now __call__) to return
# LabelledCollections, and another new one for returning the other output, which is more general for
# evaluation purposes.
# the so-called "gen" function has to be implemented as a protocol. The problem here is that this function
# should be able to return only unlabelled instances plus a vector of prevalences (and not LabelledCollections).
# This was coded as different functions in 0.1.6
class AbstractProtocol(metaclass=ABCMeta):
@abstractmethod
@ -133,11 +115,21 @@ class LoadSamplesFromDirectory(AbstractProtocol):
self.loader_fn = loader_fn
self.classes = classes
self.loader_kwargs = loader_kwargs
self._list_files = None
def __call__(self):
for file in sorted(glob(self.folder_path, '*')):
for file in self.list_files:
yield LabelledCollection.load(file, loader_func=self.loader_fn, classes=self.classes, **self.loader_kwargs)
@property
def list_files(self):
if self._list_files is None:
self._list_files = sorted(glob(self.folder_path, '*'))
return self._list_files
def total(self):
return len(self.list_files)
class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
"""