refit=True default value in GridSearchQ

This commit is contained in:
Alejandro Moreo Fernandez 2021-06-16 13:53:54 +02:00
parent 294e251450
commit 8239947746
3 changed files with 6 additions and 2 deletions

View File

@ -524,7 +524,7 @@ class ThresholdOptimization(AggregativeQuantifier, BinaryQuantifier):
... ...
def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, int, LabelledCollection] = None): def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, int, LabelledCollection] = None):
BinaryQuantifier._check_binary(data, "Threshold Optimization") self._check_binary(data, "Threshold Optimization")
if val_split is None: if val_split is None:
val_split = self.val_split val_split = self.val_split
@ -643,6 +643,9 @@ class MS(ThresholdOptimization):
def __init__(self, learner: BaseEstimator, val_split=0.4): def __init__(self, learner: BaseEstimator, val_split=0.4):
super().__init__(learner, val_split) super().__init__(learner, val_split)
def _condition(self, tpr, fpr) -> float:
pass
def optimize_threshold(self, y, probabilities): def optimize_threshold(self, y, probabilities):
tprs = [] tprs = []
fprs = [] fprs = []

View File

@ -39,6 +39,7 @@ class BaseQuantifier(metaclass=ABCMeta):
class BinaryQuantifier(BaseQuantifier): class BinaryQuantifier(BaseQuantifier):
def _check_binary(self, data: LabelledCollection, quantifier_name): def _check_binary(self, data: LabelledCollection, quantifier_name):
assert data.binary, f'{quantifier_name} works only on problems of binary classification. ' \ assert data.binary, f'{quantifier_name} works only on problems of binary classification. ' \
f'Use the class OneVsAll to enable {quantifier_name} work on single-label data.' f'Use the class OneVsAll to enable {quantifier_name} work on single-label data.'

View File

@ -20,7 +20,7 @@ class GridSearchQ(BaseQuantifier):
n_repetitions: int = 1, n_repetitions: int = 1,
eval_budget: int = None, eval_budget: int = None,
error: Union[Callable, str] = qp.error.mae, error: Union[Callable, str] = qp.error.mae,
refit=False, refit=True,
val_split=0.4, val_split=0.4,
n_jobs=1, n_jobs=1,
random_seed=42, random_seed=42,