refactoring w/o labelled collection

This commit is contained in:
Alejandro Moreo Fernandez 2025-04-20 22:09:18 +02:00
parent 075be93a23
commit 5738821d10
1 changed files with 6 additions and 3 deletions

View File

@ -729,6 +729,11 @@ class EMQ(AggregativeSoftQuantifier):
posteriors = self.calibration_function(posteriors)
return posteriors
def classifier_fit_predict(self, X, y):
classif_predictions = super().classifier_fit_predict(X, y)
self.train_prevalence = F.prevalence_from_labels(y, classes=self.classes_)
return classif_predictions
def aggregation_fit(self, classif_predictions):
"""
Trains the aggregation function of EMQ. This comes down to recalibrating the posterior probabilities
@ -756,9 +761,7 @@ class EMQ(AggregativeSoftQuantifier):
y = np.searchsorted(self.classes_, y)
self.calibration_function = calibrator(P, np.eye(n_classes)[y], posterior_supplied=True)
if self.exact_train_prev:
self.train_prevalence = F.prevalence_from_labels(y, self.classes_)
else:
if not self.exact_train_prev:
train_posteriors = classif_predictions.X
if self.recalib is not None:
train_posteriors = self.calibration_function(train_posteriors)