diff --git a/quacc/method/base.py b/quacc/method/base.py index dcc19f8..b31c039 100644 --- a/quacc/method/base.py +++ b/quacc/method/base.py @@ -17,11 +17,10 @@ class BaseAccuracyEstimator(BaseQuantifier): self, classifier: BaseEstimator, quantifier: BaseQuantifier, - collapse_false=False, ): self.__check_classifier(classifier) self.quantifier = quantifier - self.extpol = ExtensionPolicy(collapse_false=collapse_false) + self.extpol = ExtensionPolicy() def __check_classifier(self, classifier): if not hasattr(classifier, "predict_proba"): @@ -50,23 +49,17 @@ class BaseAccuracyEstimator(BaseQuantifier): def estimate(self, instances, ext=False) -> np.ndarray: ... - @property - def collapse_false(self): - return self.extpol.collapse_false - class ConfidenceBasedAccuracyEstimator(BaseAccuracyEstimator): def __init__( self, classifier: BaseEstimator, quantifier: BaseQuantifier, - collapse_false=False, confidence=None, ): super().__init__( classifier=classifier, quantifier=quantifier, - collapse_false=collapse_false, ) self.__check_confidence(confidence) self.calibrator = None @@ -137,8 +130,8 @@ class MultiClassAccuracyEstimator(ConfidenceBasedAccuracyEstimator): classifier=classifier, quantifier=quantifier, confidence=confidence, - collapse_false=collapse_false, ) + self.extpol = ExtensionPolicy(collapse_false=collapse_false) self.e_train = None def _get_pred_ext(self, pred_proba: np.ndarray): @@ -176,6 +169,10 @@ class MultiClassAccuracyEstimator(ConfidenceBasedAccuracyEstimator): estim_prev = np.insert(estim_prev, _cls, [0.0], axis=0) return estim_prev + @property + def collapse_false(self): + return self.extpol.collapse_false + class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator): def __init__( @@ -183,7 +180,6 @@ class BinaryQuantifierAccuracyEstimator(ConfidenceBasedAccuracyEstimator): classifier: BaseEstimator, quantifier: BaseAccuracyEstimator, confidence: str = None, - collapse_false=False, ): super().__init__( classifier=classifier,