forked from moreo/QuaPy
small fixes in kdey (now should work with string labels) and EMQ (in case some training prior prob was 0, it broke)
This commit is contained in:
parent
9542eaee61
commit
320b3eac38
|
@ -52,7 +52,7 @@ class KDEBase:
|
|||
"""
|
||||
return np.exp(kde.score_samples(X))
|
||||
|
||||
def get_mixture_components(self, X, y, n_classes, bandwidth):
|
||||
def get_mixture_components(self, X, y, classes, bandwidth):
|
||||
"""
|
||||
Returns an array containing the mixture components, i.e., the KDE functions for each class.
|
||||
|
||||
|
@ -62,7 +62,7 @@ class KDEBase:
|
|||
:param bandwidth: float, the bandwidth of the kernel
|
||||
:return: a list of KernelDensity objects, each fitted with the corresponding class-specific covariates
|
||||
"""
|
||||
return [self.get_kde_function(X[y == cat], bandwidth) for cat in range(n_classes)]
|
||||
return [self.get_kde_function(X[y == cat], bandwidth) for cat in classes]
|
||||
|
||||
|
||||
|
||||
|
@ -114,7 +114,7 @@ class KDEyML(AggregativeSoftQuantifier, KDEBase):
|
|||
self.random_state=random_state
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
self.mix_densities = self.get_mixture_components(*classif_predictions.Xy, data.n_classes, self.bandwidth)
|
||||
self.mix_densities = self.get_mixture_components(*classif_predictions.Xy, data.classes_, self.bandwidth)
|
||||
return self
|
||||
|
||||
def aggregate(self, posteriors: np.ndarray):
|
||||
|
@ -196,7 +196,7 @@ class KDEyHD(AggregativeSoftQuantifier, KDEBase):
|
|||
self.montecarlo_trials = montecarlo_trials
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
self.mix_densities = self.get_mixture_components(*classif_predictions.Xy, data.n_classes, self.bandwidth)
|
||||
self.mix_densities = self.get_mixture_components(*classif_predictions.Xy, data.classes_, self.bandwidth)
|
||||
|
||||
N = self.montecarlo_trials
|
||||
rs = self.random_state
|
||||
|
|
|
@ -640,6 +640,8 @@ class EMQ(AggregativeSoftQuantifier):
|
|||
raise ValueError('invalid param argument for recalibration method; available ones are '
|
||||
'"nbvs", "bcts", "ts", and "vs".')
|
||||
|
||||
if not np.issubdtype(y.dtype, np.number):
|
||||
y = np.searchsorted(data.classes_, y)
|
||||
self.calibration_function = calibrator(P, np.eye(data.n_classes)[y], posterior_supplied=True)
|
||||
|
||||
if self.exact_train_prev:
|
||||
|
@ -681,6 +683,11 @@ class EMQ(AggregativeSoftQuantifier):
|
|||
"""
|
||||
Px = posterior_probabilities
|
||||
Ptr = np.copy(tr_prev)
|
||||
|
||||
if np.product(Ptr) == 0: # some entry is 0; we should smooth the values to avoid 0 division
|
||||
Ptr += epsilon
|
||||
Ptr /= Ptr.sum()
|
||||
|
||||
qs = np.copy(Ptr) # qs (the running estimate) is initialized as the training prevalence
|
||||
|
||||
s, converged = 0, False
|
||||
|
|
Loading…
Reference in New Issue