fix added for cross_val_predict

This commit is contained in:
Lorenzo Volpi 2023-11-06 01:58:36 +01:00
parent 51c3d54aa5
commit 13fe531e12
1 changed files with 46 additions and 56 deletions

View File

@ -34,8 +34,7 @@ class GridSearchQ(BaseQuantifier):
:param verbose: set to True to get information through the stdout
"""
def __init__(
self,
def __init__(self,
model: BaseQuantifier,
param_grid: dict,
protocol: AbstractProtocol,
@ -43,8 +42,8 @@ class GridSearchQ(BaseQuantifier):
refit=True,
timeout=-1,
n_jobs=None,
verbose=False,
):
verbose=False):
self.model = model
self.param_grid = param_grid
self.protocol = protocol
@ -53,27 +52,25 @@ class GridSearchQ(BaseQuantifier):
self.n_jobs = qp._get_njobs(n_jobs)
self.verbose = verbose
self.__check_error(error)
assert isinstance(protocol, AbstractProtocol), "unknown protocol"
assert isinstance(protocol, AbstractProtocol), 'unknown protocol'
def _sout(self, msg):
if self.verbose:
print(f"[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}")
print(f'[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}')
def __check_error(self, error):
if error in qp.error.QUANTIFICATION_ERROR:
self.error = error
elif isinstance(error, str):
self.error = qp.error.from_name(error)
elif hasattr(error, "__call__"):
elif hasattr(error, '__call__'):
self.error = error
else:
raise ValueError(
f"unexpected error type; must either be a callable function or a str representing\n"
f"the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}"
)
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
def fit(self, training: LabelledCollection):
"""Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
""" Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
the error metric.
:param training: the training set on which to optimize the hyperparameters
@ -89,17 +86,14 @@ class GridSearchQ(BaseQuantifier):
tinit = time()
hyper = [
dict({k: val[i] for i, k in enumerate(params_keys)})
for val in itertools.product(*params_values)
]
self._sout(f"starting model selection with {self.n_jobs =}")
# pass a seed to parallel so it is set in clild processes
hyper = [dict({k: val[i] for i, k in enumerate(params_keys)}) for val in itertools.product(*params_values)]
self._sout(f'starting model selection with {self.n_jobs =}')
#pass a seed to parallel so it is set in clild processes
scores = qp.util.parallel(
self._delayed_eval,
((params, training) for params in hyper),
seed=qp.environ.get("_R_SEED", None),
n_jobs=self.n_jobs,
seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs
)
for params, score, model in scores:
@ -110,27 +104,23 @@ class GridSearchQ(BaseQuantifier):
self.best_model_ = model
self.param_scores_[str(params)] = score
else:
self.param_scores_[str(params)] = "timeout"
self.param_scores_[str(params)] = 'timeout'
tend = time() - tinit
tend = time()-tinit
if self.best_score_ is None:
raise TimeoutError("no combination of hyperparameters seem to work")
raise TimeoutError('no combination of hyperparameters seem to work')
self._sout(
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) "
f"[took {tend:.4f}s]"
)
self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) '
f'[took {tend:.4f}s]')
if self.refit:
if isinstance(protocol, OnLabelledCollectionProtocol):
self._sout(f"refitting on the whole development set")
self._sout(f'refitting on the whole development set')
self.best_model_.fit(training + protocol.get_labelled_collection())
else:
raise RuntimeWarning(
f'"refit" was requested, but the protocol does not '
f"implement the {OnLabelledCollectionProtocol.__name__} interface"
)
raise RuntimeWarning(f'"refit" was requested, but the protocol does not '
f'implement the {OnLabelledCollectionProtocol.__name__} interface')
return self
@ -141,7 +131,6 @@ class GridSearchQ(BaseQuantifier):
error = self.error
if self.timeout > 0:
def handler(signum, frame):
raise TimeoutError()
@ -159,26 +148,25 @@ class GridSearchQ(BaseQuantifier):
model.fit(training)
score = evaluation.evaluate(model, protocol=protocol, error_metric=error)
ttime = time() - tinit
self._sout(
f"hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]"
)
ttime = time()-tinit
self._sout(f'hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]')
if self.timeout > 0:
signal.alarm(0)
except TimeoutError:
self._sout(f"timeout ({self.timeout}s) reached for config {params}")
self._sout(f'timeout ({self.timeout}s) reached for config {params}')
score = None
except ValueError as e:
self._sout(f"the combination of hyperparameters {params} is invalid")
self._sout(f'the combination of hyperparameters {params} is invalid')
raise e
except Exception as e:
self._sout(f"something went wrong for config {params}; skipping:")
self._sout(f"\tException: {e}")
self._sout(f'something went wrong for config {params}; skipping:')
self._sout(f'\tException: {e}')
score = None
return params, score, model
def quantify(self, instances):
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
@ -186,7 +174,7 @@ class GridSearchQ(BaseQuantifier):
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
by the model selection process.
"""
assert hasattr(self, "best_model_"), "quantify called before fit"
assert hasattr(self, 'best_model_'), 'quantify called before fit'
return self.best_model().quantify(instances)
def set_params(self, **parameters):
@ -211,14 +199,14 @@ class GridSearchQ(BaseQuantifier):
:return: a trained quantifier
"""
if hasattr(self, "best_model_"):
if hasattr(self, 'best_model_'):
return self.best_model_
raise ValueError("best_model called before fit")
raise ValueError('best_model called before fit')
def cross_val_predict(
quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0
):
def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
"""
Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_
but for quantification.
@ -235,7 +223,9 @@ def cross_val_predict(
for train, test in data.kFCV(nfolds=nfolds, random_state=random_state):
quantifier.fit(train)
fold_prev = quantifier.quantify(test.X)
rel_size = 1.0 * len(test) / len(data)
total_prev += fold_prev * rel_size
rel_size = 1. * len(test) / len(data)
total_prev += fold_prev*rel_size
return total_prev