From b68b58ad113bd9cb54c8aa73df2be7d35aaa4911 Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Thu, 18 Jan 2024 18:26:40 +0100 Subject: [PATCH] fixed optimization threshold methods (again) --- quapy/method/aggregative.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/quapy/method/aggregative.py b/quapy/method/aggregative.py index 066f480..c6293c9 100644 --- a/quapy/method/aggregative.py +++ b/quapy/method/aggregative.py @@ -1278,16 +1278,21 @@ class MS(ThresholdOptimization): def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection): # keeps all candidates decision_scores, y = classif_predictions.Xy - self.tprs_fprs_thresholds = self._eval_candidate_thresholds(decision_scores, y) + tprs_fprs_thresholds = self._eval_candidate_thresholds(decision_scores, y) + self.tprs = tprs_fprs_thresholds[:, 0] + self.fprs = tprs_fprs_thresholds[:, 1] + self.thresholds = tprs_fprs_thresholds[:, 2] return self def aggregate(self, classif_predictions: np.ndarray): - prevalences = [] - for tpr, fpr, threshold in self.tprs_fprs_thresholds: - pos_prev = self.aggregate_with_threshold(classif_predictions, tpr, fpr, threshold)[1] - prevalences.append(pos_prev) - median = np.median(prevalences) - return F.as_binary_prevalence(median) + prevalences = self.aggregate_with_threshold(classif_predictions, self.tprs, self.fprs, self.thresholds) + return np.median(prevalences, axis=0) + # prevalences = [] + # for tpr, fpr, threshold in self.tprs_fprs_thresholds: + # pos_prev = self.aggregate_with_threshold(classif_predictions, tpr, fpr, threshold)[1] + # prevalences.append(pos_prev) + # median = np.median(prevalences) + # return F.as_binary_prevalence(median) class MS2(MS):