forked from moreo/QuaPy
fixed optimization threshold methods (again)
This commit is contained in:
parent
c0d92a2083
commit
b68b58ad11
|
@ -1278,16 +1278,21 @@ class MS(ThresholdOptimization):
|
||||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
# keeps all candidates
|
# keeps all candidates
|
||||||
decision_scores, y = classif_predictions.Xy
|
decision_scores, y = classif_predictions.Xy
|
||||||
self.tprs_fprs_thresholds = self._eval_candidate_thresholds(decision_scores, y)
|
tprs_fprs_thresholds = self._eval_candidate_thresholds(decision_scores, y)
|
||||||
|
self.tprs = tprs_fprs_thresholds[:, 0]
|
||||||
|
self.fprs = tprs_fprs_thresholds[:, 1]
|
||||||
|
self.thresholds = tprs_fprs_thresholds[:, 2]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def aggregate(self, classif_predictions: np.ndarray):
|
def aggregate(self, classif_predictions: np.ndarray):
|
||||||
prevalences = []
|
prevalences = self.aggregate_with_threshold(classif_predictions, self.tprs, self.fprs, self.thresholds)
|
||||||
for tpr, fpr, threshold in self.tprs_fprs_thresholds:
|
return np.median(prevalences, axis=0)
|
||||||
pos_prev = self.aggregate_with_threshold(classif_predictions, tpr, fpr, threshold)[1]
|
# prevalences = []
|
||||||
prevalences.append(pos_prev)
|
# for tpr, fpr, threshold in self.tprs_fprs_thresholds:
|
||||||
median = np.median(prevalences)
|
# pos_prev = self.aggregate_with_threshold(classif_predictions, tpr, fpr, threshold)[1]
|
||||||
return F.as_binary_prevalence(median)
|
# prevalences.append(pos_prev)
|
||||||
|
# median = np.median(prevalences)
|
||||||
|
# return F.as_binary_prevalence(median)
|
||||||
|
|
||||||
|
|
||||||
class MS2(MS):
|
class MS2(MS):
|
||||||
|
|
Loading…
Reference in New Issue