Sketch of the Bayesian quantification
This commit is contained in:
parent
3705264529
commit
2cc4908326
|
@ -28,22 +28,34 @@ def prevalence_linspace(n_prevalences=21, repeats=1, smooth_limits_epsilon=0.01)
|
|||
return p
|
||||
|
||||
|
||||
def prevalence_from_labels(labels, classes):
|
||||
def counts_from_labels(labels, classes):
|
||||
"""
|
||||
Computed the prevalence values from a vector of labels.
|
||||
Computes the count values from a vector of labels.
|
||||
|
||||
:param labels: array-like of shape `(n_instances)` with the label for each instance
|
||||
:param labels: array-like of shape `(n_instances,)` with the label for each instance
|
||||
:param classes: the class labels. This is needed in order to correctly compute the prevalence vector even when
|
||||
some classes have no examples.
|
||||
:return: an ndarray of shape `(len(classes))` with the class prevalence values
|
||||
:return: an ndarray of shape `(len(classes),)` with the occurrence counts of each class
|
||||
"""
|
||||
if labels.ndim != 1:
|
||||
raise ValueError(f'param labels does not seem to be a ndarray of label predictions')
|
||||
unique, counts = np.unique(labels, return_counts=True)
|
||||
by_class = defaultdict(lambda:0, dict(zip(unique, counts)))
|
||||
prevalences = np.asarray([by_class[class_] for class_ in classes], dtype=float)
|
||||
prevalences /= prevalences.sum()
|
||||
return prevalences
|
||||
counts = np.asarray([by_class[class_] for class_ in classes], dtype=int)
|
||||
return counts
|
||||
|
||||
|
||||
def prevalence_from_labels(labels, classes):
|
||||
"""
|
||||
Computes the prevalence values from a vector of labels.
|
||||
|
||||
:param labels: array-like of shape `(n_instances,)` with the label for each instance
|
||||
:param classes: the class labels. This is needed in order to correctly compute the prevalence vector even when
|
||||
some classes have no examples.
|
||||
:return: an ndarray of shape `(len(classes))` with the class prevalence values
|
||||
"""
|
||||
counts = np.array(counts_from_labels(labels, classes), dtype=float)
|
||||
return counts / np.sum(counts)
|
||||
|
||||
|
||||
def prevalence_from_probabilities(posteriors, binarize: bool = False):
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
"""
|
||||
Utility functions for `Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ methods.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpyro
|
||||
import numpyro.distributions as dist
|
||||
|
||||
DEPENDENCIES_INSTALLED = True
|
||||
except ImportError:
|
||||
jax = None
|
||||
jnp = None
|
||||
numpyro = None
|
||||
dist = None
|
||||
|
||||
DEPENDENCIES_INSTALLED = False
|
||||
|
||||
|
||||
P_TEST_Y: str = "P_test(Y)"
|
||||
P_TEST_C: str = "P_test(C)"
|
||||
P_C_COND_Y: str = "P(C|Y)"
|
||||
|
||||
|
||||
def model(n_c_unlabeled: np.ndarray, n_y_and_c_labeled: np.ndarray) -> None:
|
||||
"""
|
||||
Defines a probabilistic model in `NumPyro <https://num.pyro.ai/>`_.
|
||||
|
||||
:param n_c_unlabeled: a `np.ndarray` of shape `(n_predicted_classes,)`
|
||||
with entry `c` being the number of instances predicted as class `c`.
|
||||
:param n_y_and_c_labeled: a `np.ndarray` of shape `(n_classes, n_predicted_classes)`
|
||||
with entry `(y, c)` being the number of instances labeled as class `y` and predicted as class `c`.
|
||||
"""
|
||||
n_y_labeled = n_y_and_c_labeled.sum(axis=1)
|
||||
|
||||
K = len(n_c_unlabeled)
|
||||
L = len(n_y_labeled)
|
||||
|
||||
pi_ = numpyro.sample(P_TEST_Y, dist.Dirichlet(jnp.ones(L)))
|
||||
p_c_cond_y = numpyro.sample(P_C_COND_Y, dist.Dirichlet(jnp.ones(K).repeat(L).reshape(L, K)))
|
||||
|
||||
with numpyro.plate('plate', L):
|
||||
numpyro.sample('F_yc', dist.Multinomial(n_y_labeled, p_c_cond_y), obs=n_y_and_c_labeled)
|
||||
|
||||
p_c = numpyro.deterministic(P_TEST_C, jnp.einsum("yc,y->c", p_c_cond_y, pi_))
|
||||
numpyro.sample('N_c', dist.Multinomial(jnp.sum(n_c_unlabeled), p_c), obs=n_c_unlabeled)
|
||||
|
||||
|
||||
def sample_posterior(
|
||||
n_c_unlabeled: np.ndarray,
|
||||
n_y_and_c_labeled: np.ndarray,
|
||||
num_warmup: int,
|
||||
num_samples: int,
|
||||
seed: int = 0,
|
||||
) -> dict:
|
||||
"""
|
||||
Samples from the Bayesian quantification model in NumPyro using the
|
||||
`NUTS <https://arxiv.org/abs/1111.4246>`_ sampler.
|
||||
|
||||
:param n_c_unlabeled: a `np.ndarray` of shape `(n_predicted_classes,)`
|
||||
with entry `c` being the number of instances predicted as class `c`.
|
||||
:param n_y_and_c_labeled: a `np.ndarray` of shape `(n_classes, n_predicted_classes)`
|
||||
with entry `(y, c)` being the number of instances labeled as class `y` and predicted as class `c`.
|
||||
:param num_warmup: the number of warmup steps.
|
||||
:param num_samples: the number of samples to draw.
|
||||
:seed: the random seed.
|
||||
:return: a `dict` with the samples. The keys are the names of the latent variables.
|
||||
"""
|
||||
mcmc = numpyro.infer.MCMC(
|
||||
numpyro.infer.NUTS(model),
|
||||
num_warmup=num_warmup,
|
||||
num_samples=num_samples,
|
||||
)
|
||||
rng_key = jax.random.PRNGKey(seed)
|
||||
mcmc.run(rng_key, n_c_unlabeled=n_c_unlabeled, n_y_and_c_labeled=n_y_and_c_labeled)
|
||||
return mcmc.get_samples()
|
|
@ -11,6 +11,7 @@ from sklearn.model_selection import cross_val_predict
|
|||
|
||||
import quapy as qp
|
||||
import quapy.functional as F
|
||||
import quapy._bayesian as _bayesian
|
||||
from quapy.functional import get_divergence
|
||||
from quapy.classification.calibration import NBVSCalibration, BCTSCalibration, TSCalibration, VSCalibration
|
||||
from quapy.classification.svmperf import SVMperf
|
||||
|
@ -384,7 +385,8 @@ class ACC(AggregativeCrispQuantifier):
|
|||
self.solver = solver
|
||||
|
||||
def _check_init_parameters(self):
|
||||
assert self.solver in ['exact', 'minimize'], "unknown solver; valid ones are 'exact', 'minimize'"
|
||||
if self.solver not in ['exact', 'minimize']:
|
||||
raise ValueError("unknown solver; valid ones are 'exact', 'minimize'")
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
"""
|
||||
|
@ -453,6 +455,91 @@ class ACC(AggregativeCrispQuantifier):
|
|||
return F.optim_minimize(loss, n_classes=A.shape[0])
|
||||
|
||||
|
||||
class BayesianCC(AggregativeCrispQuantifier):
|
||||
"""
|
||||
`Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ methods,
|
||||
which is a variant of :class`ACC` that calculates the posterior probability distribution
|
||||
over the prevalence vectors, rather than providing a point estimate obtained
|
||||
by matrix inversion.
|
||||
|
||||
Can be used to diagnose degeneracy in the predictions visible when the confusion
|
||||
matrix has high condition number or to quantify uncertainty around the point estimate.
|
||||
|
||||
This method relies on extra dependencies, which have to be installed via:
|
||||
`$ pip install quapy[bayes]`
|
||||
|
||||
:param classifier: a sklearn's Estimator that generates a classifier
|
||||
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||
should be a float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||
be extracted from the training set
|
||||
:num_warmup: number of warmup iterations for the MCMC sampler
|
||||
:num_samples: number of samples to draw from the posterior
|
||||
:mcmc_seed: random seed for the MCMC sampler
|
||||
"""
|
||||
def __init__(self, classifier: BaseEstimator, val_split: float = 0.75, num_warmup: int = 500, num_samples: int = 1_000, mcmc_seed: int = 0) -> None:
|
||||
if num_warmup <= 0:
|
||||
raise ValueError(f'num_warmup must be a positive integer, got {num_warmup}')
|
||||
if num_samples <= 0:
|
||||
raise ValueError(f'num_samples must be a positive integer, got {num_samples}')
|
||||
|
||||
if (not isinstance(val_split, float)) or val_split <= 0 or val_split >= 1:
|
||||
raise ValueError(f'val_split must be a float in (0, 1), got {val_split}')
|
||||
|
||||
if _bayesian.DEPENDENCIES_INSTALLED is False:
|
||||
raise ImportError("Auxiliary dependencies are required. Run `$ pip install quapy[bayes]` to install them.")
|
||||
|
||||
self.classifier = classifier
|
||||
self.val_split = val_split
|
||||
self.num_warmup = num_warmup
|
||||
self.num_samples = num_samples
|
||||
self.mcmc_seed = mcmc_seed
|
||||
|
||||
# Array of shape (n_classes, n_predicted_classes) where entry (y, c) is the number of instances labeled as class y and predicted as class c
|
||||
# By default it's None and it's set during the `aggregation_fit` phase
|
||||
self._n_and_c_labeled = None
|
||||
|
||||
# Dictionary with posterior samples, set when `aggregate` is provided.
|
||||
self._samples = None
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
"""
|
||||
Estimates the misclassification rates.
|
||||
|
||||
:param classif_predictions: classifier predictions with true labels
|
||||
"""
|
||||
pred_labels, true_labels = classif_predictions.Xy
|
||||
self._n_and_c_labeled = confusion_matrix(y_true=true_labels, y_pred=pred_labels, labels=self.classifier.classes_)
|
||||
|
||||
def sample_from_posterior(self, classif_predictions):
|
||||
if self._n_and_c_labeled is None:
|
||||
raise ValueError("aggregation_fit must be called before sample_from_posterior")
|
||||
|
||||
n_c_unlabeled = F.counts_from_labels(classif_predictions, self.classifier.classes_)
|
||||
|
||||
self._samples = _bayesian.sample_posterior(
|
||||
n_c_unlabeled=n_c_unlabeled,
|
||||
n_y_and_c_labeled=self._n_and_c_labeled,
|
||||
num_warmup=self.num_warmup,
|
||||
num_samples=self.num_samples,
|
||||
seed=self.mcmc_seed,
|
||||
)
|
||||
return self._samples
|
||||
|
||||
def get_prevalence_samples(self):
|
||||
if self._samples is None:
|
||||
raise ValueError("sample_from_posterior must be called before get_prevalence_samples")
|
||||
return self._samples[_bayesian.P_TEST_Y]
|
||||
|
||||
def get_conditional_probability_samples(self):
|
||||
if self._samples is None:
|
||||
raise ValueError("sample_from_posterior must be called before get_conditional_probability_samples")
|
||||
return self._samples[_bayesian.P_C_COND_Y]
|
||||
|
||||
def aggregate(self, classif_predictions):
|
||||
samples = self.sample_from_posterior(classif_predictions)[_bayesian.P_TEST_Y]
|
||||
return np.asarray(samples.mean(axis=0), dtype=float)
|
||||
|
||||
|
||||
class PCC(AggregativeSoftQuantifier):
|
||||
"""
|
||||
`Probabilistic Classify & Count <https://ieeexplore.ieee.org/abstract/document/5694031>`_,
|
||||
|
|
7
setup.py
7
setup.py
|
@ -123,10 +123,9 @@ setup(
|
|||
#
|
||||
# Similar to `install_requires` above, these must be valid existing
|
||||
# projects.
|
||||
# extras_require={ # Optional
|
||||
# 'dev': ['check-manifest'],
|
||||
# 'test': ['coverage'],
|
||||
# },
|
||||
extras_require={ # Optional
|
||||
'bayes': ['jax', 'jaxlib', 'numpyro'],
|
||||
},
|
||||
|
||||
# If there are data files included in your packages that need to be
|
||||
# installed, specify them here.
|
||||
|
|
Loading…
Reference in New Issue