fixed optimization threshold methods (again)

This commit is contained in:
Alejandro Moreo Fernandez 2024-01-18 18:26:40 +01:00
parent c0d92a2083
commit b68b58ad11
1 changed files with 12 additions and 7 deletions

View File

@ -1278,16 +1278,21 @@ class MS(ThresholdOptimization):
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
# keeps all candidates
decision_scores, y = classif_predictions.Xy
self.tprs_fprs_thresholds = self._eval_candidate_thresholds(decision_scores, y)
tprs_fprs_thresholds = self._eval_candidate_thresholds(decision_scores, y)
self.tprs = tprs_fprs_thresholds[:, 0]
self.fprs = tprs_fprs_thresholds[:, 1]
self.thresholds = tprs_fprs_thresholds[:, 2]
return self
def aggregate(self, classif_predictions: np.ndarray):
prevalences = []
for tpr, fpr, threshold in self.tprs_fprs_thresholds:
pos_prev = self.aggregate_with_threshold(classif_predictions, tpr, fpr, threshold)[1]
prevalences.append(pos_prev)
median = np.median(prevalences)
return F.as_binary_prevalence(median)
prevalences = self.aggregate_with_threshold(classif_predictions, self.tprs, self.fprs, self.thresholds)
return np.median(prevalences, axis=0)
# prevalences = []
# for tpr, fpr, threshold in self.tprs_fprs_thresholds:
# pos_prev = self.aggregate_with_threshold(classif_predictions, tpr, fpr, threshold)[1]
# prevalences.append(pos_prev)
# median = np.median(prevalences)
# return F.as_binary_prevalence(median)
class MS2(MS):