This commit is contained in:
Alejandro Moreo Fernandez 2025-11-15 18:03:06 +01:00
commit 6388d9b549
3 changed files with 13 additions and 6 deletions

View File

@ -2,6 +2,7 @@
Utility functions for `Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ methods. Utility functions for `Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ methods.
""" """
import numpy as np import numpy as np
import importlib.resources
try: try:
import jax import jax
@ -82,6 +83,9 @@ def sample_posterior(
def load_stan_file():
return importlib.resources.files('quapy.method').joinpath('stan/pq.stan').read_text(encoding='utf-8')
def pq_stan(stan_code, n_bins, pos_hist, neg_hist, test_hist, number_of_samples, num_warmup, stan_seed): def pq_stan(stan_code, n_bins, pos_hist, neg_hist, test_hist, number_of_samples, num_warmup, stan_seed):
""" """
Perform Bayesian prevalence estimation using a Stan model for probabilistic quantification. Perform Bayesian prevalence estimation using a Stan model for probabilistic quantification.

View File

@ -13,7 +13,6 @@ from abc import ABC, abstractmethod
from scipy.special import softmax, factorial from scipy.special import softmax, factorial
import copy import copy
from functools import lru_cache from functools import lru_cache
from pathlib import Path
""" """
This module provides implementation of different types of confidence regions, and the implementation of Bootstrap This module provides implementation of different types of confidence regions, and the implementation of Bootstrap
@ -625,10 +624,7 @@ class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
self.num_samples = num_samples self.num_samples = num_samples
self.region = region self.region = region
self.stan_seed = stan_seed self.stan_seed = stan_seed
# with open('quapy/method/stan/pq.stan', 'r') as f: self.stan_code = _bayesian.load_stan_file()
stan_path = Path(__file__).resolve().parent / "stan" / "pq.stan"
with stan_path.open("r") as f:
self.stan_code = str(f.read())
def aggregation_fit(self, classif_predictions, labels): def aggregation_fit(self, classif_predictions, labels):
y_pred = classif_predictions[:, self.pos_label] y_pred = classif_predictions[:, self.pos_label]
@ -662,7 +658,8 @@ class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
return F.as_binary_prevalence(self.prev_distribution.mean()) return F.as_binary_prevalence(self.prev_distribution.mean())
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC): def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
point_estimate = self.predict(instances) classif_predictions = self.classify(instances)
point_estimate = self.aggregate(classif_predictions)
samples = self.prev_distribution samples = self.prev_distribution
region = WithConfidenceABC.construct_region(samples, confidence_level=confidence_level, method=self.region) region = WithConfidenceABC.construct_region(samples, confidence_level=confidence_level, method=self.region)
return point_estimate, region return point_estimate, region

View File

@ -111,6 +111,12 @@ setup(
# #
packages=find_packages(include=['quapy', 'quapy.*']), # Required packages=find_packages(include=['quapy', 'quapy.*']), # Required
package_data={
# For the 'quapy.method' package, include all files
# in the 'stan' subdirectory that end with .stan
'quapy.method': ['stan/*.stan']
},
python_requires='>=3.8, <4', python_requires='>=3.8, <4',
install_requires=['scikit-learn', 'pandas', 'tqdm', 'matplotlib', 'joblib', 'xlrd', 'abstention', 'ucimlrepo', 'certifi'], install_requires=['scikit-learn', 'pandas', 'tqdm', 'matplotlib', 'joblib', 'xlrd', 'abstention', 'ucimlrepo', 'certifi'],