From 5738821d10a927b9c73e209f7ce5ddb0870f7e47 Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Sun, 20 Apr 2025 22:09:18 +0200 Subject: [PATCH] refactoring w/o labelled collection --- quapy/method/aggregative.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/quapy/method/aggregative.py b/quapy/method/aggregative.py index 67b003a..a723dcc 100644 --- a/quapy/method/aggregative.py +++ b/quapy/method/aggregative.py @@ -729,6 +729,11 @@ class EMQ(AggregativeSoftQuantifier): posteriors = self.calibration_function(posteriors) return posteriors + def classifier_fit_predict(self, X, y): + classif_predictions = super().classifier_fit_predict(X, y) + self.train_prevalence = F.prevalence_from_labels(y, classes=self.classes_) + return classif_predictions + def aggregation_fit(self, classif_predictions): """ Trains the aggregation function of EMQ. This comes down to recalibrating the posterior probabilities @@ -756,9 +761,7 @@ class EMQ(AggregativeSoftQuantifier): y = np.searchsorted(self.classes_, y) self.calibration_function = calibrator(P, np.eye(n_classes)[y], posterior_supplied=True) - if self.exact_train_prev: - self.train_prevalence = F.prevalence_from_labels(y, self.classes_) - else: + if not self.exact_train_prev: train_posteriors = classif_predictions.X if self.recalib is not None: train_posteriors = self.calibration_function(train_posteriors)