From 82399477462415218bd553a1d87737be22398211 Mon Sep 17 00:00:00 2001 From: Alex Moreo Date: Wed, 16 Jun 2021 13:53:54 +0200 Subject: [PATCH] refit=True default value in GridSearchQ --- quapy/method/aggregative.py | 5 ++++- quapy/method/base.py | 1 + quapy/model_selection.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/quapy/method/aggregative.py b/quapy/method/aggregative.py index 151dd2e..5e9268b 100644 --- a/quapy/method/aggregative.py +++ b/quapy/method/aggregative.py @@ -524,7 +524,7 @@ class ThresholdOptimization(AggregativeQuantifier, BinaryQuantifier): ... def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, int, LabelledCollection] = None): - BinaryQuantifier._check_binary(data, "Threshold Optimization") + self._check_binary(data, "Threshold Optimization") if val_split is None: val_split = self.val_split @@ -643,6 +643,9 @@ class MS(ThresholdOptimization): def __init__(self, learner: BaseEstimator, val_split=0.4): super().__init__(learner, val_split) + def _condition(self, tpr, fpr) -> float: + pass + def optimize_threshold(self, y, probabilities): tprs = [] fprs = [] diff --git a/quapy/method/base.py b/quapy/method/base.py index 0c2729f..64fdff4 100644 --- a/quapy/method/base.py +++ b/quapy/method/base.py @@ -39,6 +39,7 @@ class BaseQuantifier(metaclass=ABCMeta): class BinaryQuantifier(BaseQuantifier): + def _check_binary(self, data: LabelledCollection, quantifier_name): assert data.binary, f'{quantifier_name} works only on problems of binary classification. ' \ f'Use the class OneVsAll to enable {quantifier_name} work on single-label data.' diff --git a/quapy/model_selection.py b/quapy/model_selection.py index 5fd21a7..1080db0 100644 --- a/quapy/model_selection.py +++ b/quapy/model_selection.py @@ -20,7 +20,7 @@ class GridSearchQ(BaseQuantifier): n_repetitions: int = 1, eval_budget: int = None, error: Union[Callable, str] = qp.error.mae, - refit=False, + refit=True, val_split=0.4, n_jobs=1, random_seed=42,