refactoring no labelled collection and other improvements in EMQ

This commit is contained in:
Alejandro Moreo Fernandez 2025-04-23 12:25:05 +02:00
parent 5738821d10
commit 960ca5076e
10 changed files with 537 additions and 517 deletions

View File

@ -1,10 +1,32 @@
Change Log 0.1.10
-----------------
- Base code Refactor:
- Removing coupling between LabelledCollection and quantification methods. E.g.:
def fit(data:LabelledCollection): -> def fit(X, y):
- Adding function "predict" (function "quantify" is still present as an alias)
- Aggregative methods's behavior in terms of fit_classifier and how to treat the val_split is now
indicated exclusively at construction time, and it is no longer possible to indicate it at fit time.
This is because, in v<=0.1.9, one could create a method (e.g., ACC) and then indicate:
my_acc.fit(tr_data, fit_classifier=False, val_split=val_data)
in which case the first argument is unused, and this was ambiguous with
my_acc.fit(the_data, fit_classifier=False)
in which case the_data is to be used for validation purposes. However, the val_split could be set as a fraction
indicating only part of the_data must be used for validation, and the rest wasted... it was confusing.
- EMQ has been modified, so that the representation function "classify" now only provides posterior
probabilities and, if required, these are recalibrated (e.g., by "bcts") during the aggregation function.
- A new parameter "on_calib_error" is passed to the constructor, which informs of the policy to follow
in case the calibration functions failed. Options include:
- 'raise': raises a RuntimeException (default)
- 'backup': avoids calibration
- Parameter "recalib" has been renamed "calib"
- Added aggregative bootstrap for deriving confidence regions (confidence intervals, ellipses in the simplex, or
ellipses in the CLR space). This method is efficient as it leverages the two-phases of the aggregative quantifiers.
This method applies resampling only to the aggregation phase, thus avoiding to train many quantifiers, or
classify multiple times the instances of a sample. See the new example no. 15.
classify multiple times the instances of a sample. See:
- quapy/method/confidence.py (new)
- the new example no. 15.
- BayesianCC moved to confidence.py, where methods having to do with confidence intervals live
Change Log 0.1.9

View File

@ -50,7 +50,7 @@ def quantification_models():
yield 'MAX', MAX(newLR()), lr_params
yield 'MS', MS(newLR()), lr_params
yield 'MS2', MS2(newLR()), lr_params
yield 'sldc', EMQ(newLR(), recalib='platt'), lr_params
yield 'sldc', EMQ(newLR(), calib='platt'), lr_params
yield 'svmmae', newSVMAE(), svmperf_params
yield 'hdy', HDy(newLR()), lr_params

View File

