forked from moreo/QuaPy
fixed optimization threshold methods (again)
This commit is contained in:
parent
c0d92a2083
commit
b68b58ad11
|
@ -1278,16 +1278,21 @@ class MS(ThresholdOptimization):
|
|||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
# keeps all candidates
|
||||
decision_scores, y = classif_predictions.Xy
|
||||
self.tprs_fprs_thresholds = self._eval_candidate_thresholds(decision_scores, y)
|
||||
tprs_fprs_thresholds = self._eval_candidate_thresholds(decision_scores, y)
|
||||
self.tprs = tprs_fprs_thresholds[:, 0]
|
||||
self.fprs = tprs_fprs_thresholds[:, 1]
|
||||
self.thresholds = tprs_fprs_thresholds[:, 2]
|
||||
return self
|
||||
|
||||
def aggregate(self, classif_predictions: np.ndarray):
|
||||
prevalences = []
|
||||
for tpr, fpr, threshold in self.tprs_fprs_thresholds:
|
||||
pos_prev = self.aggregate_with_threshold(classif_predictions, tpr, fpr, threshold)[1]
|
||||
prevalences.append(pos_prev)
|
||||
median = np.median(prevalences)
|
||||
return F.as_binary_prevalence(median)
|
||||
prevalences = self.aggregate_with_threshold(classif_predictions, self.tprs, self.fprs, self.thresholds)
|
||||
return np.median(prevalences, axis=0)
|
||||
# prevalences = []
|
||||
# for tpr, fpr, threshold in self.tprs_fprs_thresholds:
|
||||
# pos_prev = self.aggregate_with_threshold(classif_predictions, tpr, fpr, threshold)[1]
|
||||
# prevalences.append(pos_prev)
|
||||
# median = np.median(prevalences)
|
||||
# return F.as_binary_prevalence(median)
|
||||
|
||||
|
||||
class MS2(MS):
|
||||
|
|
Loading…
Reference in New Issue