diff --git a/quapy/model_selection.py b/quapy/model_selection.py index f77bee9..081f2f4 100644 --- a/quapy/model_selection.py +++ b/quapy/model_selection.py @@ -34,17 +34,16 @@ class GridSearchQ(BaseQuantifier): :param verbose: set to True to get information through the stdout """ - def __init__( - self, - model: BaseQuantifier, - param_grid: dict, - protocol: AbstractProtocol, - error: Union[Callable, str] = qp.error.mae, - refit=True, - timeout=-1, - n_jobs=None, - verbose=False, - ): + def __init__(self, + model: BaseQuantifier, + param_grid: dict, + protocol: AbstractProtocol, + error: Union[Callable, str] = qp.error.mae, + refit=True, + timeout=-1, + n_jobs=None, + verbose=False): + self.model = model self.param_grid = param_grid self.protocol = protocol @@ -53,27 +52,25 @@ class GridSearchQ(BaseQuantifier): self.n_jobs = qp._get_njobs(n_jobs) self.verbose = verbose self.__check_error(error) - assert isinstance(protocol, AbstractProtocol), "unknown protocol" + assert isinstance(protocol, AbstractProtocol), 'unknown protocol' def _sout(self, msg): if self.verbose: - print(f"[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}") + print(f'[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}') def __check_error(self, error): if error in qp.error.QUANTIFICATION_ERROR: self.error = error elif isinstance(error, str): self.error = qp.error.from_name(error) - elif hasattr(error, "__call__"): + elif hasattr(error, '__call__'): self.error = error else: - raise ValueError( - f"unexpected error type; must either be a callable function or a str representing\n" - f"the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}" - ) + raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n' + f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}') def fit(self, training: LabelledCollection): - """Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing + """ Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing the error metric. :param training: the training set on which to optimize the hyperparameters @@ -89,17 +86,14 @@ class GridSearchQ(BaseQuantifier): tinit = time() - hyper = [ - dict({k: val[i] for i, k in enumerate(params_keys)}) - for val in itertools.product(*params_values) - ] - self._sout(f"starting model selection with {self.n_jobs =}") - # pass a seed to parallel so it is set in clild processes + hyper = [dict({k: val[i] for i, k in enumerate(params_keys)}) for val in itertools.product(*params_values)] + self._sout(f'starting model selection with {self.n_jobs =}') + #pass a seed to parallel so it is set in clild processes scores = qp.util.parallel( self._delayed_eval, ((params, training) for params in hyper), - seed=qp.environ.get("_R_SEED", None), - n_jobs=self.n_jobs, + seed=qp.environ.get('_R_SEED', None), + n_jobs=self.n_jobs ) for params, score, model in scores: @@ -110,27 +104,23 @@ class GridSearchQ(BaseQuantifier): self.best_model_ = model self.param_scores_[str(params)] = score else: - self.param_scores_[str(params)] = "timeout" + self.param_scores_[str(params)] = 'timeout' - tend = time() - tinit + tend = time()-tinit if self.best_score_ is None: - raise TimeoutError("no combination of hyperparameters seem to work") + raise TimeoutError('no combination of hyperparameters seem to work') - self._sout( - f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) " - f"[took {tend:.4f}s]" - ) + self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) ' + f'[took {tend:.4f}s]') if self.refit: if isinstance(protocol, OnLabelledCollectionProtocol): - self._sout(f"refitting on the whole development set") + self._sout(f'refitting on the whole development set') self.best_model_.fit(training + protocol.get_labelled_collection()) else: - raise RuntimeWarning( - f'"refit" was requested, but the protocol does not ' - f"implement the {OnLabelledCollectionProtocol.__name__} interface" - ) + raise RuntimeWarning(f'"refit" was requested, but the protocol does not ' + f'implement the {OnLabelledCollectionProtocol.__name__} interface') return self @@ -141,7 +131,6 @@ class GridSearchQ(BaseQuantifier): error = self.error if self.timeout > 0: - def handler(signum, frame): raise TimeoutError() @@ -159,26 +148,25 @@ class GridSearchQ(BaseQuantifier): model.fit(training) score = evaluation.evaluate(model, protocol=protocol, error_metric=error) - ttime = time() - tinit - self._sout( - f"hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]" - ) + ttime = time()-tinit + self._sout(f'hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]') if self.timeout > 0: signal.alarm(0) except TimeoutError: - self._sout(f"timeout ({self.timeout}s) reached for config {params}") + self._sout(f'timeout ({self.timeout}s) reached for config {params}') score = None except ValueError as e: - self._sout(f"the combination of hyperparameters {params} is invalid") + self._sout(f'the combination of hyperparameters {params} is invalid') raise e except Exception as e: - self._sout(f"something went wrong for config {params}; skipping:") - self._sout(f"\tException: {e}") + self._sout(f'something went wrong for config {params}; skipping:') + self._sout(f'\tException: {e}') score = None return params, score, model + def quantify(self, instances): """Estimate class prevalence values using the best model found after calling the :meth:`fit` method. @@ -186,7 +174,7 @@ class GridSearchQ(BaseQuantifier): :return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found by the model selection process. """ - assert hasattr(self, "best_model_"), "quantify called before fit" + assert hasattr(self, 'best_model_'), 'quantify called before fit' return self.best_model().quantify(instances) def set_params(self, **parameters): @@ -211,14 +199,14 @@ class GridSearchQ(BaseQuantifier): :return: a trained quantifier """ - if hasattr(self, "best_model_"): + if hasattr(self, 'best_model_'): return self.best_model_ - raise ValueError("best_model called before fit") + raise ValueError('best_model called before fit') -def cross_val_predict( - quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0 -): + + +def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0): """ Akin to `scikit-learn's cross_val_predict `_ but for quantification. @@ -235,7 +223,9 @@ def cross_val_predict( for train, test in data.kFCV(nfolds=nfolds, random_state=random_state): quantifier.fit(train) fold_prev = quantifier.quantify(test.X) - rel_size = 1.0 * len(test) / len(data) - total_prev += fold_prev * rel_size + rel_size = 1. * len(test) / len(data) + total_prev += fold_prev*rel_size return total_prev + +