merged
This commit is contained in:
commit
6388d9b549
|
|
@ -2,6 +2,7 @@
|
||||||
Utility functions for `Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ methods.
|
Utility functions for `Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ methods.
|
||||||
"""
|
"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import importlib.resources
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import jax
|
import jax
|
||||||
|
|
@ -82,6 +83,9 @@ def sample_posterior(
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def load_stan_file():
|
||||||
|
return importlib.resources.files('quapy.method').joinpath('stan/pq.stan').read_text(encoding='utf-8')
|
||||||
|
|
||||||
def pq_stan(stan_code, n_bins, pos_hist, neg_hist, test_hist, number_of_samples, num_warmup, stan_seed):
|
def pq_stan(stan_code, n_bins, pos_hist, neg_hist, test_hist, number_of_samples, num_warmup, stan_seed):
|
||||||
"""
|
"""
|
||||||
Perform Bayesian prevalence estimation using a Stan model for probabilistic quantification.
|
Perform Bayesian prevalence estimation using a Stan model for probabilistic quantification.
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,6 @@ from abc import ABC, abstractmethod
|
||||||
from scipy.special import softmax, factorial
|
from scipy.special import softmax, factorial
|
||||||
import copy
|
import copy
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This module provides implementation of different types of confidence regions, and the implementation of Bootstrap
|
This module provides implementation of different types of confidence regions, and the implementation of Bootstrap
|
||||||
|
|
@ -625,10 +624,7 @@ class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
||||||
self.num_samples = num_samples
|
self.num_samples = num_samples
|
||||||
self.region = region
|
self.region = region
|
||||||
self.stan_seed = stan_seed
|
self.stan_seed = stan_seed
|
||||||
# with open('quapy/method/stan/pq.stan', 'r') as f:
|
self.stan_code = _bayesian.load_stan_file()
|
||||||
stan_path = Path(__file__).resolve().parent / "stan" / "pq.stan"
|
|
||||||
with stan_path.open("r") as f:
|
|
||||||
self.stan_code = str(f.read())
|
|
||||||
|
|
||||||
def aggregation_fit(self, classif_predictions, labels):
|
def aggregation_fit(self, classif_predictions, labels):
|
||||||
y_pred = classif_predictions[:, self.pos_label]
|
y_pred = classif_predictions[:, self.pos_label]
|
||||||
|
|
@ -662,7 +658,8 @@ class PQ(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
||||||
return F.as_binary_prevalence(self.prev_distribution.mean())
|
return F.as_binary_prevalence(self.prev_distribution.mean())
|
||||||
|
|
||||||
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
def predict_conf(self, instances, confidence_level=None) -> (np.ndarray, ConfidenceRegionABC):
|
||||||
point_estimate = self.predict(instances)
|
classif_predictions = self.classify(instances)
|
||||||
|
point_estimate = self.aggregate(classif_predictions)
|
||||||
samples = self.prev_distribution
|
samples = self.prev_distribution
|
||||||
region = WithConfidenceABC.construct_region(samples, confidence_level=confidence_level, method=self.region)
|
region = WithConfidenceABC.construct_region(samples, confidence_level=confidence_level, method=self.region)
|
||||||
return point_estimate, region
|
return point_estimate, region
|
||||||
|
|
|
||||||
6
setup.py
6
setup.py
|
|
@ -111,6 +111,12 @@ setup(
|
||||||
#
|
#
|
||||||
packages=find_packages(include=['quapy', 'quapy.*']), # Required
|
packages=find_packages(include=['quapy', 'quapy.*']), # Required
|
||||||
|
|
||||||
|
package_data={
|
||||||
|
# For the 'quapy.method' package, include all files
|
||||||
|
# in the 'stan' subdirectory that end with .stan
|
||||||
|
'quapy.method': ['stan/*.stan']
|
||||||
|
},
|
||||||
|
|
||||||
python_requires='>=3.8, <4',
|
python_requires='>=3.8, <4',
|
||||||
|
|
||||||
install_requires=['scikit-learn', 'pandas', 'tqdm', 'matplotlib', 'joblib', 'xlrd', 'abstention', 'ucimlrepo', 'certifi'],
|
install_requires=['scikit-learn', 'pandas', 'tqdm', 'matplotlib', 'joblib', 'xlrd', 'abstention', 'ucimlrepo', 'certifi'],
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue