Sketch of the Bayesian quantification

This commit is contained in:
Paweł Czyż 2024-03-15 14:01:24 +01:00
parent 3705264529
commit 2cc4908326
4 changed files with 188 additions and 12 deletions

View File

@ -28,22 +28,34 @@ def prevalence_linspace(n_prevalences=21, repeats=1, smooth_limits_epsilon=0.01)
return p return p
def prevalence_from_labels(labels, classes): def counts_from_labels(labels, classes):
""" """
Computed the prevalence values from a vector of labels. Computes the count values from a vector of labels.
:param labels: array-like of shape `(n_instances)` with the label for each instance :param labels: array-like of shape `(n_instances,)` with the label for each instance
:param classes: the class labels. This is needed in order to correctly compute the prevalence vector even when :param classes: the class labels. This is needed in order to correctly compute the prevalence vector even when
some classes have no examples. some classes have no examples.
:return: an ndarray of shape `(len(classes))` with the class prevalence values :return: an ndarray of shape `(len(classes),)` with the occurrence counts of each class
""" """
if labels.ndim != 1: if labels.ndim != 1:
raise ValueError(f'param labels does not seem to be a ndarray of label predictions') raise ValueError(f'param labels does not seem to be a ndarray of label predictions')
unique, counts = np.unique(labels, return_counts=True) unique, counts = np.unique(labels, return_counts=True)
by_class = defaultdict(lambda:0, dict(zip(unique, counts))) by_class = defaultdict(lambda:0, dict(zip(unique, counts)))
prevalences = np.asarray([by_class[class_] for class_ in classes], dtype=float) counts = np.asarray([by_class[class_] for class_ in classes], dtype=int)
prevalences /= prevalences.sum() return counts
return prevalences
def prevalence_from_labels(labels, classes):
"""
Computes the prevalence values from a vector of labels.
:param labels: array-like of shape `(n_instances,)` with the label for each instance
:param classes: the class labels. This is needed in order to correctly compute the prevalence vector even when
some classes have no examples.
:return: an ndarray of shape `(len(classes))` with the class prevalence values
"""
counts = np.array(counts_from_labels(labels, classes), dtype=float)
return counts / np.sum(counts)
def prevalence_from_probabilities(posteriors, binarize: bool = False): def prevalence_from_probabilities(posteriors, binarize: bool = False):

78
quapy/method/_bayesian.py Normal file
View File

@ -0,0 +1,78 @@
"""
Utility functions for `Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ methods.
"""
import numpy as np
try:
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
DEPENDENCIES_INSTALLED = True
except ImportError:
jax = None
jnp = None
numpyro = None
dist = None
DEPENDENCIES_INSTALLED = False
P_TEST_Y: str = "P_test(Y)"
P_TEST_C: str = "P_test(C)"
P_C_COND_Y: str = "P(C|Y)"
def model(n_c_unlabeled: np.ndarray, n_y_and_c_labeled: np.ndarray) -> None:
"""
Defines a probabilistic model in `NumPyro <https://num.pyro.ai/>`_.
:param n_c_unlabeled: a `np.ndarray` of shape `(n_predicted_classes,)`
with entry `c` being the number of instances predicted as class `c`.
:param n_y_and_c_labeled: a `np.ndarray` of shape `(n_classes, n_predicted_classes)`
with entry `(y, c)` being the number of instances labeled as class `y` and predicted as class `c`.
"""
n_y_labeled = n_y_and_c_labeled.sum(axis=1)
K = len(n_c_unlabeled)
L = len(n_y_labeled)
pi_ = numpyro.sample(P_TEST_Y, dist.Dirichlet(jnp.ones(L)))
p_c_cond_y = numpyro.sample(P_C_COND_Y, dist.Dirichlet(jnp.ones(K).repeat(L).reshape(L, K)))
with numpyro.plate('plate', L):
numpyro.sample('F_yc', dist.Multinomial(n_y_labeled, p_c_cond_y), obs=n_y_and_c_labeled)
p_c = numpyro.deterministic(P_TEST_C, jnp.einsum("yc,y->c", p_c_cond_y, pi_))
numpyro.sample('N_c', dist.Multinomial(jnp.sum(n_c_unlabeled), p_c), obs=n_c_unlabeled)
def sample_posterior(
n_c_unlabeled: np.ndarray,
n_y_and_c_labeled: np.ndarray,
num_warmup: int,
num_samples: int,
seed: int = 0,
) -> dict:
"""
Samples from the Bayesian quantification model in NumPyro using the
`NUTS <https://arxiv.org/abs/1111.4246>`_ sampler.
:param n_c_unlabeled: a `np.ndarray` of shape `(n_predicted_classes,)`
with entry `c` being the number of instances predicted as class `c`.
:param n_y_and_c_labeled: a `np.ndarray` of shape `(n_classes, n_predicted_classes)`
with entry `(y, c)` being the number of instances labeled as class `y` and predicted as class `c`.
:param num_warmup: the number of warmup steps.
:param num_samples: the number of samples to draw.
:seed: the random seed.
:return: a `dict` with the samples. The keys are the names of the latent variables.
"""
mcmc = numpyro.infer.MCMC(
numpyro.infer.NUTS(model),
num_warmup=num_warmup,
num_samples=num_samples,
)
rng_key = jax.random.PRNGKey(seed)
mcmc.run(rng_key, n_c_unlabeled=n_c_unlabeled, n_y_and_c_labeled=n_y_and_c_labeled)
return mcmc.get_samples()

View File

