Merge branch 'pawel-czyz-bayesian-quantification' into devel
This commit is contained in:
commit
f674151eba
|
@ -14,7 +14,6 @@ for facilitating the analysis and interpretation of the experimental results.
|
||||||
### Last updates:
|
### Last updates:
|
||||||
|
|
||||||
* Version 0.1.8 is released! major changes can be consulted [here](CHANGE_LOG.txt).
|
* Version 0.1.8 is released! major changes can be consulted [here](CHANGE_LOG.txt).
|
||||||
* A detailed documentation is now available [here](https://hlt-isti.github.io/QuaPy/)
|
|
||||||
* The developer API documentation is available [here](https://hlt-isti.github.io/QuaPy/build/html/modules.html)
|
* The developer API documentation is available [here](https://hlt-isti.github.io/QuaPy/build/html/modules.html)
|
||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
|
@ -28,22 +28,34 @@ def prevalence_linspace(n_prevalences=21, repeats=1, smooth_limits_epsilon=0.01)
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
def prevalence_from_labels(labels, classes):
|
def counts_from_labels(labels, classes):
|
||||||
"""
|
"""
|
||||||
Computed the prevalence values from a vector of labels.
|
Computes the count values from a vector of labels.
|
||||||
|
|
||||||
:param labels: array-like of shape `(n_instances)` with the label for each instance
|
:param labels: array-like of shape `(n_instances,)` with the label for each instance
|
||||||
:param classes: the class labels. This is needed in order to correctly compute the prevalence vector even when
|
:param classes: the class labels. This is needed in order to correctly compute the prevalence vector even when
|
||||||
some classes have no examples.
|
some classes have no examples.
|
||||||
:return: an ndarray of shape `(len(classes))` with the class prevalence values
|
:return: an ndarray of shape `(len(classes),)` with the occurrence counts of each class
|
||||||
"""
|
"""
|
||||||
if labels.ndim != 1:
|
if labels.ndim != 1:
|
||||||
raise ValueError(f'param labels does not seem to be a ndarray of label predictions')
|
raise ValueError(f'param labels does not seem to be a ndarray of label predictions')
|
||||||
unique, counts = np.unique(labels, return_counts=True)
|
unique, counts = np.unique(labels, return_counts=True)
|
||||||
by_class = defaultdict(lambda:0, dict(zip(unique, counts)))
|
by_class = defaultdict(lambda:0, dict(zip(unique, counts)))
|
||||||
prevalences = np.asarray([by_class[class_] for class_ in classes], dtype=float)
|
counts = np.asarray([by_class[class_] for class_ in classes], dtype=int)
|
||||||
prevalences /= prevalences.sum()
|
return counts
|
||||||
return prevalences
|
|
||||||
|
|
||||||
|
def prevalence_from_labels(labels, classes):
|
||||||
|
"""
|
||||||
|
Computes the prevalence values from a vector of labels.
|
||||||
|
|
||||||
|
:param labels: array-like of shape `(n_instances,)` with the label for each instance
|
||||||
|
:param classes: the class labels. This is needed in order to correctly compute the prevalence vector even when
|
||||||
|
some classes have no examples.
|
||||||
|
:return: an ndarray of shape `(len(classes))` with the class prevalence values
|
||||||
|
"""
|
||||||
|
counts = np.array(counts_from_labels(labels, classes), dtype=float)
|
||||||
|
return counts / np.sum(counts)
|
||||||
|
|
||||||
|
|
||||||
def prevalence_from_probabilities(posteriors, binarize: bool = False):
|
def prevalence_from_probabilities(posteriors, binarize: bool = False):
|
||||||
|
|
|
@ -20,11 +20,13 @@ AGGREGATIVE_METHODS = {
|
||||||
aggregative.KDEyML,
|
aggregative.KDEyML,
|
||||||
aggregative.KDEyCS,
|
aggregative.KDEyCS,
|
||||||
aggregative.KDEyHD,
|
aggregative.KDEyHD,
|
||||||
|
aggregative.BayesianCC
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
NON_AGGREGATIVE_METHODS = {
|
NON_AGGREGATIVE_METHODS = {
|
||||||
non_aggregative.MaximumLikelihoodPrevalenceEstimation
|
non_aggregative.MaximumLikelihoodPrevalenceEstimation,
|
||||||
|
non_aggregative.DMx
|
||||||
}
|
}
|
||||||
|
|
||||||
META_METHODS = {
|
META_METHODS = {
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
"""
|
||||||
|
Utility functions for `Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ methods.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpyro
|
||||||
|
import numpyro.distributions as dist
|
||||||
|
|
||||||
|
DEPENDENCIES_INSTALLED = True
|
||||||
|
except ImportError:
|
||||||
|
jax = None
|
||||||
|
jnp = None
|
||||||
|
numpyro = None
|
||||||
|
dist = None
|
||||||
|
|
||||||
|
DEPENDENCIES_INSTALLED = False
|
||||||
|
|
||||||
|
|
||||||
|
P_TEST_Y: str = "P_test(Y)"
|
||||||
|
P_TEST_C: str = "P_test(C)"
|
||||||
|
P_C_COND_Y: str = "P(C|Y)"
|
||||||
|
|
||||||
|
|
||||||
|
def model(n_c_unlabeled: np.ndarray, n_y_and_c_labeled: np.ndarray) -> None:
|
||||||
|
"""
|
||||||
|
Defines a probabilistic model in `NumPyro <https://num.pyro.ai/>`_.
|
||||||
|
|
||||||
|
:param n_c_unlabeled: a `np.ndarray` of shape `(n_predicted_classes,)`
|
||||||
|
with entry `c` being the number of instances predicted as class `c`.
|
||||||
|
:param n_y_and_c_labeled: a `np.ndarray` of shape `(n_classes, n_predicted_classes)`
|
||||||
|
with entry `(y, c)` being the number of instances labeled as class `y` and predicted as class `c`.
|
||||||
|
"""
|
||||||
|
n_y_labeled = n_y_and_c_labeled.sum(axis=1)
|
||||||
|
|
||||||
|
K = len(n_c_unlabeled)
|
||||||
|
L = len(n_y_labeled)
|
||||||
|
|
||||||
|
pi_ = numpyro.sample(P_TEST_Y, dist.Dirichlet(jnp.ones(L)))
|
||||||
|
p_c_cond_y = numpyro.sample(P_C_COND_Y, dist.Dirichlet(jnp.ones(K).repeat(L).reshape(L, K)))
|
||||||
|
|
||||||
|
with numpyro.plate('plate', L):
|
||||||
|
numpyro.sample('F_yc', dist.Multinomial(n_y_labeled, p_c_cond_y), obs=n_y_and_c_labeled)
|
||||||
|
|
||||||
|
p_c = numpyro.deterministic(P_TEST_C, jnp.einsum("yc,y->c", p_c_cond_y, pi_))
|
||||||
|
numpyro.sample('N_c', dist.Multinomial(jnp.sum(n_c_unlabeled), p_c), obs=n_c_unlabeled)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_posterior(
|
||||||
|
n_c_unlabeled: np.ndarray,
|
||||||
|
n_y_and_c_labeled: np.ndarray,
|
||||||
|
num_warmup: int,
|
||||||
|
num_samples: int,
|
||||||
|
seed: int = 0,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Samples from the Bayesian quantification model in NumPyro using the
|
||||||
|
`NUTS <https://arxiv.org/abs/1111.4246>`_ sampler.
|
||||||
|
|
||||||
|
:param n_c_unlabeled: a `np.ndarray` of shape `(n_predicted_classes,)`
|
||||||
|
with entry `c` being the number of instances predicted as class `c`.
|
||||||
|
:param n_y_and_c_labeled: a `np.ndarray` of shape `(n_classes, n_predicted_classes)`
|
||||||
|
with entry `(y, c)` being the number of instances labeled as class `y` and predicted as class `c`.
|
||||||
|
:param num_warmup: the number of warmup steps.
|
||||||
|
:param num_samples: the number of samples to draw.
|
||||||
|
:seed: the random seed.
|
||||||
|
:return: a `dict` with the samples. The keys are the names of the latent variables.
|
||||||
|
"""
|
||||||
|
mcmc = numpyro.infer.MCMC(
|
||||||
|
numpyro.infer.NUTS(model),
|
||||||
|
num_warmup=num_warmup,
|
||||||
|
num_samples=num_samples,
|
||||||
|
progress_bar=False
|
||||||
|
)
|
||||||
|
rng_key = jax.random.PRNGKey(seed)
|
||||||
|
mcmc.run(rng_key, n_c_unlabeled=n_c_unlabeled, n_y_and_c_labeled=n_y_and_c_labeled)
|
||||||
|
return mcmc.get_samples()
|
|
@ -16,6 +16,8 @@ from quapy.classification.calibration import NBVSCalibration, BCTSCalibration, T
|
||||||
from quapy.classification.svmperf import SVMperf
|
from quapy.classification.svmperf import SVMperf
|
||||||
from quapy.data import LabelledCollection
|
from quapy.data import LabelledCollection
|
||||||
from quapy.method.base import BaseQuantifier, BinaryQuantifier, OneVsAllGeneric
|
from quapy.method.base import BaseQuantifier, BinaryQuantifier, OneVsAllGeneric
|
||||||
|
from quapy.method import _bayesian
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Abstract classes
|
# Abstract classes
|
||||||
|
@ -162,8 +164,8 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
||||||
"""
|
"""
|
||||||
Trains the aggregation function.
|
Trains the aggregation function.
|
||||||
|
|
||||||
:param classif_predictions: a LabelledCollection containing the label predictions issued
|
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
|
||||||
by the classifier
|
as instances, the predictions issued by the classifier and, as labels, the true labels
|
||||||
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
@ -249,7 +251,7 @@ class AggregativeQuantifier(BaseQuantifier, ABC):
|
||||||
|
|
||||||
class AggregativeCrispQuantifier(AggregativeQuantifier, ABC):
|
class AggregativeCrispQuantifier(AggregativeQuantifier, ABC):
|
||||||
"""
|
"""
|
||||||
Abstract class for quantification methods that base their estimations on the aggregation of crips decisions
|
Abstract class for quantification methods that base their estimations on the aggregation of crisp decisions
|
||||||
as returned by a hard classifier. Aggregative crisp quantifiers thus extend Aggregative
|
as returned by a hard classifier. Aggregative crisp quantifiers thus extend Aggregative
|
||||||
Quantifiers by implementing specifications about crisp predictions.
|
Quantifiers by implementing specifications about crisp predictions.
|
||||||
"""
|
"""
|
||||||
|
@ -335,7 +337,8 @@ class CC(AggregativeCrispQuantifier):
|
||||||
"""
|
"""
|
||||||
Nothing to do here!
|
Nothing to do here!
|
||||||
|
|
||||||
:param classif_predictions: this is actually None
|
:param classif_predictions: not used
|
||||||
|
:param data: not used
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -384,13 +387,16 @@ class ACC(AggregativeCrispQuantifier):
|
||||||
self.solver = solver
|
self.solver = solver
|
||||||
|
|
||||||
def _check_init_parameters(self):
|
def _check_init_parameters(self):
|
||||||
assert self.solver in ['exact', 'minimize'], "unknown solver; valid ones are 'exact', 'minimize'"
|
if self.solver not in ['exact', 'minimize']:
|
||||||
|
raise ValueError("unknown solver; valid ones are 'exact', 'minimize'")
|
||||||
|
|
||||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
"""
|
"""
|
||||||
Estimates the misclassification rates.
|
Estimates the misclassification rates.
|
||||||
|
|
||||||
:param classif_predictions: classifier predictions with true labels
|
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
|
||||||
|
as instances, the label predictions issued by the classifier and, as labels, the true labels
|
||||||
|
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||||
"""
|
"""
|
||||||
pred_labels, true_labels = classif_predictions.Xy
|
pred_labels, true_labels = classif_predictions.Xy
|
||||||
self.cc = CC(self.classifier)
|
self.cc = CC(self.classifier)
|
||||||
|
@ -468,7 +474,8 @@ class PCC(AggregativeSoftQuantifier):
|
||||||
"""
|
"""
|
||||||
Nothing to do here!
|
Nothing to do here!
|
||||||
|
|
||||||
:param classif_predictions: this is actually None
|
:param classif_predictions: not used
|
||||||
|
:param data: not used
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -516,7 +523,9 @@ class PACC(AggregativeSoftQuantifier):
|
||||||
"""
|
"""
|
||||||
Estimates the misclassification rates
|
Estimates the misclassification rates
|
||||||
|
|
||||||
:param classif_predictions: classifier soft predictions with true labels
|
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
|
||||||
|
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
|
||||||
|
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||||
"""
|
"""
|
||||||
posteriors, true_labels = classif_predictions.Xy
|
posteriors, true_labels = classif_predictions.Xy
|
||||||
self.pcc = PCC(self.classifier)
|
self.pcc = PCC(self.classifier)
|
||||||
|
@ -626,6 +635,14 @@ class EMQ(AggregativeSoftQuantifier):
|
||||||
return posteriors
|
return posteriors
|
||||||
|
|
||||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
|
"""
|
||||||
|
Trains the aggregation function of EMQ. This comes down to recalibrating the posterior probabilities
|
||||||
|
ir requested.
|
||||||
|
|
||||||
|
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
|
||||||
|
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
|
||||||
|
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||||
|
"""
|
||||||
if self.recalib is not None:
|
if self.recalib is not None:
|
||||||
P, y = classif_predictions.Xy
|
P, y = classif_predictions.Xy
|
||||||
if self.recalib == 'nbvs':
|
if self.recalib == 'nbvs':
|
||||||
|
@ -712,6 +729,99 @@ class EMQ(AggregativeSoftQuantifier):
|
||||||
return qs, ps
|
return qs, ps
|
||||||
|
|
||||||
|
|
||||||
|
class BayesianCC(AggregativeCrispQuantifier):
|
||||||
|
"""
|
||||||
|
`Bayesian quantification <https://arxiv.org/abs/2302.09159>`_ method,
|
||||||
|
which is a variant of :class:`ACC` that calculates the posterior probability distribution
|
||||||
|
over the prevalence vectors, rather than providing a point estimate obtained
|
||||||
|
by matrix inversion.
|
||||||
|
|
||||||
|
Can be used to diagnose degeneracy in the predictions visible when the confusion
|
||||||
|
matrix has high condition number or to quantify uncertainty around the point estimate.
|
||||||
|
|
||||||
|
This method relies on extra dependencies, which have to be installed via:
|
||||||
|
`$ pip install quapy[bayes]`
|
||||||
|
|
||||||
|
:param classifier: a sklearn's Estimator that generates a classifier
|
||||||
|
:param val_split: a float in (0, 1) indicating the proportion of the training data to be used,
|
||||||
|
as a stratified held-out validation set, for generating classifier predictions.
|
||||||
|
:param num_warmup: number of warmup iterations for the MCMC sampler (default 500)
|
||||||
|
:param num_samples: number of samples to draw from the posterior (default 1000)
|
||||||
|
:param mcmc_seed: random seed for the MCMC sampler (default 0)
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
classifier: BaseEstimator,
|
||||||
|
val_split: float = 0.75,
|
||||||
|
num_warmup: int = 500,
|
||||||
|
num_samples: int = 1_000,
|
||||||
|
mcmc_seed: int = 0):
|
||||||
|
|
||||||
|
if num_warmup <= 0:
|
||||||
|
raise ValueError(f'parameter {num_warmup=} must be a positive integer')
|
||||||
|
if num_samples <= 0:
|
||||||
|
raise ValueError(f'parameter {num_samples=} must be a positive integer')
|
||||||
|
|
||||||
|
if (not isinstance(val_split, float)) or val_split <= 0 or val_split >= 1:
|
||||||
|
raise ValueError(f'val_split must be a float in (0, 1), got {val_split}')
|
||||||
|
|
||||||
|
if _bayesian.DEPENDENCIES_INSTALLED is False:
|
||||||
|
raise ImportError("Auxiliary dependencies are required. Run `$ pip install quapy[bayes]` to install them.")
|
||||||
|
|
||||||
|
self.classifier = classifier
|
||||||
|
self.val_split = val_split
|
||||||
|
self.num_warmup = num_warmup
|
||||||
|
self.num_samples = num_samples
|
||||||
|
self.mcmc_seed = mcmc_seed
|
||||||
|
|
||||||
|
# Array of shape (n_classes, n_predicted_classes,) where entry (y, c) is the number of instances
|
||||||
|
# labeled as class y and predicted as class c.
|
||||||
|
# By default, this array is set to None and later defined as part of the `aggregation_fit` phase
|
||||||
|
self._n_and_c_labeled = None
|
||||||
|
|
||||||
|
# Dictionary with posterior samples, set when `aggregate` is provided.
|
||||||
|
self._samples = None
|
||||||
|
|
||||||
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
|
"""
|
||||||
|
Estimates the misclassification rates.
|
||||||
|
|
||||||
|
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
|
||||||
|
as instances, the label predictions issued by the classifier and, as labels, the true labels
|
||||||
|
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||||
|
"""
|
||||||
|
pred_labels, true_labels = classif_predictions.Xy
|
||||||
|
self._n_and_c_labeled = confusion_matrix(y_true=true_labels, y_pred=pred_labels, labels=self.classifier.classes_)
|
||||||
|
|
||||||
|
def sample_from_posterior(self, classif_predictions):
|
||||||
|
if self._n_and_c_labeled is None:
|
||||||
|
raise ValueError("aggregation_fit must be called before sample_from_posterior")
|
||||||
|
|
||||||
|
n_c_unlabeled = F.counts_from_labels(classif_predictions, self.classifier.classes_)
|
||||||
|
|
||||||
|
self._samples = _bayesian.sample_posterior(
|
||||||
|
n_c_unlabeled=n_c_unlabeled,
|
||||||
|
n_y_and_c_labeled=self._n_and_c_labeled,
|
||||||
|
num_warmup=self.num_warmup,
|
||||||
|
num_samples=self.num_samples,
|
||||||
|
seed=self.mcmc_seed,
|
||||||
|
)
|
||||||
|
return self._samples
|
||||||
|
|
||||||
|
def get_prevalence_samples(self):
|
||||||
|
if self._samples is None:
|
||||||
|
raise ValueError("sample_from_posterior must be called before get_prevalence_samples")
|
||||||
|
return self._samples[_bayesian.P_TEST_Y]
|
||||||
|
|
||||||
|
def get_conditional_probability_samples(self):
|
||||||
|
if self._samples is None:
|
||||||
|
raise ValueError("sample_from_posterior must be called before get_conditional_probability_samples")
|
||||||
|
return self._samples[_bayesian.P_C_COND_Y]
|
||||||
|
|
||||||
|
def aggregate(self, classif_predictions):
|
||||||
|
samples = self.sample_from_posterior(classif_predictions)[_bayesian.P_TEST_Y]
|
||||||
|
return np.asarray(samples.mean(axis=0), dtype=float)
|
||||||
|
|
||||||
|
|
||||||
class HDy(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
class HDy(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
||||||
"""
|
"""
|
||||||
`Hellinger Distance y <https://www.sciencedirect.com/science/article/pii/S0020025512004069>`_ (HDy).
|
`Hellinger Distance y <https://www.sciencedirect.com/science/article/pii/S0020025512004069>`_ (HDy).
|
||||||
|
@ -733,14 +843,11 @@ class HDy(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
||||||
|
|
||||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
"""
|
"""
|
||||||
Trains a HDy quantifier.
|
Trains the aggregation function of HDy.
|
||||||
|
|
||||||
:param data: the training set
|
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
|
||||||
:param fit_classifier: set to False to bypass the training (the learner is assumed to be already fit)
|
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
|
||||||
:param val_split: either a float in (0,1) indicating the proportion of training instances to use for
|
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||||
validation (e.g., 0.3 for using 30% of the training set as validation data), or a
|
|
||||||
:class:`quapy.data.base.LabelledCollection` indicating the validation set itself
|
|
||||||
:return: self
|
|
||||||
"""
|
"""
|
||||||
P, y = classif_predictions.Xy
|
P, y = classif_predictions.Xy
|
||||||
Px = P[:, self.pos_label] # takes only the P(y=+1|x)
|
Px = P[:, self.pos_label] # takes only the P(y=+1|x)
|
||||||
|
@ -757,8 +864,6 @@ class HDy(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
||||||
self.Pxy1_density = {bins: hist(self.Pxy1, bins) for bins in self.bins}
|
self.Pxy1_density = {bins: hist(self.Pxy1, bins) for bins in self.bins}
|
||||||
self.Pxy0_density = {bins: hist(self.Pxy0, bins) for bins in self.bins}
|
self.Pxy0_density = {bins: hist(self.Pxy0, bins) for bins in self.bins}
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def aggregate(self, classif_posteriors):
|
def aggregate(self, classif_posteriors):
|
||||||
# "In this work, the number of bins b used in HDx and HDy was chosen from 10 to 110 in steps of 10,
|
# "In this work, the number of bins b used in HDx and HDy was chosen from 10 to 110 in steps of 10,
|
||||||
# and the final estimated a priori probability was taken as the median of these 11 estimates."
|
# and the final estimated a priori probability was taken as the median of these 11 estimates."
|
||||||
|
@ -833,6 +938,13 @@ class DyS(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
||||||
return (left + right) / 2
|
return (left + right) / 2
|
||||||
|
|
||||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
|
"""
|
||||||
|
Trains the aggregation function of DyS.
|
||||||
|
|
||||||
|
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
|
||||||
|
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
|
||||||
|
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||||
|
"""
|
||||||
Px, y = classif_predictions.Xy
|
Px, y = classif_predictions.Xy
|
||||||
Px = Px[:, self.pos_label] # takes only the P(y=+1|x)
|
Px = Px[:, self.pos_label] # takes only the P(y=+1|x)
|
||||||
self.Pxy1 = Px[y == self.pos_label]
|
self.Pxy1 = Px[y == self.pos_label]
|
||||||
|
@ -871,6 +983,13 @@ class SMM(AggregativeSoftQuantifier, BinaryAggregativeQuantifier):
|
||||||
self.val_split = val_split
|
self.val_split = val_split
|
||||||
|
|
||||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
|
"""
|
||||||
|
Trains the aggregation function of SMM.
|
||||||
|
|
||||||
|
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
|
||||||
|
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
|
||||||
|
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||||
|
"""
|
||||||
Px, y = classif_predictions.Xy
|
Px, y = classif_predictions.Xy
|
||||||
Px = Px[:, self.pos_label] # takes only the P(y=+1|x)
|
Px = Px[:, self.pos_label] # takes only the P(y=+1|x)
|
||||||
self.Pxy1 = Px[y == self.pos_label]
|
self.Pxy1 = Px[y == self.pos_label]
|
||||||
|
@ -944,19 +1063,17 @@ class DMy(AggregativeSoftQuantifier):
|
||||||
|
|
||||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
"""
|
"""
|
||||||
Trains the classifier (if requested) and generates the validation distributions out of the training data.
|
Trains the aggregation function of a distribution matching method. This comes down to generating the
|
||||||
|
validation distributions out of the training data.
|
||||||
The validation distributions have shape `(n, ch, nbins)`, with `n` the number of classes, `ch` the number of
|
The validation distributions have shape `(n, ch, nbins)`, with `n` the number of classes, `ch` the number of
|
||||||
channels, and `nbins` the number of bins. In particular, let `V` be the validation distributions; then `di=V[i]`
|
channels, and `nbins` the number of bins. In particular, let `V` be the validation distributions; then `di=V[i]`
|
||||||
are the distributions obtained from training data labelled with class `i`; while `dij = di[j]` is the discrete
|
are the distributions obtained from training data labelled with class `i`; while `dij = di[j]` is the discrete
|
||||||
distribution of posterior probabilities `P(Y=j|X=x)` for training data labelled with class `i`, and `dij[k]`
|
distribution of posterior probabilities `P(Y=j|X=x)` for training data labelled with class `i`, and `dij[k]`
|
||||||
is the fraction of instances with a value in the `k`-th bin.
|
is the fraction of instances with a value in the `k`-th bin.
|
||||||
|
|
||||||
:param data: the training set
|
:param classif_predictions: a :class:`quapy.data.base.LabelledCollection` containing,
|
||||||
:param fit_classifier: set to False to bypass the training (the learner is assumed to be already fit)
|
as instances, the posterior probabilities issued by the classifier and, as labels, the true labels
|
||||||
:param val_split: either a float in (0,1) indicating the proportion of training instances to use for
|
:param data: a :class:`quapy.data.base.LabelledCollection` consisting of the training data
|
||||||
validation (e.g., 0.3 for using 30% of the training set as validation data), or a LabelledCollection
|
|
||||||
indicating the validation set itself, or an int indicating the number k of folds to be used in kFCV
|
|
||||||
to estimate the parameters
|
|
||||||
"""
|
"""
|
||||||
posteriors, true_labels = classif_predictions.Xy
|
posteriors, true_labels = classif_predictions.Xy
|
||||||
n_classes = len(self.classifier.classes_)
|
n_classes = len(self.classifier.classes_)
|
||||||
|
|
|
@ -150,6 +150,7 @@ class DMx(BaseQuantifier):
|
||||||
class ReadMe(BaseQuantifier):
|
class ReadMe(BaseQuantifier):
|
||||||
|
|
||||||
def __init__(self, bootstrap_trials=100, bootstrap_range=100, bagging_trials=100, bagging_range=25, **vectorizer_kwargs):
|
def __init__(self, bootstrap_trials=100, bootstrap_range=100, bagging_trials=100, bagging_range=25, **vectorizer_kwargs):
|
||||||
|
raise NotImplementedError('under development ...')
|
||||||
self.bootstrap_trials = bootstrap_trials
|
self.bootstrap_trials = bootstrap_trials
|
||||||
self.bootstrap_range = bootstrap_range
|
self.bootstrap_range = bootstrap_range
|
||||||
self.bagging_trials = bagging_trials
|
self.bagging_trials = bagging_trials
|
||||||
|
|
7
setup.py
7
setup.py
|
@ -123,10 +123,9 @@ setup(
|
||||||
#
|
#
|
||||||
# Similar to `install_requires` above, these must be valid existing
|
# Similar to `install_requires` above, these must be valid existing
|
||||||
# projects.
|
# projects.
|
||||||
# extras_require={ # Optional
|
extras_require={ # Optional
|
||||||
# 'dev': ['check-manifest'],
|
'bayes': ['jax', 'jaxlib', 'numpyro'],
|
||||||
# 'test': ['coverage'],
|
},
|
||||||
# },
|
|
||||||
|
|
||||||
# If there are data files included in your packages that need to be
|
# If there are data files included in your packages that need to be
|
||||||
# installed, specify them here.
|
# installed, specify them here.
|
||||||
|
|
Loading…
Reference in New Issue