fix added for cross_val_predict

This commit is contained in:
Lorenzo Volpi 2023-11-06 01:58:36 +01:00
parent 51c3d54aa5
commit 13fe531e12
1 changed files with 46 additions and 56 deletions

View File

@ -34,8 +34,7 @@ class GridSearchQ(BaseQuantifier):
:param verbose: set to True to get information through the stdout :param verbose: set to True to get information through the stdout
""" """
def __init__( def __init__(self,
self,
model: BaseQuantifier, model: BaseQuantifier,
param_grid: dict, param_grid: dict,
protocol: AbstractProtocol, protocol: AbstractProtocol,
@ -43,8 +42,8 @@ class GridSearchQ(BaseQuantifier):
refit=True, refit=True,
timeout=-1, timeout=-1,
n_jobs=None, n_jobs=None,
verbose=False, verbose=False):
):
self.model = model self.model = model
self.param_grid = param_grid self.param_grid = param_grid
self.protocol = protocol self.protocol = protocol
@ -53,24 +52,22 @@ class GridSearchQ(BaseQuantifier):
self.n_jobs = qp._get_njobs(n_jobs) self.n_jobs = qp._get_njobs(n_jobs)
self.verbose = verbose self.verbose = verbose
self.__check_error(error) self.__check_error(error)
assert isinstance(protocol, AbstractProtocol), "unknown protocol" assert isinstance(protocol, AbstractProtocol), 'unknown protocol'
def _sout(self, msg): def _sout(self, msg):
if self.verbose: if self.verbose:
print(f"[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}") print(f'[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}')
def __check_error(self, error): def __check_error(self, error):
if error in qp.error.QUANTIFICATION_ERROR: if error in qp.error.QUANTIFICATION_ERROR:
self.error = error self.error = error
elif isinstance(error, str): elif isinstance(error, str):
self.error = qp.error.from_name(error) self.error = qp.error.from_name(error)
elif hasattr(error, "__call__"): elif hasattr(error, '__call__'):
self.error = error self.error = error
else: else:
raise ValueError( raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
f"unexpected error type; must either be a callable function or a str representing\n" f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
f"the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}"
)
def fit(self, training: LabelledCollection): def fit(self, training: LabelledCollection):
""" Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing """ Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
@ -89,17 +86,14 @@ class GridSearchQ(BaseQuantifier):
tinit = time() tinit = time()
hyper = [ hyper = [dict({k: val[i] for i, k in enumerate(params_keys)}) for val in itertools.product(*params_values)]
dict({k: val[i] for i, k in enumerate(params_keys)}) self._sout(f'starting model selection with {self.n_jobs =}')
for val in itertools.product(*params_values)
]
self._sout(f"starting model selection with {self.n_jobs =}")
#pass a seed to parallel so it is set in clild processes #pass a seed to parallel so it is set in clild processes
scores = qp.util.parallel( scores = qp.util.parallel(
self._delayed_eval, self._delayed_eval,
((params, training) for params in hyper), ((params, training) for params in hyper),
seed=qp.environ.get("_R_SEED", None), seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs, n_jobs=self.n_jobs
) )
for params, score, model in scores: for params, score, model in scores:
@ -110,27 +104,23 @@ class GridSearchQ(BaseQuantifier):
self.best_model_ = model self.best_model_ = model
self.param_scores_[str(params)] = score self.param_scores_[str(params)] = score
else: else:
self.param_scores_[str(params)] = "timeout" self.param_scores_[str(params)] = 'timeout'
tend = time()-tinit tend = time()-tinit
if self.best_score_ is None: if self.best_score_ is None:
raise TimeoutError("no combination of hyperparameters seem to work") raise TimeoutError('no combination of hyperparameters seem to work')
self._sout( self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) '
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) " f'[took {tend:.4f}s]')
f"[took {tend:.4f}s]"
)
if self.refit: if self.refit:
if isinstance(protocol, OnLabelledCollectionProtocol): if isinstance(protocol, OnLabelledCollectionProtocol):
self._sout(f"refitting on the whole development set") self._sout(f'refitting on the whole development set')
self.best_model_.fit(training + protocol.get_labelled_collection()) self.best_model_.fit(training + protocol.get_labelled_collection())
else: else:
raise RuntimeWarning( raise RuntimeWarning(f'"refit" was requested, but the protocol does not '
f'"refit" was requested, but the protocol does not ' f'implement the {OnLabelledCollectionProtocol.__name__} interface')
f"implement the {OnLabelledCollectionProtocol.__name__} interface"
)
return self return self
@ -141,7 +131,6 @@ class GridSearchQ(BaseQuantifier):
error = self.error error = self.error
if self.timeout > 0: if self.timeout > 0:
def handler(signum, frame): def handler(signum, frame):
raise TimeoutError() raise TimeoutError()
@ -160,25 +149,24 @@ class GridSearchQ(BaseQuantifier):
score = evaluation.evaluate(model, protocol=protocol, error_metric=error) score = evaluation.evaluate(model, protocol=protocol, error_metric=error)
ttime = time()-tinit ttime = time()-tinit
self._sout( self._sout(f'hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]')
f"hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]"
)
if self.timeout > 0: if self.timeout > 0:
signal.alarm(0) signal.alarm(0)
except TimeoutError: except TimeoutError:
self._sout(f"timeout ({self.timeout}s) reached for config {params}") self._sout(f'timeout ({self.timeout}s) reached for config {params}')
score = None score = None
except ValueError as e: except ValueError as e:
self._sout(f"the combination of hyperparameters {params} is invalid") self._sout(f'the combination of hyperparameters {params} is invalid')
raise e raise e
except Exception as e: except Exception as e:
self._sout(f"something went wrong for config {params}; skipping:") self._sout(f'something went wrong for config {params}; skipping:')
self._sout(f"\tException: {e}") self._sout(f'\tException: {e}')
score = None score = None
return params, score, model return params, score, model
def quantify(self, instances): def quantify(self, instances):
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method. """Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
@ -186,7 +174,7 @@ class GridSearchQ(BaseQuantifier):
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found :return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
by the model selection process. by the model selection process.
""" """
assert hasattr(self, "best_model_"), "quantify called before fit" assert hasattr(self, 'best_model_'), 'quantify called before fit'
return self.best_model().quantify(instances) return self.best_model().quantify(instances)
def set_params(self, **parameters): def set_params(self, **parameters):
@ -211,14 +199,14 @@ class GridSearchQ(BaseQuantifier):
:return: a trained quantifier :return: a trained quantifier
""" """
if hasattr(self, "best_model_"): if hasattr(self, 'best_model_'):
return self.best_model_ return self.best_model_
raise ValueError("best_model called before fit") raise ValueError('best_model called before fit')
def cross_val_predict(
quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0
): def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
""" """
Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_ Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_
but for quantification. but for quantification.
@ -235,7 +223,9 @@ def cross_val_predict(
for train, test in data.kFCV(nfolds=nfolds, random_state=random_state): for train, test in data.kFCV(nfolds=nfolds, random_state=random_state):
quantifier.fit(train) quantifier.fit(train)
fold_prev = quantifier.quantify(test.X) fold_prev = quantifier.quantify(test.X)
rel_size = 1.0 * len(test) / len(data) rel_size = 1. * len(test) / len(data)
total_prev += fold_prev*rel_size total_prev += fold_prev*rel_size
return total_prev return total_prev