forked from moreo/QuaPy
refit=True default value in GridSearchQ
This commit is contained in:
parent
294e251450
commit
8239947746
|
@ -524,7 +524,7 @@ class ThresholdOptimization(AggregativeQuantifier, BinaryQuantifier):
|
|||
...
|
||||
|
||||
def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, int, LabelledCollection] = None):
|
||||
BinaryQuantifier._check_binary(data, "Threshold Optimization")
|
||||
self._check_binary(data, "Threshold Optimization")
|
||||
|
||||
if val_split is None:
|
||||
val_split = self.val_split
|
||||
|
@ -643,6 +643,9 @@ class MS(ThresholdOptimization):
|
|||
def __init__(self, learner: BaseEstimator, val_split=0.4):
|
||||
super().__init__(learner, val_split)
|
||||
|
||||
def _condition(self, tpr, fpr) -> float:
|
||||
pass
|
||||
|
||||
def optimize_threshold(self, y, probabilities):
|
||||
tprs = []
|
||||
fprs = []
|
||||
|
|
|
@ -39,6 +39,7 @@ class BaseQuantifier(metaclass=ABCMeta):
|
|||
|
||||
|
||||
class BinaryQuantifier(BaseQuantifier):
|
||||
|
||||
def _check_binary(self, data: LabelledCollection, quantifier_name):
|
||||
assert data.binary, f'{quantifier_name} works only on problems of binary classification. ' \
|
||||
f'Use the class OneVsAll to enable {quantifier_name} work on single-label data.'
|
||||
|
|
|
@ -20,7 +20,7 @@ class GridSearchQ(BaseQuantifier):
|
|||
n_repetitions: int = 1,
|
||||
eval_budget: int = None,
|
||||
error: Union[Callable, str] = qp.error.mae,
|
||||
refit=False,
|
||||
refit=True,
|
||||
val_split=0.4,
|
||||
n_jobs=1,
|
||||
random_seed=42,
|
||||
|
|
Loading…
Reference in New Issue