fix added for cross_val_predict

This commit is contained in:
Lorenzo Volpi 2023-11-06 01:58:36 +01:00
parent 51c3d54aa5
commit 13fe531e12
1 changed files with 46 additions and 56 deletions

View File

@ -34,8 +34,7 @@ class GridSearchQ(BaseQuantifier):
:param verbose: set to True to get information through the stdout :param verbose: set to True to get information through the stdout
""" """
def __init__( def __init__(self,
self,
model: BaseQuantifier, model: BaseQuantifier,
param_grid: dict, param_grid: dict,
protocol: AbstractProtocol, protocol: AbstractProtocol,
@ -43,8 +42,8 @@ class GridSearchQ(BaseQuantifier):
refit=True, refit=True,
timeout=-1, timeout=-1,
n_jobs=None, n_jobs=None,
verbose=False, verbose=False):
):
self.model = model self.model = model
self.param_grid = param_grid self.param_grid = param_grid
self.protocol = protocol self.protocol = protocol
@ -53,27 +52,25 @@ class GridSearchQ(BaseQuantifier):
self.n_jobs = qp._get_njobs(n_jobs) self.n_jobs = qp._get_njobs(n_jobs)
self.verbose = verbose self.verbose = verbose
self.__check_error(error) self.__check_error(error)
assert isinstance(protocol, AbstractProtocol), "unknown protocol" assert isinstance(protocol, AbstractProtocol), 'unknown protocol'
def _sout(self, msg): def _sout(self, msg):
if self.verbose: if self.verbose:
print(f"[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}") print(f'[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}')
def __check_error(self, error): def __check_error(self, error):
if error in qp.error.QUANTIFICATION_ERROR: if error in qp.error.QUANTIFICATION_ERROR:
self.error = error self.error = error
elif isinstance(error, str): elif isinstance(error, str):
self.error = qp.error.from_name(error) self.error = qp.error.from_name(error)
elif hasattr(error, "__call__"): elif hasattr(error, '__call__'):
self.error = error self.error = error
else: else:
raise ValueError( raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
f"unexpected error type; must either be a callable function or a str representing\n" f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
f"the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}"
)
def fit(self, training: LabelledCollection): def fit(self, training: LabelledCollection):
"""Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing """ Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
the error metric. the error metric.
:param training: the training set on which to optimize the hyperparameters :param training: the training set on which to optimize the hyperparameters
@ -89,17 +86,14 @@ class GridSearchQ(BaseQuantifier):
tinit = time() tinit = time()
hyper = [ hyper = [dict({k: val[i] for i, k in enumerate(params_keys)}) for val in itertools.product(*params_values)]
dict({k: val[i] for i, k in enumerate(params_keys)}) self._sout(f'starting model selection with {self.n_jobs =}')
for val in itertools.product(*params_values) #pass a seed to parallel so it is set in clild processes
]
self._sout(f"starting model selection with {self.n_jobs =}")
# pass a seed to parallel so it is set in clild processes
scores = qp.util.parallel( scores = qp.util.parallel(
self._delayed_eval, self._delayed_eval,
((params, training) for params in hyper), ((params, training) for params in hyper),
seed=qp.environ.get("_R_SEED", None), seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs, n_jobs=self.n_jobs
) )
for params, score, model in scores: for params, score, model in scores:
@ -110,27 +104,23 @@ class GridSearchQ(BaseQuantifier):
self.best_model_ = model self.best_model_ = model
self.param_scores_[str(params)] = score self.param_scores_[str(params)] = score
else: else:
self.param_scores_[str(params)] = "timeout" self.param_scores_[str(params)] = 'timeout'
tend = time() - tinit tend = time()-tinit
if self.best_score_ is None: if self.best_score_ is None:
raise TimeoutError("no combination of hyperparameters seem to work") raise TimeoutError('no combination of hyperparameters seem to work')
self._sout( self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) '
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) " f'[took {tend:.4f}s]')
f"[took {tend:.4f}s]"
)
if self.refit: if self.refit:
if isinstance(protocol, OnLabelledCollectionProtocol): if isinstance(protocol, OnLabelledCollectionProtocol):
self._sout(f"refitting on the whole development set") self._sout(f'refitting on the whole development set')
self.best_model_.fit(training + protocol.get_labelled_collection()) self.best_model_.fit(training + protocol.get_labelled_collection())
else: else:
raise RuntimeWarning( raise RuntimeWarning(f'"refit" was requested, but the protocol does not '
f'"refit" was requested, but the protocol does not ' f'implement the {OnLabelledCollectionProtocol.__name__} interface')
f"implement the {OnLabelledCollectionProtocol.__name__} interface"
)
return self return self
@ -141,7 +131,6 @@ class GridSearchQ(BaseQuantifier):
error = self.error error = self.error
if self.timeout > 0: if self.timeout > 0:
def handler(signum, frame): def handler(signum, frame):
raise TimeoutError() raise TimeoutError()
@ -159,26 +148,25 @@ class GridSearchQ(BaseQuantifier):
model.fit(training) model.fit(training)
score = evaluation.evaluate(model, protocol=protocol, error_metric=error) score = evaluation.evaluate(model, protocol=protocol, error_metric=error)
ttime = time() - tinit ttime = time()-tinit
self._sout( self._sout(f'hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]')
f"hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]"
)
if self.timeout > 0: if self.timeout > 0:
signal.alarm(0) signal.alarm(0)
except TimeoutError: except TimeoutError:
self._sout(f"timeout ({self.timeout}s) reached for config {params}") self._sout(f'timeout ({self.timeout}s) reached for config {params}')
score = None score = None
except ValueError as e: except ValueError as e:
self._sout(f"the combination of hyperparameters {params} is invalid") self._sout(f'the combination of hyperparameters {params} is invalid')
raise e raise e
except Exception as e: except Exception as e:
self._sout(f"something went wrong for config {params}; skipping:") self._sout(f'something went wrong for config {params}; skipping:')
self._sout(f"\tException: {e}") self._sout(f'\tException: {e}')
score = None score = None
return params, score, model return params, score, model
def quantify(self, instances): def quantify(self, instances):
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method. """Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
@ -186,7 +174,7 @@ class GridSearchQ(BaseQuantifier):
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found :return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
by the model selection process. by the model selection process.
""" """
assert hasattr(self, "best_model_"), "quantify called before fit" assert hasattr(self, 'best_model_'), 'quantify called before fit'
return self.best_model().quantify(instances) return self.best_model().quantify(instances)
def set_params(self, **parameters): def set_params(self, **parameters):
@ -211,14 +199,14 @@ class GridSearchQ(BaseQuantifier):
:return: a trained quantifier :return: a trained quantifier
""" """
if hasattr(self, "best_model_"): if hasattr(self, 'best_model_'):
return self.best_model_ return self.best_model_
raise ValueError("best_model called before fit") raise ValueError('best_model called before fit')
def cross_val_predict(
quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0
): def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
""" """
Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_ Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_
but for quantification. but for quantification.
@ -235,7 +223,9 @@ def cross_val_predict(
for train, test in data.kFCV(nfolds=nfolds, random_state=random_state): for train, test in data.kFCV(nfolds=nfolds, random_state=random_state):
quantifier.fit(train) quantifier.fit(train)
fold_prev = quantifier.quantify(test.X) fold_prev = quantifier.quantify(test.X)
rel_size = 1.0 * len(test) / len(data) rel_size = 1. * len(test) / len(data)
total_prev += fold_prev * rel_size total_prev += fold_prev*rel_size
return total_prev return total_prev