refactoring w/o labelled collection
This commit is contained in:
parent
075be93a23
commit
5738821d10
|
@ -729,6 +729,11 @@ class EMQ(AggregativeSoftQuantifier):
|
|||
posteriors = self.calibration_function(posteriors)
|
||||
return posteriors
|
||||
|
||||
def classifier_fit_predict(self, X, y):
|
||||
classif_predictions = super().classifier_fit_predict(X, y)
|
||||
self.train_prevalence = F.prevalence_from_labels(y, classes=self.classes_)
|
||||
return classif_predictions
|
||||
|
||||
def aggregation_fit(self, classif_predictions):
|
||||
"""
|
||||
Trains the aggregation function of EMQ. This comes down to recalibrating the posterior probabilities
|
||||
|
@ -756,9 +761,7 @@ class EMQ(AggregativeSoftQuantifier):
|
|||
y = np.searchsorted(self.classes_, y)
|
||||
self.calibration_function = calibrator(P, np.eye(n_classes)[y], posterior_supplied=True)
|
||||
|
||||
if self.exact_train_prev:
|
||||
self.train_prevalence = F.prevalence_from_labels(y, self.classes_)
|
||||
else:
|
||||
if not self.exact_train_prev:
|
||||
train_posteriors = classif_predictions.X
|
||||
if self.recalib is not None:
|
||||
train_posteriors = self.calibration_function(train_posteriors)
|
||||
|
|
Loading…
Reference in New Issue