@ -11,6 +11,7 @@ from sklearn.model_selection import cross_val_predict
import quapy as qp import quapy as qp
import quapy.functional as F import quapy.functional as F
import quapy._bayesian as _bayesian
from quapy.functional import get_divergence from quapy.functional import get_divergence
from quapy.classification.calibration import NBVSCalibration, BCTSCalibration, TSCalibration, VSCalibration from quapy.classification.calibration import NBVSCalibration, BCTSCalibration, TSCalibration, VSCalibration
from quapy.classification.svmperf import SVMperf from quapy.classification.svmperf import SVMperf
@ -384,7 +385,8 @@ class ACC(AggregativeCrispQuantifier):
self.solver = solver self.solver = solver
def _check_init_parameters(self): def _check_init_parameters(self):
assert self.solver in ['exact', 'minimize'], "unknown solver; valid ones are 'exact', 'minimize'" if self.solver not in ['exact', 'minimize']:
raise ValueError("unknown solver; valid ones are 'exact', 'minimize'")
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection): def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
""" """
@ -453,6 +455,91 @@ class ACC(AggregativeCrispQuantifier):
return F.optim_minimize(loss, n_classes=A.shape[0]) return F.optim_minimize(loss, n_classes=A.shape[0])
class BayesianCC(AggregativeCrispQuantifier):
"""
`Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ methods,
which is a variant of :class`ACC` that calculates the posterior probability distribution
over the prevalence vectors, rather than providing a point estimate obtained
by matrix inversion.
Can be used to diagnose degeneracy in the predictions visible when the confusion
matrix has high condition number or to quantify uncertainty around the point estimate.
This method relies on extra dependencies, which have to be installed via:
`$ pip install quapy[bayes]`
:param classifier: a sklearn's Estimator that generates a classifier
:param val_split: specifies the data used for generating classifier predictions. This specification
should be a float in (0, 1) indicating the proportion of stratified held-out validation set to
be extracted from the training set
:num_warmup: number of warmup iterations for the MCMC sampler
:num_samples: number of samples to draw from the posterior
:mcmc_seed: random seed for the MCMC sampler
"""
def __init__(self, classifier: BaseEstimator, val_split: float = 0.75, num_warmup: int = 500, num_samples: int = 1_000, mcmc_seed: int = 0) -> None:
if num_warmup <= 0:
raise ValueError(f'num_warmup must be a positive integer, got {num_warmup}')
if num_samples <= 0:
raise ValueError(f'num_samples must be a positive integer, got {num_samples}')
if (not isinstance(val_split, float)) or val_split <= 0 or val_split >= 1:
raise ValueError(f'val_split must be a float in (0, 1), got {val_split}')
if _bayesian.DEPENDENCIES_INSTALLED is False:
raise ImportError("Auxiliary dependencies are required. Run `$ pip install quapy[bayes]` to install them.")
self.classifier = classifier
self.val_split = val_split
self.num_warmup = num_warmup
self.num_samples = num_samples
self.mcmc_seed = mcmc_seed
# Array of shape (n_classes, n_predicted_classes) where entry (y, c) is the number of instances labeled as class y and predicted as class c
# By default it's None and it's set during the `aggregation_fit` phase
self._n_and_c_labeled = None
# Dictionary with posterior samples, set when `aggregate` is provided.
self._samples = None
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
"""
Estimates the misclassification rates.
:param classif_predictions: classifier predictions with true labels
"""
pred_labels, true_labels = classif_predictions.Xy
self._n_and_c_labeled = confusion_matrix(y_true=true_labels, y_pred=pred_labels, labels=self.classifier.classes_)
def sample_from_posterior(self, classif_predictions):
if self._n_and_c_labeled is None:
raise ValueError("aggregation_fit must be called before sample_from_posterior")
n_c_unlabeled = F.counts_from_labels(classif_predictions, self.classifier.classes_)
self._samples = _bayesian.sample_posterior(
n_c_unlabeled=n_c_unlabeled,
n_y_and_c_labeled=self._n_and_c_labeled,
num_warmup=self.num_warmup,
num_samples=self.num_samples,
seed=self.mcmc_seed,
)
return self._samples
def get_prevalence_samples(self):
if self._samples is None:
raise ValueError("sample_from_posterior must be called before get_prevalence_samples")
return self._samples[_bayesian.P_TEST_Y]
def get_conditional_probability_samples(self):
if self._samples is None:
raise ValueError("sample_from_posterior must be called before get_conditional_probability_samples")
return self._samples[_bayesian.P_C_COND_Y]
def aggregate(self, classif_predictions):
samples = self.sample_from_posterior(classif_predictions)[_bayesian.P_TEST_Y]
return np.asarray(samples.mean(axis=0), dtype=float)
class PCC(AggregativeSoftQuantifier): class PCC(AggregativeSoftQuantifier):
""" """
`Probabilistic Classify & Count <https://ieeexplore.ieee.org/abstract/document/5694031>`_, `Probabilistic Classify & Count <https://ieeexplore.ieee.org/abstract/document/5694031>`_,

View File

@ -123,10 +123,9 @@ setup(
# #
# Similar to `install_requires` above, these must be valid existing # Similar to `install_requires` above, these must be valid existing
# projects. # projects.
# extras_require={ # Optional extras_require={ # Optional
# 'dev': ['check-manifest'], 'bayes': ['jax', 'jaxlib', 'numpyro'],
# 'test': ['coverage'], },
# },
# If there are data files included in your packages that need to be # If there are data files included in your packages that need to be
# installed, specify them here. # installed, specify them here.