diff --git a/quapy/model_selection.py b/quapy/model_selection.py index 4a53bb6..f77bee9 100644 --- a/quapy/model_selection.py +++ b/quapy/model_selection.py @@ -1,17 +1,17 @@ import itertools import signal from copy import deepcopy -from typing import Union, Callable +from time import time +from typing import Callable, Union import numpy as np from sklearn import clone import quapy as qp from quapy import evaluation -from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol from quapy.data.base import LabelledCollection from quapy.method.aggregative import BaseQuantifier -from time import time +from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol class GridSearchQ(BaseQuantifier): @@ -34,16 +34,17 @@ class GridSearchQ(BaseQuantifier): :param verbose: set to True to get information through the stdout """ - def __init__(self, - model: BaseQuantifier, - param_grid: dict, - protocol: AbstractProtocol, - error: Union[Callable, str] = qp.error.mae, - refit=True, - timeout=-1, - n_jobs=None, - verbose=False): - + def __init__( + self, + model: BaseQuantifier, + param_grid: dict, + protocol: AbstractProtocol, + error: Union[Callable, str] = qp.error.mae, + refit=True, + timeout=-1, + n_jobs=None, + verbose=False, + ): self.model = model self.param_grid = param_grid self.protocol = protocol @@ -52,25 +53,27 @@ class GridSearchQ(BaseQuantifier): self.n_jobs = qp._get_njobs(n_jobs) self.verbose = verbose self.__check_error(error) - assert isinstance(protocol, AbstractProtocol), 'unknown protocol' + assert isinstance(protocol, AbstractProtocol), "unknown protocol" def _sout(self, msg): if self.verbose: - print(f'[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}') + print(f"[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}") def __check_error(self, error): if error in qp.error.QUANTIFICATION_ERROR: self.error = error elif isinstance(error, str): self.error = qp.error.from_name(error) - elif hasattr(error, '__call__'): + elif hasattr(error, "__call__"): self.error = error else: - raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n' - f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}') + raise ValueError( + f"unexpected error type; must either be a callable function or a str representing\n" + f"the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}" + ) def fit(self, training: LabelledCollection): - """ Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing + """Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing the error metric. :param training: the training set on which to optimize the hyperparameters @@ -86,14 +89,17 @@ class GridSearchQ(BaseQuantifier): tinit = time() - hyper = [dict({k: val[i] for i, k in enumerate(params_keys)}) for val in itertools.product(*params_values)] - self._sout(f'starting model selection with {self.n_jobs =}') - #pass a seed to parallel so it is set in clild processes + hyper = [ + dict({k: val[i] for i, k in enumerate(params_keys)}) + for val in itertools.product(*params_values) + ] + self._sout(f"starting model selection with {self.n_jobs =}") + # pass a seed to parallel so it is set in clild processes scores = qp.util.parallel( self._delayed_eval, ((params, training) for params in hyper), - seed=qp.environ.get('_R_SEED', None), - n_jobs=self.n_jobs + seed=qp.environ.get("_R_SEED", None), + n_jobs=self.n_jobs, ) for params, score, model in scores: @@ -104,23 +110,27 @@ class GridSearchQ(BaseQuantifier): self.best_model_ = model self.param_scores_[str(params)] = score else: - self.param_scores_[str(params)] = 'timeout' + self.param_scores_[str(params)] = "timeout" - tend = time()-tinit + tend = time() - tinit if self.best_score_ is None: - raise TimeoutError('no combination of hyperparameters seem to work') + raise TimeoutError("no combination of hyperparameters seem to work") - self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) ' - f'[took {tend:.4f}s]') + self._sout( + f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) " + f"[took {tend:.4f}s]" + ) if self.refit: if isinstance(protocol, OnLabelledCollectionProtocol): - self._sout(f'refitting on the whole development set') + self._sout(f"refitting on the whole development set") self.best_model_.fit(training + protocol.get_labelled_collection()) else: - raise RuntimeWarning(f'"refit" was requested, but the protocol does not ' - f'implement the {OnLabelledCollectionProtocol.__name__} interface') + raise RuntimeWarning( + f'"refit" was requested, but the protocol does not ' + f"implement the {OnLabelledCollectionProtocol.__name__} interface" + ) return self @@ -131,6 +141,7 @@ class GridSearchQ(BaseQuantifier): error = self.error if self.timeout > 0: + def handler(signum, frame): raise TimeoutError() @@ -148,25 +159,26 @@ class GridSearchQ(BaseQuantifier): model.fit(training) score = evaluation.evaluate(model, protocol=protocol, error_metric=error) - ttime = time()-tinit - self._sout(f'hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]') + ttime = time() - tinit + self._sout( + f"hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]" + ) if self.timeout > 0: signal.alarm(0) except TimeoutError: - self._sout(f'timeout ({self.timeout}s) reached for config {params}') + self._sout(f"timeout ({self.timeout}s) reached for config {params}") score = None except ValueError as e: - self._sout(f'the combination of hyperparameters {params} is invalid') + self._sout(f"the combination of hyperparameters {params} is invalid") raise e except Exception as e: - self._sout(f'something went wrong for config {params}; skipping:') - self._sout(f'\tException: {e}') + self._sout(f"something went wrong for config {params}; skipping:") + self._sout(f"\tException: {e}") score = None return params, score, model - def quantify(self, instances): """Estimate class prevalence values using the best model found after calling the :meth:`fit` method. @@ -174,7 +186,7 @@ class GridSearchQ(BaseQuantifier): :return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found by the model selection process. """ - assert hasattr(self, 'best_model_'), 'quantify called before fit' + assert hasattr(self, "best_model_"), "quantify called before fit" return self.best_model().quantify(instances) def set_params(self, **parameters): @@ -199,14 +211,14 @@ class GridSearchQ(BaseQuantifier): :return: a trained quantifier """ - if hasattr(self, 'best_model_'): + if hasattr(self, "best_model_"): return self.best_model_ - raise ValueError('best_model called before fit') + raise ValueError("best_model called before fit") - - -def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0): +def cross_val_predict( + quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0 +): """ Akin to `scikit-learn's cross_val_predict `_ but for quantification. @@ -223,9 +235,7 @@ def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfol for train, test in data.kFCV(nfolds=nfolds, random_state=random_state): quantifier.fit(train) fold_prev = quantifier.quantify(test.X) - rel_size = len(test.X)/len(data) - total_prev += fold_prev*rel_size + rel_size = 1.0 * len(test) / len(data) + total_prev += fold_prev * rel_size return total_prev - -