@ -99,26 +99,27 @@ class KDEyML(AggregativeSoftQuantifier, KDEBase):
which corresponds to the maximum likelihood estimate.
:param classifier: a sklearn's Estimator that generates a binary classifier.
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a collection defining the specific set of data to use for validation.
Alternatively, this set can be specified at fit time by indicating the exact set of data
on which the predictions are to be generated.
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
:param bandwidth: float, the bandwidth of the Kernel
:param random_state: a seed to be set before fitting any base quantifier (default None)
"""
def __init__(self, classifier: BaseEstimator=None, val_split=5, bandwidth=0.1, random_state=None):
self.classifier = qp._get_classifier(classifier)
self.val_split = val_split
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5, bandwidth=0.1,
random_state=None):
super().__init__(classifier, fit_classifier, val_split)
self.bandwidth = KDEBase._check_bandwidth(bandwidth)
self.random_state=random_state
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
self.mix_densities = self.get_mixture_components(*classif_predictions.Xy, data.classes_, self.bandwidth)
def aggregation_fit(self, classif_predictions, labels):
self.mix_densities = self.get_mixture_components(classif_predictions, labels, self.classes_, self.bandwidth)
return self
def aggregate(self, posteriors: np.ndarray):
@ -173,35 +174,35 @@ class KDEyHD(AggregativeSoftQuantifier, KDEBase):
where the datapoints (trials) :math:`x_1,\\ldots,x_t\\sim_{\\mathrm{iid}} r` with :math:`r` the
uniform distribution.
:param classifier: a sklearn's Estimator that generates a binary classifier.
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a collection defining the specific set of data to use for validation.
Alternatively, this set can be specified at fit time by indicating the exact set of data
on which the predictions are to be generated.
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
:param bandwidth: float, the bandwidth of the Kernel
:param random_state: a seed to be set before fitting any base quantifier (default None)
:param montecarlo_trials: number of Monte Carlo trials (default 10000)
"""
def __init__(self, classifier: BaseEstimator=None, val_split=5, divergence: str='HD',
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5, divergence: str='HD',
bandwidth=0.1, random_state=None, montecarlo_trials=10000):
self.classifier = qp._get_classifier(classifier)
self.val_split = val_split
super().__init__(classifier, fit_classifier, val_split)
self.divergence = divergence
self.bandwidth = KDEBase._check_bandwidth(bandwidth)
self.random_state=random_state
self.montecarlo_trials = montecarlo_trials
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
self.mix_densities = self.get_mixture_components(*classif_predictions.Xy, data.classes_, self.bandwidth)
def aggregation_fit(self, classif_predictions, labels):
self.mix_densities = self.get_mixture_components(classif_predictions, labels, self.classes_, self.bandwidth)
N = self.montecarlo_trials
rs = self.random_state
n = data.n_classes
n = len(self.classes_)
self.reference_samples = np.vstack([kde_i.sample(N//n, random_state=rs) for kde_i in self.mix_densities])
self.reference_classwise_densities = np.asarray([self.pdf(kde_j, self.reference_samples) for kde_j in self.mix_densities])
self.reference_density = np.mean(self.reference_classwise_densities, axis=0) # equiv. to (uniform @ self.reference_classwise_densities)
@ -265,20 +266,20 @@ class KDEyCS(AggregativeSoftQuantifier):
The authors showed that this distribution matching admits a closed-form solution
:param classifier: a sklearn's Estimator that generates a binary classifier.
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a collection defining the specific set of data to use for validation.
Alternatively, this set can be specified at fit time by indicating the exact set of data
on which the predictions are to be generated.
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
:param bandwidth: float, the bandwidth of the Kernel
"""
def __init__(self, classifier: BaseEstimator=None, val_split=5, bandwidth=0.1):
self.classifier = qp._get_classifier(classifier)
self.val_split = val_split
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5, bandwidth=0.1):
super().__init__(classifier, fit_classifier, val_split)
self.bandwidth = KDEBase._check_bandwidth(bandwidth)
def gram_matrix_mix_sum(self, X, Y=None):
@ -293,17 +294,17 @@ class KDEyCS(AggregativeSoftQuantifier):
gram = norm_factor * rbf_kernel(X, Y, gamma=gamma)
return gram.sum()
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
def aggregation_fit(self, classif_predictions, labels):
P, y = classif_predictions.Xy
n = data.n_classes
P, y = classif_predictions, labels
n = len(self.classes_)
assert all(sorted(np.unique(y)) == np.arange(n)), \
'label name gaps not allowed in current implementation'
# counts_inv keeps track of the relative weight of each datapoint within its class
# (i.e., the weight in its KDE model)
counts_inv = 1 / (data.counts())
counts_inv = 1 / (F.counts_from_labels(y, classes=self.classes_))
# tr_tr_sums corresponds to symbol \overline{B} in the paper
tr_tr_sums = np.zeros(shape=(n,n), dtype=float)

View File

@ -21,13 +21,13 @@ class QuaNetTrainer(BaseQuantifier):
Example:
>>> import quapy as qp
>>> from quapy.method_name.meta import QuaNet
>>> from quapy.method.meta import QuaNet
>>> from quapy.classification.neural import NeuralClassifierTrainer, CNNnet
>>>
>>> # use samples of 100 elements
>>> qp.environ['SAMPLE_SIZE'] = 100
>>>
>>> # load the kindle dataset as text, and convert words to numerical indexes
>>> # load the Kindle dataset as text, and convert words to numerical indexes
>>> dataset = qp.datasets.fetch_reviews('kindle', pickle=True)
>>> qp.train.preprocessing.index(dataset, min_df=5, inplace=True)
>>>
@ -37,12 +37,14 @@ class QuaNetTrainer(BaseQuantifier):
>>>
>>> # train QuaNet (QuaNet is an alias to QuaNetTrainer)
>>> model = QuaNet(classifier, qp.environ['SAMPLE_SIZE'], device='cuda')
>>> model.fit(dataset.training)
>>> model.fit(*dataset.training.Xy)
>>> estim_prevalence = model.predict(dataset.test.instances)
:param classifier: an object implementing `fit` (i.e., that can be trained on labelled data),
`predict_proba` (i.e., that can generate posterior probabilities of unlabelled examples) and
`transform` (i.e., that can generate embedded representations of the unlabelled instances).
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param sample_size: integer, the sample size; default is None, meaning that the sample size should be
taken from qp.environ["SAMPLE_SIZE"]
:param n_epochs: integer, maximum number of training epochs
@ -64,6 +66,7 @@ class QuaNetTrainer(BaseQuantifier):
def __init__(self,
classifier,
fit_classifier=True,
sample_size=None,
n_epochs=100,
tr_iter_per_poch=500,
@ -86,6 +89,7 @@ class QuaNetTrainer(BaseQuantifier):
f'the classifier {classifier.__class__.__name__} does not seem to be able to produce posterior probabilities ' \
f'since it does not implement the method "predict_proba"'
self.classifier = classifier
self.fit_classifier = fit_classifier
self.sample_size = qp._get_sample_size(sample_size)
self.n_epochs = n_epochs
self.tr_iter = tr_iter_per_poch
@ -111,20 +115,21 @@ class QuaNetTrainer(BaseQuantifier):
self.__check_params_colision(self.quanet_params, self.classifier.get_params())
self._classes_ = None
def fit(self, data: LabelledCollection, fit_classifier=True):
def fit(self, X, y):
"""
Trains QuaNet.
:param data: the training data on which to train QuaNet. If `fit_classifier=True`, the data will be split in
:param X: the training instances on which to train QuaNet. If `fit_classifier=True`, the data will be split in
40/40/20 for training the classifier, training QuaNet, and validating QuaNet, respectively. If
`fit_classifier=False`, the data will be split in 66/34 for training QuaNet and validating it, respectively.
:param fit_classifier: if True, trains the classifier on a split containing 40% of the data
:param y: the labels of X
:return: self
"""
data = LabelledCollection(X, y)
self._classes_ = data.classes_
os.makedirs(self.checkpointdir, exist_ok=True)
if fit_classifier:
if self.fit_classifier:
classifier_data, unused_data = data.split_stratified(0.4)
train_data, valid_data = unused_data.split_stratified(0.66) # 0.66 split of 60% makes 40% and 20%
self.classifier.fit(*classifier_data.Xy)

View File

@ -18,18 +18,23 @@ class ThresholdOptimization(BinaryAggregativeQuantifier):
that would allow for more true positives and many more false positives, on the grounds this
would deliver larger denominators.
:param classifier: a sklearn's Estimator that generates a classifier
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
misclassification rates are to be estimated.
This parameter can be indicated as a real value (between 0 and 1), representing a proportion of
validation data, or as an integer, indicating that the misclassification rates should be estimated via
`k`-fold cross validation (this integer stands for the number of folds `k`, defaults 5), or as a
:class:`quapy.data.base.LabelledCollection` (the split itself).
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
:param n_jobs: number of parallel workers
"""
def __init__(self, classifier: BaseEstimator=None, val_split=None, n_jobs=None):
self.classifier = qp._get_classifier(classifier)
self.val_split = val_split
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=None, n_jobs=None):
super.__init__(classifier, fit_classifier, val_split)
self.n_jobs = qp._get_njobs(n_jobs)
@abstractmethod
@ -115,8 +120,8 @@ class ThresholdOptimization(BinaryAggregativeQuantifier):
return 0
return FP / (FP + TN)
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
decision_scores, y = classif_predictions.Xy
def aggregation_fit(self, classif_predictions, labels):
decision_scores, y = classif_predictions, labels
# the standard behavior is to keep the best threshold only
self.tpr, self.fpr, self.threshold = self._eval_candidate_thresholds(decision_scores, y)[0]
return self
@ -134,17 +139,22 @@ class T50(ThresholdOptimization):
for the threshold that makes `tpr` closest to 0.5.
The goal is to bring improved stability to the denominator of the adjustment.
:param classifier: a sklearn's Estimator that generates a classifier
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
misclassification rates are to be estimated.
This parameter can be indicated as a real value (between 0 and 1), representing a proportion of
validation data, or as an integer, indicating that the misclassification rates should be estimated via
`k`-fold cross validation (this integer stands for the number of folds `k`, defaults 5), or as a
:class:`quapy.data.base.LabelledCollection` (the split itself).
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
"""
def __init__(self, classifier: BaseEstimator=None, val_split=5):
super().__init__(classifier, val_split)
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5):
super().__init__(classifier, fit_classifier, val_split)
def condition(self, tpr, fpr) -> float:
return abs(tpr - 0.5)
@ -158,17 +168,20 @@ class MAX(ThresholdOptimization):
for the threshold that maximizes `tpr-fpr`.
The goal is to bring improved stability to the denominator of the adjustment.
:param classifier: a sklearn's Estimator that generates a classifier
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
misclassification rates are to be estimated.
This parameter can be indicated as a real value (between 0 and 1), representing a proportion of
validation data, or as an integer, indicating that the misclassification rates should be estimated via
`k`-fold cross validation (this integer stands for the number of folds `k`, defaults 5), or as a
:class:`quapy.data.base.LabelledCollection` (the split itself).
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
"""
def __init__(self, classifier: BaseEstimator=None, val_split=5):
super().__init__(classifier, val_split)
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5):
super().__init__(classifier, fit_classifier, val_split)
def condition(self, tpr, fpr) -> float:
# MAX strives to maximize (tpr - fpr), which is equivalent to minimize (fpr - tpr)
@ -183,17 +196,20 @@ class X(ThresholdOptimization):
for the threshold that yields `tpr=1-fpr`.
The goal is to bring improved stability to the denominator of the adjustment.
:param classifier: a sklearn's Estimator that generates a classifier
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
misclassification rates are to be estimated.
This parameter can be indicated as a real value (between 0 and 1), representing a proportion of
validation data, or as an integer, indicating that the misclassification rates should be estimated via
`k`-fold cross validation (this integer stands for the number of folds `k`, defaults 5), or as a
:class:`quapy.data.base.LabelledCollection` (the split itself).
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
"""
def __init__(self, classifier: BaseEstimator=None, val_split=5):
super().__init__(classifier, val_split)
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5):
super().__init__(classifier, fit_classifier, val_split)
def condition(self, tpr, fpr) -> float:
return abs(1 - (tpr + fpr))
@ -207,22 +223,25 @@ class MS(ThresholdOptimization):
class prevalence estimates for all decision thresholds and returns the median of them all.
The goal is to bring improved stability to the denominator of the adjustment.
:param classifier: a sklearn's Estimator that generates a classifier
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
misclassification rates are to be estimated.
This parameter can be indicated as a real value (between 0 and 1), representing a proportion of
validation data, or as an integer, indicating that the misclassification rates should be estimated via
`k`-fold cross validation (this integer stands for the number of folds `k`, defaults 5), or as a
:class:`quapy.data.base.LabelledCollection` (the split itself).
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
"""
def __init__(self, classifier: BaseEstimator=None, val_split=5):
super().__init__(classifier, val_split)
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5):
super().__init__(classifier, fit_classifier, val_split)
def condition(self, tpr, fpr) -> float:
return 1
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
decision_scores, y = classif_predictions.Xy
def aggregation_fit(self, classif_predictions, labels):
decision_scores, y = classif_predictions, labels
# keeps all candidates
tprs_fprs_thresholds = self._eval_candidate_thresholds(decision_scores, y)
self.tprs = tprs_fprs_thresholds[:, 0]
@ -246,16 +265,19 @@ class MS2(MS):
which `tpr-fpr>0.25`
The goal is to bring improved stability to the denominator of the adjustment.
:param classifier: a sklearn's Estimator that generates a classifier
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set in which the
misclassification rates are to be estimated.
This parameter can be indicated as a real value (between 0 and 1), representing a proportion of
validation data, or as an integer, indicating that the misclassification rates should be estimated via
`k`-fold cross validation (this integer stands for the number of folds `k`, defaults 5), or as a
:class:`quapy.data.base.LabelledCollection` (the split itself).
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
"""
def __init__(self, classifier: BaseEstimator=None, val_split=5):
super().__init__(classifier, val_split)
def __init__(self, classifier: BaseEstimator=None, fit_classifier=True, val_split=5):
super().__init__(classifier, fit_classifier, val_split)
def discard(self, tpr, fpr) -> bool:
return (tpr-fpr) <= 0.25

View File

@ -3,6 +3,7 @@ from copy import deepcopy
from typing import Callable, Literal, Union
import numpy as np
from abstention.calibration import NoBiasVectorScaling, TempScaling, VectorScaling
from numpy.f2py.crackfortran import true_intent_list
from sklearn.base import BaseEstimator
from sklearn.calibration import CalibratedClassifierCV
from sklearn.exceptions import NotFittedError
@ -36,6 +37,18 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
The method :meth:`quantify` comes with a default implementation based on :meth:`classify`
and :meth:`aggregate`.
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a tuple `(X,y)` defining the specific set of data to use for validation. Set to
None when the method does not require any validation data, in order to avoid that some portion of
the training data be wasted.
"""
def __init__(self, classifier: Union[None,BaseEstimator], fit_classifier:bool=True, val_split:Union[int,float,tuple,None]=5):
@ -116,34 +129,13 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
Trains the aggregative quantifier. This comes down to training a classifier (if requested) and an
aggregation function.
:param X: array-like, the training instances
:param y: array-like, the labels
:param X: array-like of shape `(n_samples, n_features)`, the training instances
:param y: array-like of shape `(n_samples,)`, the labels
:return: self
"""
self._check_init_parameters()
classif_predictions = self.classifier_fit_predict(X, y)
self.aggregation_fit(classif_predictions)
return self
def fit_depr(self, data: LabelledCollection, fit_classifier=True, val_split=None):
"""
Trains the aggregative quantifier. This comes down to training a classifier and an aggregation function.
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a collection defining the specific set of data to use for validation.
Alternatively, this set can be specified at fit time by indicating the exact set of data
on which the predictions are to be generated.
:return: self
"""
self._check_init_parameters()
classif_predictions = self.classifier_fit_predict_depr(data, fit_classifier, predict_on=val_split)
self.aggregation_fit_depr(classif_predictions, data)
classif_predictions, labels = self.classifier_fit_predict(X, y)
self.aggregation_fit(classif_predictions, labels)
return self
def classifier_fit_predict(self, X, y):
@ -151,19 +143,20 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
Trains the classifier if requested (`fit_classifier=True`) and generate the necessary predictions to
train the aggregation function.
:param X: array-like, the training instances
:param y: array-like, the labels
:param X: array-like of shape `(n_samples, n_features)`, the training instances
:param y: array-like of shape `(n_samples,)`, the labels
"""
self._check_classifier()
# self._check_non_empty_classes(y)
predictions, labels = None, None
if isinstance(self.val_split, int):
assert self.fit_classifier, f'unexpected value for {self.fit_classifier=}'
num_folds = self.val_split
n_jobs = self.n_jobs if hasattr(self, 'n_jobs') else qp._get_njobs(None)
predictions = cross_val_predict(self.classifier, X, y, cv=num_folds, n_jobs=n_jobs, method=self._classifier_method())
yval = y
labels = y
self.classifier.fit(X, y)
elif isinstance(self.val_split, float):
assert self.fit_classifier, f'unexpected value for {self.fit_classifier=}'
@ -171,26 +164,30 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
Xtr, Xval, ytr, yval = train_test_split(X, y, train_size=train_prop, stratify=y)
self.classifier.fit(Xtr, ytr)
predictions = self.classify(Xval)
labels = yval
elif isinstance(self.val_split, tuple):
Xval, yval = self.val_split
if self.fit_classifier:
self.classifier.fit(X, y)
predictions = self.classify(Xval)
labels = yval
elif self.val_split is None:
if self.fit_classifier:
self.classifier.fit(X, y)
predictions, yval = None, None
predictions, labels = None, None
else:
raise ValueError(f'unexpected type for {self.val_split=}')
return predictions, yval
return predictions, labels
@abstractmethod
def aggregation_fit(self, classif_predictions, **kwargs):
def aggregation_fit(self, classif_predictions, labels):
"""
Trains the aggregation function.
:param classif_predictions: the classification predictions; whatever the method
:meth:`classify` returns
:param classif_predictions: array-like with the classification predictions
(whatever the method :meth:`classify` returns)
:param labels: array-like with the true labels associated to each classifier prediction
"""
...
@ -218,8 +215,8 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
:meth:`aggregate`, e.g., posterior probabilities for probabilistic quantifiers, or crisp predictions for
non-probabilistic quantifiers. The default one is "decision_function".
:param X: array-like of shape `(n_instances, n_features,)`
:return: np.ndarray of shape `(n_instances,)` with label predictions
:param X: array-like of shape `(n_samples, n_features)`, the data instances
:return: np.ndarray of shape `(n_instances,)` with classifier predictions
"""
return getattr(self.classifier, self._classifier_method())(X)
@ -236,7 +233,7 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
Guarantees that the underlying classifier implements the method required for issuing predictions, i.e.,
the method indicated by the :meth:`_classifier_method`
:param adapt_if_necessary: unused unless overriden
:param adapt_if_necessary: unused unless overridden
"""
assert hasattr(self.classifier, self._classifier_method()), \
f"the method does not implement the required {self._classifier_method()} method"
@ -246,7 +243,7 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
Generate class prevalence estimates for the sample's instances by aggregating the label predictions generated
by the classifier.
:param X: array-like
:param X: array-like of shape `(n_samples, n_features)`, the data instances
:return: `np.ndarray` of shape `(n_classes)` with class prevalence estimates.
"""
classif_predictions = self.classify(X)
@ -257,7 +254,7 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
"""
Implements the aggregation of the classifier predictions.
:param classif_predictions: `np.ndarray` of label predictions
:param classif_predictions: `np.ndarray` of classifier predictions
:return: `np.ndarray` of shape `(n_classes,)` with class prevalence estimates.
"""
...
@ -268,7 +265,7 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
Class labels, in the same order in which class prevalence values are to be computed.
This default implementation actually returns the class labels of the learner.
:return: array-like
:return: array-like, the class labels
"""
return self.classifier.classes_
@ -356,11 +353,12 @@ class CC(AggregativeCrispQuantifier):
def __init__(self, classifier: BaseEstimator = None, fit_classifier: bool = True):
super().__init__(classifier, fit_classifier, val_split=None)
def aggregation_fit(self, classif_predictions):
def aggregation_fit(self, classif_predictions, labels):
"""
Nothing to do here!
:param classif_predictions: not used
:param classif_predictions: unused
:param labels: unused
"""
pass
@ -368,7 +366,7 @@ class CC(AggregativeCrispQuantifier):
"""
Computes class prevalence estimates by counting the prevalence of each of the predicted labels.
:param classif_predictions: array-like with label predictions
:param classif_predictions: array-like with classifier predictions
:return: `np.ndarray` of shape `(n_classes,)` with class prevalence estimates.
"""
return F.prevalence_from_labels(classif_predictions, self.classes_)
@ -385,11 +383,12 @@ class PCC(AggregativeSoftQuantifier):
def __init__(self, classifier: BaseEstimator = None, fit_classifier: bool = True):
super().__init__(classifier, fit_classifier, val_split=None)
def aggregation_fit(self, classif_predictions):
def aggregation_fit(self, classif_predictions, labels):
"""
Nothing to do here!
:param classif_predictions: not used
:param classif_predictions: unused
:param labels: unused
"""
pass
@ -403,15 +402,17 @@ class ACC(AggregativeCrispQuantifier):
the "adjusted" variant of :class:`CC`, that corrects the predictions of CC
according to the `misclassification rates`.
:param classifier: a sklearn's Estimator that generates a classifier
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a collection defining the specific set of data to use for validation.
Alternatively, this set can be specified at fit time by indicating the exact set of data
on which the predictions are to be generated.
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
:param str method: adjustment method to be used:
@ -470,16 +471,20 @@ class ACC(AggregativeCrispQuantifier):
`Vaz et al. 2018 <https://jmlr.org/papers/v20/18-456.html>`_. This amounts
to setting method to 'invariant-ratio' and clipping to 'project'.
:param classifier: a sklearn's Estimator that generates a classifier
:param fit_classifier: bool, whether to fit the classifier or not
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a collection defining the specific set of data to use for validation.
Alternatively, this set can be specified at fit time by indicating the exact set of data
on which the predictions are to be generated.
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
:param n_jobs: number of parallel workers
:return: an instance of ACC configured so that it implements the Invariant Ratio Estimator
"""
return ACC(classifier, fit_classifier=fit_classifier, val_split=val_split, method='invariant-ratio', norm='mapsimplex', n_jobs=n_jobs)
@ -492,15 +497,14 @@ class ACC(AggregativeCrispQuantifier):
if self.norm not in ACC.NORMALIZATIONS:
raise ValueError(f"unknown normalization; valid ones are {ACC.NORMALIZATIONS}")
def aggregation_fit(self, classif_predictions):
def aggregation_fit(self, classif_predictions, labels):
"""
Estimates the misclassification rates.
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
as instances, the label predictions issued by the classifier and, as labels, the true labels
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
:param classif_predictions: array-like with the predicted labels
:param labels: array-like with the true labels associated to each predicted label
"""
pred_labels, true_labels = classif_predictions
true_labels = labels
pred_labels = classif_predictions
self.cc = CC(self.classifier, fit_classifier=False)
self.Pte_cond_estim_ = ACC.getPteCondEstim(self.classifier.classes_, true_labels, pred_labels)
@ -541,16 +545,17 @@ class PACC(AggregativeSoftQuantifier):
`Probabilistic Adjusted Classify & Count <https://ieeexplore.ieee.org/abstract/document/5694031>`_,
the probabilistic variant of ACC that relies on the posterior probabilities returned by a probabilistic classifier.
:param classifier: a sklearn's Estimator that generates a classifier
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: bool, whether to fit the classifier or not
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`). Alternatively, this set can be specified at fit time by indicating the exact set of data
on which the predictions are to be generated.
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
:param str method: adjustment method to be used:
@ -606,15 +611,15 @@ class PACC(AggregativeSoftQuantifier):
if self.norm not in ACC.NORMALIZATIONS:
raise ValueError(f"unknown normalization; valid ones are {ACC.NORMALIZATIONS}")
def aggregation_fit(self, classif_predictions):
def aggregation_fit(self, classif_predictions, labels):
"""
Estimates the misclassification rates
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
:param classif_predictions: array-like with posterior probabilities
:param labels: array-like with the true labels associated to each vector of posterior probabilities
"""
posteriors, true_labels = classif_predictions
posteriors = classif_predictions
true_labels = labels
self.pcc = PCC(self.classifier, fit_classifier=False)
self.Pte_cond_estim_ = PACC.getPteCondEstim(self.classifier.classes_, true_labels, posteriors)
@ -656,61 +661,98 @@ class EMQ(AggregativeSoftQuantifier):
prevalence, an estimate of it obtained via k-fold cross validation (instead of the true training prevalence),
and to recalibrate the posterior probabilities of the classifier.
:param classifier: a sklearn's Estimator that generates a classifier
:param fit_classifier: bool, whether to fit the classifier or not
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer, indicating that the predictions
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`, default 5); or as a collection defining the specific set of data to use for validation.
Alternatively, this set can be specified at fit time by indicating the exact set of data
on which the predictions are to be generated. This hyperparameter is only meant to be used when the
heuristics are to be applied, i.e., if a recalibration is required. The default value is None (meaning
the recalibration is not required). In case this hyperparameter is set to a value other than None, but
the recalibration is not required (recalib=None), a warning message will be raised.
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
This hyperparameter is only meant to be used when the heuristics are to be applied, i.e., if a
calibration is required. The default value is None (meaning the calibration is not required). In
case this hyperparameter is set to a value other than None, but the calibration is not required
(calib=None), a warning message will be raised.
:param exact_train_prev: set to True (default) for using the true training prevalence as the initial observation;
set to False for computing the training prevalence as an estimate of it, i.e., as the expected
value of the posterior probabilities of the training instances.
:param recalib: a string indicating the method of recalibration.
:param calib: a string indicating the method of calibration.
Available choices include "nbvs" (No-Bias Vector Scaling), "bcts" (Bias-Corrected Temperature Scaling,
default), "ts" (Temperature Scaling), and "vs" (Vector Scaling). Default is None (no recalibration).
default), "ts" (Temperature Scaling), and "vs" (Vector Scaling). Default is None (no calibration).
:param on_calib_error: a string indicating the policy to follow in case the calibrator fails at runtime.
Options include "raise" (default), in which case a RuntimeException is raised; and "backup", in which
case the calibrator is silently skipped.
:param n_jobs: number of parallel workers. Only used for recalibrating the classifier if `val_split` is set to
an integer `k` --the number of folds.
"""
MAX_ITER = 1000
EPSILON = 1e-4
ON_CALIB_ERROR_VALUES = ['raise', 'backup']
CALIB_OPTIONS = [None, 'nbvs', 'bcts', 'ts', 'vs']
def __init__(self, classifier: BaseEstimator = None, fit_classifier=True, val_split=None, exact_train_prev=True,
calib=None, on_calib_error='raise', n_jobs=None):
assert calib in EMQ.CALIB_OPTIONS, \
f'invalid value for {calib=}; valid ones are {EMQ.CALIB_OPTIONS}'
assert on_calib_error in EMQ.ON_CALIB_ERROR_VALUES, \
f'invalid value for {on_calib_error=}; valid ones are {EMQ.ON_CALIB_ERROR_VALUES}'
def __init__(self, classifier: BaseEstimator = None, fit_classifier=True, val_split=None, exact_train_prev=True, recalib=None,
n_jobs=None):
super().__init__(classifier, fit_classifier, val_split)
self.exact_train_prev = exact_train_prev
self.recalib = recalib
self.calib = calib
self.on_calib_errors = on_calib_error
self.n_jobs = n_jobs
@classmethod
def EMQ_BCTS(cls, classifier: BaseEstimator, fit_classifier=True, val_split=5, n_jobs=None):
def EMQ_BCTS(cls, classifier: BaseEstimator, fit_classifier=True, val_split=5, on_calib_error="raise", n_jobs=None):
"""
Constructs an instance of EMQ using the best configuration found in the `Alexandari et al. paper
<http://proceedings.mlr.press/v119/alexandari20a.html>`_, i.e., one that relies on Bias-Corrected Temperature
Scaling (BCTS) as a recalibration function, and that uses an estimate of the training prevalence instead of
Scaling (BCTS) as a calibration function, and that uses an estimate of the training prevalence instead of
the true training prevalence.
:param classifier: a sklearn's Estimator that generates a classifier
:param n_jobs: number of parallel workers.
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
:param on_calib_error: a string indicating the policy to follow in case the calibrator fails at runtime.
Options include "raise" (default), in which case a RuntimeException is raised; and "backup", in which
case the calibrator is silently skipped.
:param n_jobs: number of parallel workers. Only used for recalibrating the classifier if `val_split` is set to
an integer `k` --the number of folds.
:return: An instance of EMQ with BCTS
"""
return EMQ(classifier, fit_classifier=fit_classifier, val_split=val_split, exact_train_prev=False, recalib='bcts', n_jobs=n_jobs)
return EMQ(classifier, fit_classifier=fit_classifier, val_split=val_split, exact_train_prev=False,
calib='bcts', on_calib_error=on_calib_error, n_jobs=n_jobs)
def _check_init_parameters(self):
if self.val_split is not None:
if self.exact_train_prev and self.recalib is None:
if self.exact_train_prev and self.calib is None:
raise RuntimeWarning(f'The parameter {self.val_split=} was specified for EMQ, while the parameters '
f'{self.exact_train_prev=} and {self.recalib=}. This has no effect and causes an unnecessary '
f'{self.exact_train_prev=} and {self.calib=}. This has no effect and causes an unnecessary '
f'overload.')
else:
if self.recalib is not None:
print(f'[warning] The parameter {self.recalib=} requires the val_split be different from None. '
if self.calib is not None:
print(f'[warning] The parameter {self.calib=} requires the val_split be different from None. '
f'This parameter will be set to 5. To avoid this warning, set this value to a float value '
f'indicating the proportion of training data to be used as validation, or to an integer '
f'indicating the number of folds for kFCV.')
@ -718,56 +760,78 @@ class EMQ(AggregativeSoftQuantifier):
def classify(self, X):
"""
Provides the posterior probabilities for the given instances. If the classifier was required
to be recalibrated, then these posteriors are recalibrated accordingly.
Provides the posterior probabilities for the given instances. The calibration function, if required,
has no effect in this step, and is only involved in the aggregate method.
:param X: array-like of shape `(n_instances, n_dimensions,)`
:return: np.ndarray of shape `(n_instances, n_classes,)` with posterior probabilities
"""
posteriors = self.classifier.predict_proba(X)
if hasattr(self, 'calibration_function') and self.calibration_function is not None:
posteriors = self.calibration_function(posteriors)
return posteriors
return self.classifier.predict_proba(X)
def classifier_fit_predict(self, X, y):
classif_predictions = super().classifier_fit_predict(X, y)
self.train_prevalence = F.prevalence_from_labels(y, classes=self.classes_)
return classif_predictions
def aggregation_fit(self, classif_predictions):
def _fit_calibration(self, calibrator, P, y):
n_classes = len(self.classes_)
if not np.issubdtype(y.dtype, np.number):
y = np.searchsorted(self.classes_, y)
try:
self.calibration_function = calibrator(P, np.eye(n_classes)[y], posterior_supplied=True)
except Exception as e:
if self.on_calib_errors == 'raise':
raise RuntimeError(f'calibration {self.calib} failed at fit time: {e}')
elif self.on_calib_errors == 'backup':
self.calibration_function = lambda P: P
def _calibrate_if_requested(self, uncalib_posteriors):
if hasattr(self, 'calibration_function') and self.calibration_function is not None:
try:
calib_posteriors = self.calibration_function(uncalib_posteriors)
except Exception as e:
if self.on_calib_errors == 'raise':
raise RuntimeError(f'calibration {self.calib} failed at predict time: {e}')
elif self.on_calib_errors == 'backup':
calib_posteriors = uncalib_posteriors
else:
raise ValueError(f'unexpected {self.on_calib_errors=}; '
f'valid options are {EMQ.ON_CALIB_ERROR_VALUES}')
return calib_posteriors
return uncalib_posteriors
def aggregation_fit(self, classif_predictions, labels):
"""
Trains the aggregation function of EMQ. This comes down to recalibrating the posterior probabilities
ir requested.
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
:param classif_predictions: array-like with the raw (i.e., uncalibrated) posterior probabilities
returned by the classifier
:param labels: array-like with the true labels associated to each classifier prediction
"""
P, y = classif_predictions
n_classes = len(self.classes_)
if self.recalib is not None:
if self.recalib == 'nbvs':
calibrator = NoBiasVectorScaling()
elif self.recalib == 'bcts':
calibrator = TempScaling(bias_positions='all')
elif self.recalib == 'ts':
calibrator = TempScaling()
elif self.recalib == 'vs':
calibrator = VectorScaling()
else:
raise ValueError('invalid param argument for recalibration method; available ones are '
'"nbvs", "bcts", "ts", and "vs".')
P = classif_predictions
y = labels
if self.calib is not None:
calibrator = {
'nbvs': NoBiasVectorScaling(),
'bcts': TempScaling(bias_positions='all'),
'ts': TempScaling(),
'vs': VectorScaling()
}.get(self.calib, None)
if not np.issubdtype(y.dtype, np.number):
y = np.searchsorted(self.classes_, y)
self.calibration_function = calibrator(P, np.eye(n_classes)[y], posterior_supplied=True)
if calibrator is None:
raise ValueError(f'invalid value for {self.calib=}; valid ones are {EMQ.CALIB_OPTIONS}')
self._fit_calibration(calibrator, P, y)
if not self.exact_train_prev:
train_posteriors = classif_predictions.X
if self.recalib is not None:
train_posteriors = self.calibration_function(train_posteriors)
self.train_prevalence = F.prevalence_from_probabilities(train_posteriors)
P = self._calibrate_if_requested(P)
self.train_prevalence = F.prevalence_from_probabilities(P)
def aggregate(self, classif_posteriors, epsilon=EPSILON):
classif_posteriors = self._calibrate_if_requested(classif_posteriors)
priors, posteriors = self.EM(self.train_prevalence, classif_posteriors, epsilon)
return priors
@ -780,6 +844,7 @@ class EMQ(AggregativeSoftQuantifier):
:return: np.ndarray of shape `(n_instances, n_classes)`
"""
classif_posteriors = self.classify(instances)
classif_posteriors = self._calibrate_if_requested(classif_posteriors)
priors, posteriors = self.EM(self.train_prevalence, classif_posteriors, epsilon)
return posteriors
@ -827,101 +892,6 @@ class EMQ(AggregativeSoftQuantifier):
return qs, ps
class BayesianCC(AggregativeCrispQuantifier):
"""
`Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ method,
which is a variant of :class:`ACC` that calculates the posterior probability distribution
over the prevalence vectors, rather than providing a point estimate obtained
by matrix inversion.
Can be used to diagnose degeneracy in the predictions visible when the confusion
matrix has high condition number or to quantify uncertainty around the point estimate.
This method relies on extra dependencies, which have to be installed via:
`$ pip install quapy[bayes]`
:param classifier: a sklearn's Estimator that generates a classifier
:param val_split: a float in (0, 1) indicating the proportion of the training data to be used,
as a stratified held-out validation set, for generating classifier predictions.
:param num_warmup: number of warmup iterations for the MCMC sampler (default 500)
:param num_samples: number of samples to draw from the posterior (default 1000)
:param mcmc_seed: random seed for the MCMC sampler (default 0)
"""
def __init__(self,
classifier: BaseEstimator = None,
val_split: float = 0.75,
num_warmup: int = 500,
num_samples: int = 1_000,
mcmc_seed: int = 0):
if num_warmup <= 0:
raise ValueError(f'parameter {num_warmup=} must be a positive integer')
if num_samples <= 0:
raise ValueError(f'parameter {num_samples=} must be a positive integer')
if (not isinstance(val_split, float)) or val_split <= 0 or val_split >= 1:
raise ValueError(f'val_split must be a float in (0, 1), got {val_split}')
if _bayesian.DEPENDENCIES_INSTALLED is False:
raise ImportError("Auxiliary dependencies are required. Run `$ pip install quapy[bayes]` to install them.")
self.classifier = qp._get_classifier(classifier)
self.val_split = val_split
self.num_warmup = num_warmup
self.num_samples = num_samples
self.mcmc_seed = mcmc_seed
# Array of shape (n_classes, n_predicted_classes,) where entry (y, c) is the number of instances
# labeled as class y and predicted as class c.
# By default, this array is set to None and later defined as part of the `aggregation_fit` phase
self._n_and_c_labeled = None
# Dictionary with posterior samples, set when `aggregate` is provided.
self._samples = None
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
"""
Estimates the misclassification rates.
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
as instances, the label predictions issued by the classifier and, as labels, the true labels
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
"""
pred_labels, true_labels = classif_predictions.Xy
self._n_and_c_labeled = confusion_matrix(y_true=true_labels, y_pred=pred_labels,
labels=self.classifier.classes_)
def sample_from_posterior(self, classif_predictions):
if self._n_and_c_labeled is None:
raise ValueError("aggregation_fit must be called before sample_from_posterior")
n_c_unlabeled = F.counts_from_labels(classif_predictions, self.classifier.classes_)
self._samples = _bayesian.sample_posterior(
n_c_unlabeled=n_c_unlabeled,
n_y_and_c_labeled=self._n_and_c_labeled,
num_warmup=self.num_warmup,
num_samples=self.num_samples,
seed=self.mcmc_seed,
)
return self._samples
def get_prevalence_samples(self):
if self._samples is None:
raise ValueError("sample_from_posterior must be called before get_prevalence_samples")
return self._samples[_bayesian.P_TEST_Y]
def get_conditional_probability_samples(self):
if self._samples is None:
raise ValueError("sample_from_posterior must be called before get_conditional_probability_samples")
return self._samples[_bayesian.P_C_COND_Y]
def aggregate(self, classif_predictions):
samples = self.sample_from_posterior(classif_predictions)[_bayesian.P_TEST_Y]
return np.asarray(samples.mean(axis=0), dtype=float)
class HDy(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
"""
`Hellinger Distance y <https://www.sciencedirect.com/science/article/pii/S0020025512004069>`_ (HDy).
@ -932,24 +902,30 @@ class HDy(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
class-conditional distributions of the posterior probabilities returned for the positive and negative validation
examples, respectively. The parameters of the mixture thus represent the estimates of the class prevalence values.
:param classifier: a sklearn's Estimator that generates a binary classifier
:param val_split: a float in range (0,1) indicating the proportion of data to be used as a stratified held-out
validation distribution, or a :class:`quapy.data.base.LabelledCollection` (the split itself), or an integer indicating the number of folds (default 5)..
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
"""
def __init__(self, classifier: BaseEstimator = None, val_split=5):
self.classifier = qp._get_classifier(classifier)
self.val_split = val_split
def __init__(self, classifier: BaseEstimator = None, fit_classifier=True, val_split=5):
super().__init__(classifier, fit_classifier, val_split)
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
def aggregation_fit(self, classif_predictions, labels):
"""
Trains the aggregation function of HDy.
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
:param classif_predictions: array-like with the posterior probabilities returned by the classifier
:param labels: array-like with the true labels associated to each posterior
"""
P, y = classif_predictions.Xy
P, y = classif_predictions, labels
Px = P[:, self.pos_label] # takes only the P(y=+1|x)
self.Pxy1 = Px[y == self.pos_label]
self.Pxy0 = Px[y == self.neg_label]
@ -1003,20 +979,31 @@ class DyS(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
minimizes the distance between distributions.
Details for the ternary search have been got from <https://dl.acm.org/doi/pdf/10.1145/3219819.3220059>
:param classifier: a sklearn's Estimator that generates a binary classifier
:param val_split: a float in range (0,1) indicating the proportion of data to be used as a stratified held-out
validation distribution, or a :class:`quapy.data.base.LabelledCollection` (the split itself), or an integer indicating the number of folds (default 5)..
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
:param n_bins: an int with the number of bins to use to compute the histograms.
:param divergence: a str indicating the name of divergence (currently supported ones are "HD" or "topsoe"), or a
callable function computes the divergence between two distributions (two equally sized arrays).
:param tol: a float with the tolerance for the ternary search algorithm.
:param n_jobs: number of parallel workers.
"""
def __init__(self, classifier: BaseEstimator = None, val_split=5, n_bins=8, divergence: Union[str, Callable] = 'HD',
tol=1e-05, n_jobs=None):
self.classifier = qp._get_classifier(classifier)
self.val_split = val_split
def __init__(self, classifier: BaseEstimator = None, fit_classifier=True, val_split=5, n_bins=8,
divergence: Union[str, Callable] = 'HD', tol=1e-05, n_jobs=None):
super().__init__(classifier, fit_classifier, val_split)
self.tol = tol
self.divergence = divergence
self.n_bins = n_bins
@ -1038,15 +1025,14 @@ class DyS(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
# Left and right are the current bounds; the maximum is between them
return (left + right) / 2
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
def aggregation_fit(self, classif_predictions, labels):
"""
Trains the aggregation function of DyS.
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
:param classif_predictions: array-like with the posterior probabilities returned by the classifier
:param labels: array-like with the true labels associated to each posterior
"""
Px, y = classif_predictions.Xy
Px, y = classif_predictions, labels
Px = Px[:, self.pos_label] # takes only the P(y=+1|x)
self.Pxy1 = Px[y == self.pos_label]
self.Pxy0 = Px[y == self.neg_label]
@ -1074,24 +1060,30 @@ class SMM(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
SMM is a simplification of matching distribution methods where the representation of the examples
is created using the mean instead of a histogram (conceptually equivalent to PACC).
:param classifier: a sklearn's Estimator that generates a binary classifier.
:param val_split: a float in range (0,1) indicating the proportion of data to be used as a stratified held-out
validation distribution, or a :class:`quapy.data.base.LabelledCollection` (the split itself), or an integer indicating the number of folds (default 5)..
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
"""
def __init__(self, classifier: BaseEstimator = None, val_split=5):
self.classifier = qp._get_classifier(classifier)
self.val_split = val_split
def __init__(self, classifier: BaseEstimator = None, fit_classifier=True, val_split=5):
super().__init__(classifier, fit_classifier, val_split)
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
def aggregation_fit(self, classif_predictions, labels):
"""
Trains the aggregation function of SMM.
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
:param classif_predictions: array-like with the posterior probabilities returned by the classifier
:param labels: array-like with the true labels associated to each posterior
"""
Px, y = classif_predictions.Xy
Px, y = classif_predictions, labels
Px = Px[:, self.pos_label] # takes only the P(y=+1|x)
self.Pxy1 = Px[y == self.pos_label]
self.Pxy0 = Px[y == self.neg_label]
@ -1113,25 +1105,32 @@ class DMy(AggregativeSoftQuantifier):
probabilities. This implementation takes the number of bins, the divergence, and the possibility to work on CDF
as hyperparameters.
:param classifier: a `sklearn`'s Estimator that generates a probabilistic classifier
:param val_split: indicates the proportion of data to be used as a stratified held-out validation set to model the
validation distribution.
This parameter can be indicated as a real value (between 0 and 1), representing a proportion of
validation data, or as an integer, indicating that the validation distribution should be estimated via
`k`-fold cross validation (this integer stands for the number of folds `k`, defaults 5), or as a
:class:`quapy.data.base.LabelledCollection` (the split itself).
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param fit_classifier: whether to train the learner (default is True). Set to False if the
learner has been trained outside the quantifier.
:param val_split: specifies the data used for generating classifier predictions. This specification
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set; or as an integer (default 5), indicating that the predictions
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
for `k`); or as a tuple (X,y) defining the specific set of data to use for validation.
:param nbins: number of bins used to discretize the distributions (default 8)
:param divergence: a string representing a divergence measure (currently, "HD" and "topsoe" are implemented)
or a callable function taking two ndarrays of the same dimension as input (default "HD", meaning Hellinger
Distance)
:param cdf: whether to use CDF instead of PDF (default False)
:param n_jobs: number of parallel workers (default None)
"""
def __init__(self, classifier: BaseEstimator = None, val_split=5, nbins=8, divergence: Union[str, Callable] = 'HD',
cdf=False, search='optim_minimize', n_jobs=None):
self.classifier = qp._get_classifier(classifier)
self.val_split = val_split
def __init__(self, classifier: BaseEstimator = None, fit_classifier=True, val_split=5, nbins=8,
divergence: Union[str, Callable] = 'HD', cdf=False, search='optim_minimize', n_jobs=None):
super().__init__(classifier, fit_classifier, val_split)
self.nbins = nbins
self.divergence = divergence
self.cdf = cdf
@ -1162,7 +1161,7 @@ class DMy(AggregativeSoftQuantifier):
distributions = np.cumsum(distributions, axis=1)
return distributions
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
def aggregation_fit(self, classif_predictions, labels):
"""
Trains the aggregation function of a distribution matching method. This comes down to generating the
validation distributions out of the training data.
@ -1172,11 +1171,10 @@ class DMy(AggregativeSoftQuantifier):
distribution of posterior probabilities `P(Y=j|X=x)` for training data labelled with class `i`, and `dij[k]`
is the fraction of instances with a value in the `k`-th bin.
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
:param classif_predictions: array-like with the posterior probabilities returned by the classifier
:param labels: array-like with the true labels associated to each posterior
"""
posteriors, true_labels = classif_predictions.Xy
posteriors, true_labels = classif_predictions, labels
n_classes = len(self.classifier.classes_)
self.validation_distribution = qp.util.parallel(
@ -1432,7 +1430,7 @@ class AggregativeMedianEstimator(BinaryQuantifier):
:param base_quantifier: the base, binary quantifier
:param random_state: a seed to be set before fitting any base quantifier (default None)
:param param_grid: the grid or parameters towards which the median will be computed
:param n_jobs: number of parllel workes
:param n_jobs: number of parallel workers
"""
def __init__(self, base_quantifier: AggregativeQuantifier, param_grid: dict, random_state=None, n_jobs=None):
@ -1449,32 +1447,32 @@ class AggregativeMedianEstimator(BinaryQuantifier):
def _delayed_fit(self, args):
with qp.util.temp_seed(self.random_state):
params, training = args
params, X, y = args
model = deepcopy(self.base_quantifier)
model.set_params(**params)
model.fit(training)
model.fit(X, y)
return model
def _delayed_fit_classifier(self, args):
with qp.util.temp_seed(self.random_state):
cls_params, training, kwargs = args
cls_params, X, y = args
model = deepcopy(self.base_quantifier)
model.set_params(**cls_params)
predictions = model.classifier_fit_predict(training, **kwargs)
return (model, predictions)
predictions, labels = model.classifier_fit_predict(X, y)
return (model, predictions, labels)
def _delayed_fit_aggregation(self, args):
with qp.util.temp_seed(self.random_state):
((model, predictions), q_params), training = args
((model, predictions, y), q_params) = args
model = deepcopy(model)
model.set_params(**q_params)
model.aggregation_fit(predictions, training)
model.aggregation_fit(predictions, y)
return model
def fit(self, training: LabelledCollection, **kwargs):
def fit(self, X, y):
import itertools
self._check_binary(training, self.__class__.__name__)
self._check_binary(y, self.__class__.__name__)
if isinstance(self.base_quantifier, AggregativeQuantifier):
cls_configs, q_configs = qp.model_selection.group_params(self.param_grid)
@ -1482,7 +1480,7 @@ class AggregativeMedianEstimator(BinaryQuantifier):
if len(cls_configs) > 1:
models_preds = qp.util.parallel(
self._delayed_fit_classifier,
((params, training, kwargs) for params in cls_configs),
((params, X, y) for params in cls_configs),
seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs,
asarray=False,
@ -1491,12 +1489,12 @@ class AggregativeMedianEstimator(BinaryQuantifier):
else:
model = self.base_quantifier
model.set_params(**cls_configs[0])
predictions = model.classifier_fit_predict(training, **kwargs)
models_preds = [(model, predictions)]
predictions, labels = model.classifier_fit_predict(X, y)
models_preds = [(model, predictions, labels)]
self.models = qp.util.parallel(
self._delayed_fit_aggregation,
((setup, training) for setup in itertools.product(models_preds, q_configs)),
itertools.product(models_preds, q_configs),
seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs,
backend='threading'
@ -1505,7 +1503,7 @@ class AggregativeMedianEstimator(BinaryQuantifier):
configs = qp.model_selection.expand_grid(self.param_grid)
self.models = qp.util.parallel(
self._delayed_fit,
((params, training) for params in configs),
((params, X, y) for params in configs),
seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs,
backend='threading'
@ -1514,9 +1512,9 @@ class AggregativeMedianEstimator(BinaryQuantifier):
def _delayed_predict(self, args):
model, instances = args
return model.quantify(instances)
return model.predict(instances)
def quantify(self, instances):
def predict(self, instances):
prev_preds = qp.util.parallel(
self._delayed_predict,
((model, instances) for model in self.models),

View File

@ -375,18 +375,20 @@ class AggregativeBootstrap(WithConfidenceABC, AggregativeQuantifier):
self.region = region
self.random_state = random_state
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
def aggregation_fit(self, classif_predictions, labels):
data = LabelledCollection(classif_predictions, labels, classes=self.classes_)
self.quantifiers = []
if self.n_train_samples==1:
self.quantifier.aggregation_fit(classif_predictions, data)
self.quantifier.aggregation_fit(classif_predictions, labels)
self.quantifiers.append(self.quantifier)
else:
# model-based bootstrap (only on the aggregative part)
full_index = np.arange(len(data))
n_examples = len(data)
full_index = np.arange(n_examples)
with qp.util.temp_seed(self.random_state):
for i in range(self.n_train_samples):
quantifier = copy.deepcopy(self.quantifier)
index = resample(full_index, n_samples=len(data))
index = resample(full_index, n_samples=n_examples)
classif_predictions_i = classif_predictions.sampling_from_index(index)
data_i = data.sampling_from_index(index)
quantifier.aggregation_fit(classif_predictions_i, data_i)
@ -415,10 +417,10 @@ class AggregativeBootstrap(WithConfidenceABC, AggregativeQuantifier):
return prev_estim, conf
def fit(self, data: LabelledCollection, fit_classifier=True, val_split=None):
def fit(self, X, y):
self.quantifier._check_init_parameters()
classif_predictions = self.quantifier.classifier_fit_predict(data, fit_classifier, predict_on=val_split)
self.aggregation_fit(classif_predictions, data)
classif_predictions, labels = self.quantifier.classifier_fit_predict(X, y)
self.aggregation_fit(classif_predictions, labels)
return self
def quantify_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
@ -446,7 +448,8 @@ class BayesianCC(AggregativeCrispQuantifier, WithConfidenceABC):
This method relies on extra dependencies, which have to be installed via:
`$ pip install quapy[bayes]`
:param classifier: a sklearn's Estimator that generates a classifier
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
the one indicated in `qp.environ['DEFAULT_CLS']`
:param val_split: a float in (0, 1) indicating the proportion of the training data to be used,
as a stratified held-out validation set, for generating classifier predictions.
:param num_warmup: number of warmup iterations for the MCMC sampler (default 500)
@ -493,16 +496,17 @@ class BayesianCC(AggregativeCrispQuantifier, WithConfidenceABC):
# Dictionary with posterior samples, set when `aggregate` is provided.
self._samples = None
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
def aggregation_fit(self, classif_predictions, labels):
"""
Estimates the misclassification rates.
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
as instances, the label predictions issued by the classifier and, as labels, the true labels
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
:param classif_predictions: array-like with the label predictions returned by the classifier
:param labels: array-like with the true labels associated to each classifier prediction
"""
pred_labels, true_labels = classif_predictions.Xy
self._n_and_c_labeled = confusion_matrix(y_true=true_labels, y_pred=pred_labels, labels=self.classifier.classes_).astype(float)
pred_labels = classif_predictions
true_labels = labels
self._n_and_c_labeled = confusion_matrix(y_true=true_labels, y_pred=pred_labels,
labels=self.classifier.classes_)
def sample_from_posterior(self, classif_predictions):
if self._n_and_c_labeled is None:

View File

@ -52,19 +52,19 @@ class MedianEstimator2(BinaryQuantifier):
def _delayed_fit(self, args):
with qp.util.temp_seed(self.random_state):
params, training = args
params, X, y = args
model = deepcopy(self.base_quantifier)
model.set_params(**params)
model.fit(training)
model.fit(X, y)
return model
def fit(self, training: LabelledCollection):
self._check_binary(training, self.__class__.__name__)
def fit(self, X, y):
self._check_binary(y, self.__class__.__name__)
configs = qp.model_selection.expand_grid(self.param_grid)
self.models = qp.util.parallel(
self._delayed_fit,
((params, training) for params in configs),
((params, X, y) for params in configs),
seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs
)
@ -95,7 +95,7 @@ class MedianEstimator(BinaryQuantifier):
:param base_quantifier: the base, binary quantifier
:param random_state: a seed to be set before fitting any base quantifier (default None)
:param param_grid: the grid or parameters towards which the median will be computed
:param n_jobs: number of parllel workes
:param n_jobs: number of parallel workers
"""
def __init__(self, base_quantifier: BinaryQuantifier, param_grid: dict, random_state=None, n_jobs=None):
self.base_quantifier = base_quantifier
@ -111,61 +111,19 @@ class MedianEstimator(BinaryQuantifier):
def _delayed_fit(self, args):
with qp.util.temp_seed(self.random_state):
params, training = args
params, X, y = args
model = deepcopy(self.base_quantifier)
model.set_params(**params)
model.fit(training)
model.fit(X, y)
return model
def _delayed_fit_classifier(self, args):
with qp.util.temp_seed(self.random_state):
cls_params, training = args
model = deepcopy(self.base_quantifier)
model.set_params(**cls_params)
predictions = model.classifier_fit_predict(training, predict_on=model.val_split)
return (model, predictions)
def fit(self, X, y):
self._check_binary(y, self.__class__.__name__)
def _delayed_fit_aggregation(self, args):
with qp.util.temp_seed(self.random_state):
((model, predictions), q_params), training = args
model = deepcopy(model)
model.set_params(**q_params)
model.aggregation_fit(predictions, training)
return model
def fit(self, training: LabelledCollection):
self._check_binary(training, self.__class__.__name__)
if isinstance(self.base_quantifier, AggregativeQuantifier):
cls_configs, q_configs = qp.model_selection.group_params(self.param_grid)
if len(cls_configs) > 1:
models_preds = qp.util.parallel(
self._delayed_fit_classifier,
((params, training) for params in cls_configs),
seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs,
asarray=False
)
else:
model = self.base_quantifier
model.set_params(**cls_configs[0])
predictions = model.classifier_fit_predict(training, predict_on=model.val_split)
models_preds = [(model, predictions)]
self.models = qp.util.parallel(
self._delayed_fit_aggregation,
((setup, training) for setup in itertools.product(models_preds, q_configs)),
seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs,
asarray=False
)
else:
configs = qp.model_selection.expand_grid(self.param_grid)
self.models = qp.util.parallel(
self._delayed_fit,
((params, training) for params in configs),
((params, X, y) for params in configs),
seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs,
asarray=False
@ -257,12 +215,13 @@ class Ensemble(BaseQuantifier):
if self.verbose:
print('[Ensemble]' + msg)
def fit(self, data: qp.data.LabelledCollection, val_split: Union[qp.data.LabelledCollection, float] = None):
def fit(self, X, y):
data = LabelledCollection(X, y)
if self.policy == 'ds' and not data.binary:
raise ValueError(f'ds policy is only defined for binary quantification, but this dataset is not binary')
if val_split is None:
val_split = self.val_split
# randomly chooses the prevalences for each member of the ensemble (preventing classes with less than
@ -704,10 +663,10 @@ class SCMQ(AggregativeSoftQuantifier):
self.merge_fun = merge_fun
self.val_split = val_split
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
def aggregation_fit(self, classif_predictions, labels):
for quantifier in self.quantifiers:
quantifier.classifier = self.classifier
quantifier.aggregation_fit(classif_predictions, data)
quantifier.aggregation_fit(classif_predictions, labels)
return self
def aggregate(self, classif_predictions: np.ndarray):

View File

@ -100,7 +100,7 @@ class DMx(BaseQuantifier):
return distributions
def fit(self, data: LabelledCollection):
def fit(self, X, y):
"""
Generates the validation distributions out of the training data (covariates).
The validation distributions have shape `(n, nfeats, nbins)`, with `n` the number of classes, `nfeats`
@ -109,10 +109,9 @@ class DMx(BaseQuantifier):
training data labelled with class `i`; while `dij = di[j]` is the discrete distribution for feature j in
training data labelled with class `i`, and `dij[k]` is the fraction of instances with a value in the `k`-th bin.
:param data: the training set
:param X: array-like of shape `(n_samples, n_features)`, the training instances
:param y: array-like of shape `(n_samples,)`, the labels
"""
X, y = data.Xy
self.nfeats = X.shape[1]
self.feat_ranges = _get_features_range(X)
@ -147,53 +146,53 @@ class DMx(BaseQuantifier):
return F.argmin_prevalence(loss, n_classes, method=self.search)
class ReadMe(BaseQuantifier):
def __init__(self, bootstrap_trials=100, bootstrap_range=100, bagging_trials=100, bagging_range=25, **vectorizer_kwargs):
raise NotImplementedError('under development ...')
self.bootstrap_trials = bootstrap_trials
self.bootstrap_range = bootstrap_range
self.bagging_trials = bagging_trials
self.bagging_range = bagging_range
self.vectorizer_kwargs = vectorizer_kwargs
def fit(self, data: LabelledCollection):
X, y = data.Xy
self.vectorizer = CountVectorizer(binary=True, **self.vectorizer_kwargs)
X = self.vectorizer.fit_transform(X)
self.class_conditional_X = {i: X[y==i] for i in range(data.classes_)}
def predict(self, X):
X = self.vectorizer.transform(X)
# number of features
num_docs, num_feats = X.shape
# bootstrap
p_boots = []
for _ in range(self.bootstrap_trials):
docs_idx = np.random.choice(num_docs, size=self.bootstra_range, replace=False)
class_conditional_X = {i: X[docs_idx] for i, X in self.class_conditional_X.items()}
Xboot = X[docs_idx]
# bagging
p_bags = []
for _ in range(self.bagging_trials):
feat_idx = np.random.choice(num_feats, size=self.bagging_range, replace=False)
class_conditional_Xbag = {i: X[:, feat_idx] for i, X in class_conditional_X.items()}
Xbag = Xboot[:,feat_idx]
p = self.std_constrained_linear_ls(Xbag, class_conditional_Xbag)
p_bags.append(p)
p_boots.append(np.mean(p_bags, axis=0))
p_mean = np.mean(p_boots, axis=0)
p_std = np.std(p_bags, axis=0)
return p_mean
def std_constrained_linear_ls(self, X, class_cond_X: dict):
pass
# class ReadMe(BaseQuantifier):
#
# def __init__(self, bootstrap_trials=100, bootstrap_range=100, bagging_trials=100, bagging_range=25, **vectorizer_kwargs):
# raise NotImplementedError('under development ...')
# self.bootstrap_trials = bootstrap_trials
# self.bootstrap_range = bootstrap_range
# self.bagging_trials = bagging_trials
# self.bagging_range = bagging_range
# self.vectorizer_kwargs = vectorizer_kwargs
#
# def fit(self, data: LabelledCollection):
# X, y = data.Xy
# self.vectorizer = CountVectorizer(binary=True, **self.vectorizer_kwargs)
# X = self.vectorizer.fit_transform(X)
# self.class_conditional_X = {i: X[y==i] for i in range(data.classes_)}
#
# def predict(self, X):
# X = self.vectorizer.transform(X)
#
# # number of features
# num_docs, num_feats = X.shape
#
# # bootstrap
# p_boots = []
# for _ in range(self.bootstrap_trials):
# docs_idx = np.random.choice(num_docs, size=self.bootstra_range, replace=False)
# class_conditional_X = {i: X[docs_idx] for i, X in self.class_conditional_X.items()}
# Xboot = X[docs_idx]
#
# # bagging
# p_bags = []
# for _ in range(self.bagging_trials):
# feat_idx = np.random.choice(num_feats, size=self.bagging_range, replace=False)
# class_conditional_Xbag = {i: X[:, feat_idx] for i, X in class_conditional_X.items()}
# Xbag = Xboot[:,feat_idx]
# p = self.std_constrained_linear_ls(Xbag, class_conditional_Xbag)
# p_bags.append(p)
# p_boots.append(np.mean(p_bags, axis=0))
#
# p_mean = np.mean(p_boots, axis=0)
# p_std = np.std(p_bags, axis=0)
#
# return p_mean
#
#
# def std_constrained_linear_ls(self, X, class_cond_X: dict):
# pass
def _get_features_range(X):

View File

@ -1,3 +1,4 @@
from sklearn.linear_model import LogisticRegression
import quapy as qp
from method.aggregative import *
@ -5,11 +6,20 @@ datasets = qp.datasets.UCI_MULTICLASS_DATASETS[1]
data = qp.datasets.fetch_UCIMulticlassDataset(datasets)
train, test = data.train_test
quant = EMQ()
quant.fit(*train.Xy)
prev = quant.predict(test.X)
Xtr, ytr = train.Xy
Xte = test.X
quant = EMQ(LogisticRegression(), calib='bcts')
quant.fit(Xtr, ytr)
prev = quant.predict(Xte)
print(prev)
post = quant.predict_proba(Xte)
print(post)
post = quant.classify(Xte)
print(post)
# AggregativeMedianEstimator()
# test CC, prevent from doing 5FCV for nothing