diff --git a/quapy/method/aggregative.py b/quapy/method/aggregative.py index 67b003a..a723dcc 100644 --- a/quapy/method/aggregative.py +++ b/quapy/method/aggregative.py @@ -729,6 +729,11 @@ class EMQ(AggregativeSoftQuantifier): posteriors = self.calibration_function(posteriors) return posteriors + def classifier_fit_predict(self, X, y): + classif_predictions = super().classifier_fit_predict(X, y) + self.train_prevalence = F.prevalence_from_labels(y, classes=self.classes_) + return classif_predictions + def aggregation_fit(self, classif_predictions): """ Trains the aggregation function of EMQ. This comes down to recalibrating the posterior probabilities @@ -756,9 +761,7 @@ class EMQ(AggregativeSoftQuantifier): y = np.searchsorted(self.classes_, y) self.calibration_function = calibrator(P, np.eye(n_classes)[y], posterior_supplied=True) - if self.exact_train_prev: - self.train_prevalence = F.prevalence_from_labels(y, self.classes_) - else: + if not self.exact_train_prev: train_posteriors = classif_predictions.X if self.recalib is not None: train_posteriors = self.calibration_function(train_posteriors)