optimization conditional in the prediction function
This commit is contained in:
parent
4bc9d19635
commit
eba6fd8123
|
@ -1,9 +1,9 @@
|
|||
# main changes in 0.1.7
|
||||
|
||||
- Protocols is now an abstraction, AbstractProtocol. There is a new class extending AbstractProtocol called
|
||||
- Protocols are now abstracted as AbstractProtocol. There is a new class extending AbstractProtocol called
|
||||
AbstractStochasticSeededProtocol, which implements a seeding policy to allow replicate the series of samplings.
|
||||
There are some examples of protocols, APP, NPP, USimplexPP, CovariateShiftPP (experimental).
|
||||
The idea is to start the sampling by simpli calling the __call__ method.
|
||||
The idea is to start the sampling by simply calling the __call__ method.
|
||||
This change has a great impact in the framework, since many functions in qp.evaluation, qp.model_selection,
|
||||
and sampling functions in LabelledCollection make use of the old functions.
|
||||
|
||||
|
@ -11,7 +11,6 @@
|
|||
|
||||
|
||||
Things to fix:
|
||||
- eval budget policy?
|
||||
- clean functions like binary, aggregative, probabilistic, etc; those should be resolved via isinstance()
|
||||
- clean classes_ and n_classes from methods (maybe not from aggregative ones, but those have to be used only
|
||||
internally and not imposed in any abstract class)
|
||||
|
@ -32,3 +31,19 @@ Things to fix:
|
|||
classify & aggregate, irrespective of the class). However, this has caused a problem with OneVsAll. This has to
|
||||
be checked, since it is now innecessarily complicated (it also has old references to .probabilistic, and all this
|
||||
stuff).
|
||||
- Check method def __parallel(self, func, *args, **kwargs) in aggregative.OneVsAll
|
||||
|
||||
# 0.1.7
|
||||
# change the LabelledCollection API (removing protocol-related samplings)
|
||||
# need to change the two references to the above in the wiki / doc, and code examples...
|
||||
# removed artificial_prevalence_sampling from functional
|
||||
|
||||
# also: some parameters in the init could be used to indicate that the method should return a tuple with
|
||||
# unlabelled instances and the vector of prevalence values (and not a LabelledCollection).
|
||||
# Or: this can be done in a different function; i.e., we use one function (now __call__) to return
|
||||
# LabelledCollections, and another new one for returning the other output, which is more general for
|
||||
# evaluation purposes.
|
||||
|
||||
# the so-called "gen" function has to be implemented as a protocol. The problem here is that this function
|
||||
# should be able to return only unlabelled instances plus a vector of prevalences (and not LabelledCollections).
|
||||
# This was coded as different functions in 0.1.6
|
||||
|
|
|
@ -11,16 +11,35 @@ import quapy.functional as F
|
|||
import pandas as pd
|
||||
|
||||
|
||||
def prediction(model: BaseQuantifier, protocol: AbstractProtocol, verbose=False):
|
||||
def prediction(model: BaseQuantifier, protocol: AbstractProtocol, aggr_speedup='auto', verbose=False):
|
||||
assert aggr_speedup in [False, True, 'auto', 'force'], 'invalid value for aggr_speedup'
|
||||
|
||||
sout = lambda x: print(x) if verbose else None
|
||||
|
||||
apply_optimization = False
|
||||
|
||||
if aggr_speedup in [True, 'auto', 'force']:
|
||||
# checks whether the prediction can be made more efficiently; this check consists in verifying if the model is
|
||||
# of type aggregative, if the protocol is based on LabelledCollection, and if the total number of documents to
|
||||
# classify using the protocol would exceed the number of test documents in the original collection
|
||||
from method.aggregative import AggregativeQuantifier
|
||||
if isinstance(model, AggregativeQuantifier) and isinstance(protocol, OnLabelledCollectionProtocol):
|
||||
sout('speeding up the prediction for the aggregative quantifier')
|
||||
if aggr_speedup == 'force':
|
||||
apply_optimization = True
|
||||
sout(f'forcing aggregative speedup')
|
||||
elif hasattr(protocol, 'sample_size'):
|
||||
nD = len(protocol.get_labelled_collection())
|
||||
samplesD = protocol.total() * protocol.sample_size
|
||||
if nD < samplesD:
|
||||
apply_optimization = True
|
||||
sout(f'speeding up the prediction for the aggregative quantifier, '
|
||||
f'total classifications {nD} instead of {samplesD}')
|
||||
|
||||
if apply_optimization:
|
||||
pre_classified = model.classify(protocol.get_labelled_collection().instances)
|
||||
return __prediction_helper(model.aggregate, protocol.on_preclassified_instances(pre_classified), verbose)
|
||||
protocol_with_predictions = protocol.on_preclassified_instances(pre_classified)
|
||||
return __prediction_helper(model.aggregate, protocol_with_predictions, verbose)
|
||||
else:
|
||||
sout(f'the method is not aggregative, or the protocol is not an instance of '
|
||||
f'{OnLabelledCollectionProtocol.__name__}, so no optimization can be carried out')
|
||||
return __prediction_helper(model.quantify, protocol, verbose)
|
||||
|
||||
|
||||
|
@ -38,10 +57,11 @@ def __prediction_helper(quantification_fn, protocol: AbstractProtocol, verbose=F
|
|||
|
||||
def evaluation_report(model: BaseQuantifier,
|
||||
protocol: AbstractProtocol,
|
||||
error_metrics:Iterable[Union[str,Callable]]='mae',
|
||||
error_metrics: Iterable[Union[str,Callable]] = 'mae',
|
||||
aggr_speedup='auto',
|
||||
verbose=False):
|
||||
|
||||
true_prevs, estim_prevs = prediction(model, protocol, verbose)
|
||||
true_prevs, estim_prevs = prediction(model, protocol, aggr_speedup=aggr_speedup, verbose=verbose)
|
||||
return _prevalence_report(true_prevs, estim_prevs, error_metrics)
|
||||
|
||||
|
||||
|
@ -65,38 +85,18 @@ def _prevalence_report(true_prevs, estim_prevs, error_metrics: Iterable[Union[st
|
|||
return df
|
||||
|
||||
|
||||
def evaluate(model: BaseQuantifier, protocol: AbstractProtocol, error_metric:Union[str, Callable], verbose=False):
|
||||
def evaluate(
|
||||
model: BaseQuantifier,
|
||||
protocol: AbstractProtocol,
|
||||
error_metric:Union[str, Callable],
|
||||
aggr_speedup='auto',
|
||||
verbose=False):
|
||||
|
||||
if isinstance(error_metric, str):
|
||||
error_metric = qp.error.from_name(error_metric)
|
||||
true_prevs, estim_prevs = prediction(model, protocol, verbose)
|
||||
true_prevs, estim_prevs = prediction(model, protocol, aggr_speedup=aggr_speedup, verbose=verbose)
|
||||
return error_metric(true_prevs, estim_prevs)
|
||||
|
||||
|
||||
|
||||
def _check_num_evals(n_classes, n_prevpoints=None, eval_budget=None, repeats=1, verbose=False):
|
||||
if n_prevpoints is None and eval_budget is None:
|
||||
raise ValueError('either n_prevpoints or eval_budget has to be specified')
|
||||
elif n_prevpoints is None:
|
||||
assert eval_budget > 0, 'eval_budget must be a positive integer'
|
||||
n_prevpoints = F.get_nprevpoints_approximation(eval_budget, n_classes, repeats)
|
||||
eval_computations = F.num_prevalence_combinations(n_prevpoints, n_classes, repeats)
|
||||
if verbose:
|
||||
print(f'setting n_prevpoints={n_prevpoints} so that the number of '
|
||||
f'evaluations ({eval_computations}) does not exceed the evaluation '
|
||||
f'budget ({eval_budget})')
|
||||
elif eval_budget is None:
|
||||
eval_computations = F.num_prevalence_combinations(n_prevpoints, n_classes, repeats)
|
||||
if verbose:
|
||||
print(f'{eval_computations} evaluations will be performed for each '
|
||||
f'combination of hyper-parameters')
|
||||
else:
|
||||
eval_computations = F.num_prevalence_combinations(n_prevpoints, n_classes, repeats)
|
||||
if eval_computations > eval_budget:
|
||||
n_prevpoints = F.get_nprevpoints_approximation(eval_budget, n_classes, repeats)
|
||||
new_eval_computations = F.num_prevalence_combinations(n_prevpoints, n_classes, repeats)
|
||||
if verbose:
|
||||
print(f'the budget of evaluations would be exceeded with '
|
||||
f'n_prevpoints={n_prevpoints}. Chaning to n_prevpoints={n_prevpoints}. This will produce '
|
||||
f'{new_eval_computations} evaluation computations for each hyper-parameter combination.')
|
||||
return n_prevpoints, eval_computations
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ import scipy
|
|||
import numpy as np
|
||||
|
||||
|
||||
|
||||
def prevalence_linspace(n_prevalences=21, repeats=1, smooth_limits_epsilon=0.01):
|
||||
"""
|
||||
Produces an array of uniformly separated values of prevalence.
|
||||
|
|
|
@ -1023,15 +1023,18 @@ class OneVsAll(AggregativeQuantifier):
|
|||
"""
|
||||
|
||||
def __init__(self, binary_quantifier, n_jobs=-1):
|
||||
assert isinstance(self.binary_quantifier, BaseQuantifier), \
|
||||
f'{self.binary_quantifier} does not seem to be a Quantifier'
|
||||
assert isinstance(self.binary_quantifier, AggregativeQuantifier), \
|
||||
f'{self.binary_quantifier} does not seem to be of type Aggregative'
|
||||
self.binary_quantifier = binary_quantifier
|
||||
self.n_jobs = n_jobs
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_learner=True):
|
||||
assert not data.binary, \
|
||||
f'{self.__class__.__name__} expect non-binary data'
|
||||
assert isinstance(self.binary_quantifier, BaseQuantifier), \
|
||||
f'{self.binary_quantifier} does not seem to be a Quantifier'
|
||||
assert fit_learner == True, 'fit_learner must be True'
|
||||
assert fit_learner == True, \
|
||||
'fit_learner must be True'
|
||||
|
||||
self.dict_binary_quantifiers = {c: deepcopy(self.binary_quantifier) for c in data.classes_}
|
||||
self.__parallel(self._delayed_binary_fit, data)
|
||||
|
@ -1057,42 +1060,11 @@ class OneVsAll(AggregativeQuantifier):
|
|||
return np.swapaxes(classif_predictions, 0, 1)
|
||||
else:
|
||||
return classif_predictions.T
|
||||
#
|
||||
# def posterior_probabilities(self, instances):
|
||||
# """
|
||||
# Returns a matrix of shape `(n,m,2)` with `n` the number of instances and `m` the number of classes. The entry
|
||||
# `(i,j,1)` (resp. `(i,j,0)`) is a value in [0,1] indicating the posterior probability that instance `i` belongs
|
||||
# (resp. does not belong) to class `j`.
|
||||
# The posterior probabilities are independent of each other, meaning that, in general, they do not sum
|
||||
# up to one.
|
||||
#
|
||||
# :param instances: array-like
|
||||
# :return: `np.ndarray`
|
||||
# """
|
||||
#
|
||||
# if not isinstance(self.binary_quantifier, AggregativeProbabilisticQuantifier):
|
||||
# raise NotImplementedError(f'{self.__class__.__name__} does not implement posterior_probabilities because '
|
||||
# f'the base quantifier {self.binary_quantifier.__class__.__name__} is not '
|
||||
# f'probabilistic')
|
||||
# posterior_predictions_bin = self.__parallel(self._delayed_binary_posteriors, instances)
|
||||
# return np.swapaxes(posterior_predictions_bin, 0, 1)
|
||||
|
||||
def aggregate(self, classif_predictions):
|
||||
# if self.probabilistic:
|
||||
# assert classif_predictions.shape[1] == self.n_classes and classif_predictions.shape[2] == 2, \
|
||||
# 'param classif_predictions_bin does not seem to be a valid matrix (ndarray) of posterior ' \
|
||||
# 'probabilities (2 dimensions) for each document (row) and class (columns)'
|
||||
# else:
|
||||
# assert set(np.unique(classif_predictions)).issubset({0, 1}), \
|
||||
# 'param classif_predictions_bin does not seem to be a valid matrix (ndarray) of binary ' \
|
||||
# 'predictions for each document (row) and class (columns)'
|
||||
prevalences = self.__parallel(self._delayed_binary_aggregate, classif_predictions)
|
||||
return F.normalize_prevalence(prevalences)
|
||||
|
||||
# def quantify(self, X):
|
||||
# predictions = self.classify(X)
|
||||
# return self.aggregate(predictions)
|
||||
|
||||
def __parallel(self, func, *args, **kwargs):
|
||||
return np.asarray(
|
||||
# some quantifiers (in particular, ELM-based ones) cannot be run with multiprocess, since the temp dir they
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from abc import ABCMeta, abstractmethod
|
||||
from copy import deepcopy
|
||||
|
||||
from quapy.data import LabelledCollection
|
||||
|
||||
|
@ -62,52 +63,50 @@ class BinaryQuantifier(BaseQuantifier):
|
|||
assert data.binary, f'{quantifier_name} works only on problems of binary classification. ' \
|
||||
f'Use the class OneVsAll to enable {quantifier_name} work on single-label data.'
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# class OneVsAll:
|
||||
# """
|
||||
# Allows any binary quantifier to perform quantification on single-label datasets. The method maintains one binary
|
||||
# quantifier for each class, and then l1-normalizes the outputs so that the class prevelences sum up to 1.
|
||||
# """
|
||||
#
|
||||
# def __init__(self, binary_method, n_jobs=-1):
|
||||
# self.binary_method = binary_method
|
||||
# self.n_jobs = n_jobs
|
||||
#
|
||||
# def fit(self, data: LabelledCollection, **kwargs):
|
||||
# assert not data.binary, f'{self.__class__.__name__} expect non-binary data'
|
||||
# assert isinstance(self.binary_method, BaseQuantifier), f'{self.binary_method} does not seem to be a Quantifier'
|
||||
# self.class_method = {c: deepcopy(self.binary_method) for c in data.classes_}
|
||||
# Parallel(n_jobs=self.n_jobs, backend='threading')(
|
||||
# delayed(self._delayed_binary_fit)(c, self.class_method, data, **kwargs) for c in data.classes_
|
||||
# )
|
||||
# return self
|
||||
#
|
||||
# def quantify(self, X, *args):
|
||||
# prevalences = np.asarray(
|
||||
# Parallel(n_jobs=self.n_jobs, backend='threading')(
|
||||
# delayed(self._delayed_binary_predict)(c, self.class_method, X) for c in self.classes
|
||||
# )
|
||||
# )
|
||||
# return F.normalize_prevalence(prevalences)
|
||||
#
|
||||
# @property
|
||||
# def classes(self):
|
||||
# return sorted(self.class_method.keys())
|
||||
#
|
||||
# def set_params(self, **parameters):
|
||||
# self.binary_method.set_params(**parameters)
|
||||
#
|
||||
# def get_params(self, deep=True):
|
||||
# return self.binary_method.get_params()
|
||||
#
|
||||
# def _delayed_binary_predict(self, c, learners, X):
|
||||
# return learners[c].quantify(X)[:,1] # the mean is the estimation for the positive class prevalence
|
||||
#
|
||||
# def _delayed_binary_fit(self, c, learners, data, **kwargs):
|
||||
# bindata = LabelledCollection(data.instances, data.labels == c, n_classes=2)
|
||||
# learners[c].fit(bindata, **kwargs)
|
||||
class OneVsAllGeneric:
|
||||
"""
|
||||
Allows any binary quantifier to perform quantification on single-label datasets. The method maintains one binary
|
||||
quantifier for each class, and then l1-normalizes the outputs so that the class prevelences sum up to 1.
|
||||
"""
|
||||
|
||||
def __init__(self, binary_quantifier, n_jobs=1):
|
||||
assert isinstance(binary_quantifier, BaseQuantifier), \
|
||||
f'{binary_quantifier} does not seem to be a Quantifier'
|
||||
self.binary_quantifier = binary_quantifier
|
||||
self.n_jobs = n_jobs
|
||||
|
||||
def fit(self, data: LabelledCollection, **kwargs):
|
||||
assert not data.binary, \
|
||||
f'{self.__class__.__name__} expect non-binary data'
|
||||
self.class_quatifier = {c: deepcopy(self.binary_quantifier) for c in data.classes_}
|
||||
Parallel(n_jobs=self.n_jobs, backend='threading')(
|
||||
delayed(self._delayed_binary_fit)(c, self.class_quatifier, data, **kwargs) for c in data.classes_
|
||||
)
|
||||
return self
|
||||
|
||||
def quantify(self, X, *args):
|
||||
prevalences = np.asarray(
|
||||
Parallel(n_jobs=self.n_jobs, backend='threading')(
|
||||
delayed(self._delayed_binary_predict)(c, self.class_quatifier, X) for c in self.classes
|
||||
)
|
||||
)
|
||||
return F.normalize_prevalence(prevalences)
|
||||
|
||||
@property
|
||||
def classes(self):
|
||||
return sorted(self.class_quatifier.keys())
|
||||
|
||||
def set_params(self, **parameters):
|
||||
self.binary_quantifier.set_params(**parameters)
|
||||
|
||||
def get_params(self, deep=True):
|
||||
return self.binary_quantifier.get_params()
|
||||
|
||||
def _delayed_binary_predict(self, c, learners, X):
|
||||
return learners[c].quantify(X)[:,1] # the mean is the estimation for the positive class prevalence
|
||||
|
||||
def _delayed_binary_fit(self, c, learners, data, **kwargs):
|
||||
bindata = LabelledCollection(data.instances, data.labels == c, n_classes=2)
|
||||
learners[c].fit(bindata, **kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -13,24 +13,6 @@ from os.path import exists
|
|||
from glob import glob
|
||||
|
||||
|
||||
# 0.1.7
|
||||
# change the LabelledCollection API (removing protocol-related samplings)
|
||||
# need to change the two references to the above in the wiki / doc, and code examples...
|
||||
# removed artificial_prevalence_sampling from functional
|
||||
|
||||
# maybe add some parameters in the init of the protocols (or maybe only for IndexableWhateverProtocols
|
||||
# indicating that the protocol should return indexes, and not samples themselves?
|
||||
# also: some parameters in the init could be used to indicate that the method should return a tuple with
|
||||
# unlabelled instances and the vector of prevalence values (and not a LabelledCollection).
|
||||
# Or: this can be done in a different function; i.e., we use one function (now __call__) to return
|
||||
# LabelledCollections, and another new one for returning the other output, which is more general for
|
||||
# evaluation purposes.
|
||||
|
||||
# the so-called "gen" function has to be implemented as a protocol. The problem here is that this function
|
||||
# should be able to return only unlabelled instances plus a vector of prevalences (and not LabelledCollections).
|
||||
# This was coded as different functions in 0.1.6
|
||||
|
||||
|
||||
class AbstractProtocol(metaclass=ABCMeta):
|
||||
|
||||
@abstractmethod
|
||||
|
@ -133,11 +115,21 @@ class LoadSamplesFromDirectory(AbstractProtocol):
|
|||
self.loader_fn = loader_fn
|
||||
self.classes = classes
|
||||
self.loader_kwargs = loader_kwargs
|
||||
self._list_files = None
|
||||
|
||||
def __call__(self):
|
||||
for file in sorted(glob(self.folder_path, '*')):
|
||||
for file in self.list_files:
|
||||
yield LabelledCollection.load(file, loader_func=self.loader_fn, classes=self.classes, **self.loader_kwargs)
|
||||
|
||||
@property
|
||||
def list_files(self):
|
||||
if self._list_files is None:
|
||||
self._list_files = sorted(glob(self.folder_path, '*'))
|
||||
return self._list_files
|
||||
|
||||
def total(self):
|
||||
return len(self.list_files)
|
||||
|
||||
|
||||
class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue