forked from moreo/QuaPy
refit=True default value in GridSearchQ
This commit is contained in:
parent
294e251450
commit
8239947746
|
@ -524,7 +524,7 @@ class ThresholdOptimization(AggregativeQuantifier, BinaryQuantifier):
|
||||||
...
|
...
|
||||||
|
|
||||||
def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, int, LabelledCollection] = None):
|
def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, int, LabelledCollection] = None):
|
||||||
BinaryQuantifier._check_binary(data, "Threshold Optimization")
|
self._check_binary(data, "Threshold Optimization")
|
||||||
|
|
||||||
if val_split is None:
|
if val_split is None:
|
||||||
val_split = self.val_split
|
val_split = self.val_split
|
||||||
|
@ -643,6 +643,9 @@ class MS(ThresholdOptimization):
|
||||||
def __init__(self, learner: BaseEstimator, val_split=0.4):
|
def __init__(self, learner: BaseEstimator, val_split=0.4):
|
||||||
super().__init__(learner, val_split)
|
super().__init__(learner, val_split)
|
||||||
|
|
||||||
|
def _condition(self, tpr, fpr) -> float:
|
||||||
|
pass
|
||||||
|
|
||||||
def optimize_threshold(self, y, probabilities):
|
def optimize_threshold(self, y, probabilities):
|
||||||
tprs = []
|
tprs = []
|
||||||
fprs = []
|
fprs = []
|
||||||
|
|
|
@ -39,6 +39,7 @@ class BaseQuantifier(metaclass=ABCMeta):
|
||||||
|
|
||||||
|
|
||||||
class BinaryQuantifier(BaseQuantifier):
|
class BinaryQuantifier(BaseQuantifier):
|
||||||
|
|
||||||
def _check_binary(self, data: LabelledCollection, quantifier_name):
|
def _check_binary(self, data: LabelledCollection, quantifier_name):
|
||||||
assert data.binary, f'{quantifier_name} works only on problems of binary classification. ' \
|
assert data.binary, f'{quantifier_name} works only on problems of binary classification. ' \
|
||||||
f'Use the class OneVsAll to enable {quantifier_name} work on single-label data.'
|
f'Use the class OneVsAll to enable {quantifier_name} work on single-label data.'
|
||||||
|
|
|
@ -20,7 +20,7 @@ class GridSearchQ(BaseQuantifier):
|
||||||
n_repetitions: int = 1,
|
n_repetitions: int = 1,
|
||||||
eval_budget: int = None,
|
eval_budget: int = None,
|
||||||
error: Union[Callable, str] = qp.error.mae,
|
error: Union[Callable, str] = qp.error.mae,
|
||||||
refit=False,
|
refit=True,
|
||||||
val_split=0.4,
|
val_split=0.4,
|
||||||
n_jobs=1,
|
n_jobs=1,
|
||||||
random_seed=42,
|
random_seed=42,
|
||||||
|
|
Loading…
Reference in New Issue