Merge branch 'pglez82-precisequant' into devel
This commit is contained in:
commit
5f6a151263
|
|
@ -44,7 +44,7 @@ class LabelledCollection:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def index(self):
|
def index(self):
|
||||||
if self._index is None:
|
if not hasattr(self, '_index') or self._index is None:
|
||||||
self._index = {class_: np.arange(len(self))[self.labels == class_] for class_ in self.classes_}
|
self._index = {class_: np.arange(len(self))[self.labels == class_] for class_ in self.classes_}
|
||||||
return self._index
|
return self._index
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ AGGREGATIVE_METHODS = {
|
||||||
aggregative.KDEyHD,
|
aggregative.KDEyHD,
|
||||||
# aggregative.OneVsAllAggregative,
|
# aggregative.OneVsAllAggregative,
|
||||||
confidence.BayesianCC,
|
confidence.BayesianCC,
|
||||||
|
confidence.PQ,
|
||||||
}
|
}
|
||||||
|
|
||||||
BINARY_METHODS = {
|
BINARY_METHODS = {
|
||||||
|
|
@ -40,6 +41,7 @@ BINARY_METHODS = {
|
||||||
aggregative.MAX,
|
aggregative.MAX,
|
||||||
aggregative.MS,
|
aggregative.MS,
|
||||||
aggregative.MS2,
|
aggregative.MS2,
|
||||||
|
confidence.PQ,
|
||||||
}
|
}
|
||||||
|
|
||||||
MULTICLASS_METHODS = {
|
MULTICLASS_METHODS = {
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,14 @@
|
||||||
Utility functions for `Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ methods.
|
Utility functions for `Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ methods.
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import importlib.resources
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpyro
|
import numpyro
|
||||||
import numpyro.distributions as dist
|
import numpyro.distributions as dist
|
||||||
|
import stan
|
||||||
|
|
||||||
DEPENDENCIES_INSTALLED = True
|
DEPENDENCIES_INSTALLED = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
@ -15,6 +17,7 @@ except ImportError:
|
||||||
jnp = None
|
jnp = None
|
||||||
numpyro = None
|
numpyro = None
|
||||||
dist = None
|
dist = None
|
||||||
|
stan = None
|
||||||
|
|
||||||
DEPENDENCIES_INSTALLED = False
|
DEPENDENCIES_INSTALLED = False
|
||||||
|
|
||||||
|
|
@ -77,3 +80,56 @@ def sample_posterior(
|
||||||
rng_key = jax.random.PRNGKey(seed)
|
rng_key = jax.random.PRNGKey(seed)
|
||||||
mcmc.run(rng_key, n_c_unlabeled=n_c_unlabeled, n_y_and_c_labeled=n_y_and_c_labeled)
|
mcmc.run(rng_key, n_c_unlabeled=n_c_unlabeled, n_y_and_c_labeled=n_y_and_c_labeled)
|
||||||
return mcmc.get_samples()
|
return mcmc.get_samples()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_stan_file():
|
||||||
|
return importlib.resources.files('quapy.method').joinpath('stan/pq.stan').read_text(encoding='utf-8')
|
||||||
|
|
||||||
|
def pq_stan(stan_code, n_bins, pos_hist, neg_hist, test_hist, number_of_samples, num_warmup, stan_seed):
|
||||||
|
"""
|
||||||
|
Perform Bayesian prevalence estimation using a Stan model for probabilistic quantification.
|
||||||
|
|
||||||
|
This function builds and samples from a Stan model that implements a bin-based Bayesian
|
||||||
|
quantifier. It uses the class-conditional histograms of the classifier
|
||||||
|
outputs for positive and negative examples, along with the test histogram, to estimate
|
||||||
|
the posterior distribution of prevalence in the test set.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
stan_code : str
|
||||||
|
The Stan model code as a string.
|
||||||
|
n_bins : int
|
||||||
|
Number of bins used to build the histograms for positive and negative examples.
|
||||||
|
pos_hist : array-like of shape (n_bins,)
|
||||||
|
Histogram counts of the classifier outputs for the positive class.
|
||||||
|
neg_hist : array-like of shape (n_bins,)
|
||||||
|
Histogram counts of the classifier outputs for the negative class.
|
||||||
|
test_hist : array-like of shape (n_bins,)
|
||||||
|
Histogram counts of the classifier outputs for the test set, binned using the same bins.
|
||||||
|
number_of_samples : int
|
||||||
|
Number of post-warmup samples to draw from the Stan posterior.
|
||||||
|
num_warmup : int
|
||||||
|
Number of warmup iterations for the sampler.
|
||||||
|
stan_seed : int
|
||||||
|
Random seed for Stan model compilation and sampling, ensuring reproducibility.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
prev_samples : numpy.ndarray
|
||||||
|
An array of posterior samples of the prevalence (`prev`) in the test set.
|
||||||
|
Each element corresponds to one draw from the posterior distribution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
stan_data = {
|
||||||
|
'n_bucket': n_bins,
|
||||||
|
'train_neg': neg_hist.tolist(),
|
||||||
|
'train_pos': pos_hist.tolist(),
|
||||||
|
'test': test_hist.tolist(),
|
||||||
|
'posterior': 1
|
||||||
|
}
|
||||||
|
|
||||||
|
stan_model = stan.build(stan_code, data=stan_data, random_seed=stan_seed)
|
||||||
|
fit = stan_model.sample(num_chains=1, num_samples=number_of_samples,num_warmup=num_warmup)
|
||||||
|
|
||||||
|
return fit['prev']
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,8 @@ from sklearn.metrics import confusion_matrix
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
import quapy.functional as F
|
import quapy.functional as F
|
||||||
from quapy.method import _bayesian
|
from quapy.method import _bayesian
|
||||||
from quapy.method.aggregative import AggregativeCrispQuantifier
|
|
||||||
from quapy.data import LabelledCollection
|
from quapy.data import LabelledCollection
|
||||||
from quapy.method.aggregative import AggregativeQuantifier
|
from quapy.method.aggregative import AggregativeQuantifier, AggregativeCrispQuantifier, AggregativeSoftQuantifier, BinaryAggregativeQuantifier
|
||||||
from scipy.stats import chi2
|
from scipy.stats import chi2
|
||||||
from sklearn.utils import resample
|
from sklearn.utils import resample
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
@ -566,8 +565,113 @@ class BayesianCC(AggregativeCrispQuantifier, WithConfidenceABC):
|
||||||
return np.asarray(samples.mean(axis=0), dtype=float)
|
return np.asarray(samples.mean(axis=0), dtype=float)
|
||||||
|
|
||||||
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
||||||
|
if confidence_level is None:
|
||||||
|
confidence_level = self.confidence_level
|
||||||
classif_predictions = self.classify(instances)
|
classif_predictions = self.classify(instances)
|
||||||
point_estimate = self.aggregate(classif_predictions)
|
point_estimate = self.aggregate(classif_predictions)
|
||||||
samples = self.get_prevalence_samples() # available after calling "aggregate" function
|
samples = self.get_prevalence_samples() # available after calling "aggregate" function
|
||||||
region = WithConfidenceABC.construct_region(samples, confidence_level=self.confidence_level, method=self.region)
|
region = WithConfidenceABC.construct_region(samples, confidence_level=confidence_level, method=self.region)
|
||||||
return point_estimate, region
|
return point_estimate, region
|
||||||
|
|
||||||
|
|
||||||
|
class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
||||||
|
"""
|
||||||
|
`Precise Quantifier: Bayesian distribution matching quantifier <https://arxiv.org/abs/2507.06061>,
|
||||||
|
which is a variant of :class:`HDy` that calculates the posterior probability distribution
|
||||||
|
over the prevalence vectors, rather than providing a point estimate.
|
||||||
|
|
||||||
|
This method relies on extra dependencies, which have to be installed via:
|
||||||
|
`$ pip install quapy[bayes]`
|
||||||
|
|
||||||
|
:param classifier: a scikit-learn's BaseEstimator, or None, in which case the classifier is taken to be
|
||||||
|
the one indicated in `qp.environ['DEFAULT_CLS']`
|
||||||
|
:param val_split: specifies the data used for generating classifier predictions. This specification
|
||||||
|
can be made as float in (0, 1) indicating the proportion of stratified held-out validation set to
|
||||||
|
be extracted from the training set; or as an integer (default 5), indicating that the predictions
|
||||||
|
are to be generated in a `k`-fold cross-validation manner (with this integer indicating the value
|
||||||
|
for `k`); or as a tuple `(X,y)` defining the specific set of data to use for validation. Set to
|
||||||
|
None when the method does not require any validation data, in order to avoid that some portion of
|
||||||
|
the training data be wasted.
|
||||||
|
:param num_warmup: number of warmup iterations for the STAN sampler (default 500)
|
||||||
|
:param num_samples: number of samples to draw from the posterior (default 1000)
|
||||||
|
:param stan_seed: random seed for the STAN sampler (default 0)
|
||||||
|
:param region: string, set to `intervals` for constructing confidence intervals (default), or to
|
||||||
|
`ellipse` for constructing an ellipse in the probability simplex, or to `ellipse-clr` for
|
||||||
|
constructing an ellipse in the Centered-Log Ratio (CLR) unconstrained space.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
classifier: BaseEstimator=None,
|
||||||
|
fit_classifier=True,
|
||||||
|
val_split: int = 5,
|
||||||
|
n_bins: int = 4,
|
||||||
|
fixed_bins: bool = False,
|
||||||
|
num_warmup: int = 500,
|
||||||
|
num_samples: int = 1_000,
|
||||||
|
stan_seed: int = 0,
|
||||||
|
confidence_level: float = 0.95,
|
||||||
|
region: str = 'intervals'):
|
||||||
|
|
||||||
|
if num_warmup <= 0:
|
||||||
|
raise ValueError(f'parameter {num_warmup=} must be a positive integer')
|
||||||
|
if num_samples <= 0:
|
||||||
|
raise ValueError(f'parameter {num_samples=} must be a positive integer')
|
||||||
|
|
||||||
|
if not _bayesian.DEPENDENCIES_INSTALLED:
|
||||||
|
raise ImportError("Auxiliary dependencies are required. "
|
||||||
|
"Run `$ pip install quapy[bayes]` to install them.")
|
||||||
|
|
||||||
|
super().__init__(classifier, fit_classifier, val_split)
|
||||||
|
self.n_bins = n_bins
|
||||||
|
self.fixed_bins = fixed_bins
|
||||||
|
self.num_warmup = num_warmup
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.stan_seed = stan_seed
|
||||||
|
self.stan_code = _bayesian.load_stan_file()
|
||||||
|
self.confidence_level = confidence_level
|
||||||
|
self.region = region
|
||||||
|
|
||||||
|
def aggregation_fit(self, classif_predictions, labels):
|
||||||
|
y_pred = classif_predictions[:, self.pos_label]
|
||||||
|
|
||||||
|
# Compute bin limits
|
||||||
|
if self.fixed_bins:
|
||||||
|
# Uniform bins in [0,1]
|
||||||
|
self.bin_limits = np.linspace(0, 1, self.n_bins + 1)
|
||||||
|
else:
|
||||||
|
# Quantile bins
|
||||||
|
self.bin_limits = np.quantile(y_pred, np.linspace(0, 1, self.n_bins + 1))
|
||||||
|
|
||||||
|
# Assign each prediction to a bin
|
||||||
|
bin_indices = np.digitize(y_pred, self.bin_limits[1:-1], right=True)
|
||||||
|
|
||||||
|
# Positive and negative masks
|
||||||
|
pos_mask = (labels == self.pos_label)
|
||||||
|
neg_mask = ~pos_mask
|
||||||
|
|
||||||
|
# Count positives and negatives per bin
|
||||||
|
self.pos_hist = np.bincount(bin_indices[pos_mask], minlength=self.n_bins)
|
||||||
|
self.neg_hist = np.bincount(bin_indices[neg_mask], minlength=self.n_bins)
|
||||||
|
|
||||||
|
def aggregate(self, classif_predictions):
|
||||||
|
Px_test = classif_predictions[:, self.pos_label]
|
||||||
|
test_hist, _ = np.histogram(Px_test, bins=self.bin_limits)
|
||||||
|
prevs = _bayesian.pq_stan(
|
||||||
|
self.stan_code, self.n_bins, self.pos_hist, self.neg_hist, test_hist,
|
||||||
|
self.num_samples, self.num_warmup, self.stan_seed
|
||||||
|
).flatten()
|
||||||
|
self.prev_distribution = np.vstack([1-prevs, prevs]).T
|
||||||
|
return self.prev_distribution.mean(axis=0)
|
||||||
|
|
||||||
|
def aggregate_conf(self, predictions, confidence_level=None):
|
||||||
|
if confidence_level is None:
|
||||||
|
confidence_level = self.confidence_level
|
||||||
|
point_estimate = self.aggregate(predictions)
|
||||||
|
samples = self.prev_distribution
|
||||||
|
region = WithConfidenceABC.construct_region(samples, confidence_level=confidence_level, method=self.region)
|
||||||
|
return point_estimate, region
|
||||||
|
|
||||||
|
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
||||||
|
predictions = self.classify(instances)
|
||||||
|
return self.aggregate_conf(predictions, confidence_level=confidence_level)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
data {
|
||||||
|
int<lower=0> n_bucket;
|
||||||
|
array[n_bucket] int<lower=0> train_pos;
|
||||||
|
array[n_bucket] int<lower=0> train_neg;
|
||||||
|
array[n_bucket] int<lower=0> test;
|
||||||
|
int<lower=0,upper=1> posterior;
|
||||||
|
}
|
||||||
|
|
||||||
|
transformed data{
|
||||||
|
row_vector<lower=0>[n_bucket] train_pos_rv;
|
||||||
|
row_vector<lower=0>[n_bucket] train_neg_rv;
|
||||||
|
row_vector<lower=0>[n_bucket] test_rv;
|
||||||
|
real n_test;
|
||||||
|
|
||||||
|
train_pos_rv = to_row_vector( train_pos );
|
||||||
|
train_neg_rv = to_row_vector( train_neg );
|
||||||
|
test_rv = to_row_vector( test );
|
||||||
|
n_test = sum( test );
|
||||||
|
}
|
||||||
|
|
||||||
|
parameters {
|
||||||
|
simplex[n_bucket] p_neg;
|
||||||
|
simplex[n_bucket] p_pos;
|
||||||
|
real<lower=0,upper=1> prev_prior;
|
||||||
|
}
|
||||||
|
|
||||||
|
model {
|
||||||
|
if( posterior ) {
|
||||||
|
target += train_neg_rv * log( p_neg );
|
||||||
|
target += train_pos_rv * log( p_pos );
|
||||||
|
target += test_rv * log( p_neg * ( 1 - prev_prior) + p_pos * prev_prior );
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
generated quantities {
|
||||||
|
real<lower=0,upper=1> prev;
|
||||||
|
prev = sum( binomial_rng(test, 1 / ( 1 + (p_neg./p_pos) *(1-prev_prior)/prev_prior ) ) ) / n_test;
|
||||||
|
}
|
||||||
|
|
||||||
8
setup.py
8
setup.py
|
|
@ -111,6 +111,12 @@ setup(
|
||||||
#
|
#
|
||||||
packages=find_packages(include=['quapy', 'quapy.*']), # Required
|
packages=find_packages(include=['quapy', 'quapy.*']), # Required
|
||||||
|
|
||||||
|
package_data={
|
||||||
|
# For the 'quapy.method' package, include all files
|
||||||
|
# in the 'stan' subdirectory that end with .stan
|
||||||
|
'quapy.method': ['stan/*.stan']
|
||||||
|
},
|
||||||
|
|
||||||
python_requires='>=3.8, <4',
|
python_requires='>=3.8, <4',
|
||||||
|
|
||||||
install_requires=['scikit-learn', 'pandas', 'tqdm', 'matplotlib', 'joblib', 'xlrd', 'abstention', 'ucimlrepo', 'certifi'],
|
install_requires=['scikit-learn', 'pandas', 'tqdm', 'matplotlib', 'joblib', 'xlrd', 'abstention', 'ucimlrepo', 'certifi'],
|
||||||
|
|
@ -124,7 +130,7 @@ setup(
|
||||||
# Similar to `install_requires` above, these must be valid existing
|
# Similar to `install_requires` above, these must be valid existing
|
||||||
# projects.
|
# projects.
|
||||||
extras_require={ # Optional
|
extras_require={ # Optional
|
||||||
'bayes': ['jax', 'jaxlib', 'numpyro'],
|
'bayes': ['jax', 'jaxlib', 'numpyro', 'pystan'],
|
||||||
'neural': ['torch'],
|
'neural': ['torch'],
|
||||||
'tests': ['certifi'],
|
'tests': ['certifi'],
|
||||||
'docs' : ['sphinx-rtd-theme', 'myst-parser'],
|
'docs' : ['sphinx-rtd-theme', 'myst-parser'],
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue