forked from moreo/QuaPy
small fixes in kdey (now should work with string labels) and EMQ (in case some training prior prob was 0, it broke)
This commit is contained in:
parent
9542eaee61
commit
320b3eac38
|
@ -52,7 +52,7 @@ class KDEBase:
|
||||||
"""
|
"""
|
||||||
return np.exp(kde.score_samples(X))
|
return np.exp(kde.score_samples(X))
|
||||||
|
|
||||||
def get_mixture_components(self, X, y, n_classes, bandwidth):
|
def get_mixture_components(self, X, y, classes, bandwidth):
|
||||||
"""
|
"""
|
||||||
Returns an array containing the mixture components, i.e., the KDE functions for each class.
|
Returns an array containing the mixture components, i.e., the KDE functions for each class.
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ class KDEBase:
|
||||||
:param bandwidth: float, the bandwidth of the kernel
|
:param bandwidth: float, the bandwidth of the kernel
|
||||||
:return: a list of KernelDensity objects, each fitted with the corresponding class-specific covariates
|
:return: a list of KernelDensity objects, each fitted with the corresponding class-specific covariates
|
||||||
"""
|
"""
|
||||||
return [self.get_kde_function(X[y == cat], bandwidth) for cat in range(n_classes)]
|
return [self.get_kde_function(X[y == cat], bandwidth) for cat in classes]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ class KDEyML(AggregativeSoftQuantifier, KDEBase):
|
||||||
self.random_state=random_state
|
self.random_state=random_state
|
||||||
|
|
||||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
self.mix_densities = self.get_mixture_components(*classif_predictions.Xy, data.n_classes, self.bandwidth)
|
self.mix_densities = self.get_mixture_components(*classif_predictions.Xy, data.classes_, self.bandwidth)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def aggregate(self, posteriors: np.ndarray):
|
def aggregate(self, posteriors: np.ndarray):
|
||||||
|
@ -196,7 +196,7 @@ class KDEyHD(AggregativeSoftQuantifier, KDEBase):
|
||||||
self.montecarlo_trials = montecarlo_trials
|
self.montecarlo_trials = montecarlo_trials
|
||||||
|
|
||||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
self.mix_densities = self.get_mixture_components(*classif_predictions.Xy, data.n_classes, self.bandwidth)
|
self.mix_densities = self.get_mixture_components(*classif_predictions.Xy, data.classes_, self.bandwidth)
|
||||||
|
|
||||||
N = self.montecarlo_trials
|
N = self.montecarlo_trials
|
||||||
rs = self.random_state
|
rs = self.random_state
|
||||||
|
|
|
@ -640,6 +640,8 @@ class EMQ(AggregativeSoftQuantifier):
|
||||||
raise ValueError('invalid param argument for recalibration method; available ones are '
|
raise ValueError('invalid param argument for recalibration method; available ones are '
|
||||||
'"nbvs", "bcts", "ts", and "vs".')
|
'"nbvs", "bcts", "ts", and "vs".')
|
||||||
|
|
||||||
|
if not np.issubdtype(y.dtype, np.number):
|
||||||
|
y = np.searchsorted(data.classes_, y)
|
||||||
self.calibration_function = calibrator(P, np.eye(data.n_classes)[y], posterior_supplied=True)
|
self.calibration_function = calibrator(P, np.eye(data.n_classes)[y], posterior_supplied=True)
|
||||||
|
|
||||||
if self.exact_train_prev:
|
if self.exact_train_prev:
|
||||||
|
@ -681,6 +683,11 @@ class EMQ(AggregativeSoftQuantifier):
|
||||||
"""
|
"""
|
||||||
Px = posterior_probabilities
|
Px = posterior_probabilities
|
||||||
Ptr = np.copy(tr_prev)
|
Ptr = np.copy(tr_prev)
|
||||||
|
|
||||||
|
if np.product(Ptr) == 0: # some entry is 0; we should smooth the values to avoid 0 division
|
||||||
|
Ptr += epsilon
|
||||||
|
Ptr /= Ptr.sum()
|
||||||
|
|
||||||
qs = np.copy(Ptr) # qs (the running estimate) is initialized as the training prevalence
|
qs = np.copy(Ptr) # qs (the running estimate) is initialized as the training prevalence
|
||||||
|
|
||||||
s, converged = 0, False
|
s, converged = 0, False
|
||||||
|
|
Loading…
Reference in New Issue