fix added for cross_val_predict
This commit is contained in:
parent
51c3d54aa5
commit
13fe531e12
|
@ -34,17 +34,16 @@ class GridSearchQ(BaseQuantifier):
|
|||
:param verbose: set to True to get information through the stdout
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: BaseQuantifier,
|
||||
param_grid: dict,
|
||||
protocol: AbstractProtocol,
|
||||
error: Union[Callable, str] = qp.error.mae,
|
||||
refit=True,
|
||||
timeout=-1,
|
||||
n_jobs=None,
|
||||
verbose=False,
|
||||
):
|
||||
def __init__(self,
|
||||
model: BaseQuantifier,
|
||||
param_grid: dict,
|
||||
protocol: AbstractProtocol,
|
||||
error: Union[Callable, str] = qp.error.mae,
|
||||
refit=True,
|
||||
timeout=-1,
|
||||
n_jobs=None,
|
||||
verbose=False):
|
||||
|
||||
self.model = model
|
||||
self.param_grid = param_grid
|
||||
self.protocol = protocol
|
||||
|
@ -53,27 +52,25 @@ class GridSearchQ(BaseQuantifier):
|
|||
self.n_jobs = qp._get_njobs(n_jobs)
|
||||
self.verbose = verbose
|
||||
self.__check_error(error)
|
||||
assert isinstance(protocol, AbstractProtocol), "unknown protocol"
|
||||
assert isinstance(protocol, AbstractProtocol), 'unknown protocol'
|
||||
|
||||
def _sout(self, msg):
|
||||
if self.verbose:
|
||||
print(f"[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}")
|
||||
print(f'[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}')
|
||||
|
||||
def __check_error(self, error):
|
||||
if error in qp.error.QUANTIFICATION_ERROR:
|
||||
self.error = error
|
||||
elif isinstance(error, str):
|
||||
self.error = qp.error.from_name(error)
|
||||
elif hasattr(error, "__call__"):
|
||||
elif hasattr(error, '__call__'):
|
||||
self.error = error
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unexpected error type; must either be a callable function or a str representing\n"
|
||||
f"the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}"
|
||||
)
|
||||
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
||||
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
|
||||
|
||||
def fit(self, training: LabelledCollection):
|
||||
"""Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
|
||||
""" Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
|
||||
the error metric.
|
||||
|
||||
:param training: the training set on which to optimize the hyperparameters
|
||||
|
@ -89,17 +86,14 @@ class GridSearchQ(BaseQuantifier):
|
|||
|
||||
tinit = time()
|
||||
|
||||
hyper = [
|
||||
dict({k: val[i] for i, k in enumerate(params_keys)})
|
||||
for val in itertools.product(*params_values)
|
||||
]
|
||||
self._sout(f"starting model selection with {self.n_jobs =}")
|
||||
# pass a seed to parallel so it is set in clild processes
|
||||
hyper = [dict({k: val[i] for i, k in enumerate(params_keys)}) for val in itertools.product(*params_values)]
|
||||
self._sout(f'starting model selection with {self.n_jobs =}')
|
||||
#pass a seed to parallel so it is set in clild processes
|
||||
scores = qp.util.parallel(
|
||||
self._delayed_eval,
|
||||
((params, training) for params in hyper),
|
||||
seed=qp.environ.get("_R_SEED", None),
|
||||
n_jobs=self.n_jobs,
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs
|
||||
)
|
||||
|
||||
for params, score, model in scores:
|
||||
|
@ -110,27 +104,23 @@ class GridSearchQ(BaseQuantifier):
|
|||
self.best_model_ = model
|
||||
self.param_scores_[str(params)] = score
|
||||
else:
|
||||
self.param_scores_[str(params)] = "timeout"
|
||||
self.param_scores_[str(params)] = 'timeout'
|
||||
|
||||
tend = time() - tinit
|
||||
tend = time()-tinit
|
||||
|
||||
if self.best_score_ is None:
|
||||
raise TimeoutError("no combination of hyperparameters seem to work")
|
||||
raise TimeoutError('no combination of hyperparameters seem to work')
|
||||
|
||||
self._sout(
|
||||
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) "
|
||||
f"[took {tend:.4f}s]"
|
||||
)
|
||||
self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) '
|
||||
f'[took {tend:.4f}s]')
|
||||
|
||||
if self.refit:
|
||||
if isinstance(protocol, OnLabelledCollectionProtocol):
|
||||
self._sout(f"refitting on the whole development set")
|
||||
self._sout(f'refitting on the whole development set')
|
||||
self.best_model_.fit(training + protocol.get_labelled_collection())
|
||||
else:
|
||||
raise RuntimeWarning(
|
||||
f'"refit" was requested, but the protocol does not '
|
||||
f"implement the {OnLabelledCollectionProtocol.__name__} interface"
|
||||
)
|
||||
raise RuntimeWarning(f'"refit" was requested, but the protocol does not '
|
||||
f'implement the {OnLabelledCollectionProtocol.__name__} interface')
|
||||
|
||||
return self
|
||||
|
||||
|
@ -141,7 +131,6 @@ class GridSearchQ(BaseQuantifier):
|
|||
error = self.error
|
||||
|
||||
if self.timeout > 0:
|
||||
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError()
|
||||
|
||||
|
@ -159,26 +148,25 @@ class GridSearchQ(BaseQuantifier):
|
|||
model.fit(training)
|
||||
score = evaluation.evaluate(model, protocol=protocol, error_metric=error)
|
||||
|
||||
ttime = time() - tinit
|
||||
self._sout(
|
||||
f"hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]"
|
||||
)
|
||||
ttime = time()-tinit
|
||||
self._sout(f'hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]')
|
||||
|
||||
if self.timeout > 0:
|
||||
signal.alarm(0)
|
||||
except TimeoutError:
|
||||
self._sout(f"timeout ({self.timeout}s) reached for config {params}")
|
||||
self._sout(f'timeout ({self.timeout}s) reached for config {params}')
|
||||
score = None
|
||||
except ValueError as e:
|
||||
self._sout(f"the combination of hyperparameters {params} is invalid")
|
||||
self._sout(f'the combination of hyperparameters {params} is invalid')
|
||||
raise e
|
||||
except Exception as e:
|
||||
self._sout(f"something went wrong for config {params}; skipping:")
|
||||
self._sout(f"\tException: {e}")
|
||||
self._sout(f'something went wrong for config {params}; skipping:')
|
||||
self._sout(f'\tException: {e}')
|
||||
score = None
|
||||
|
||||
return params, score, model
|
||||
|
||||
|
||||
def quantify(self, instances):
|
||||
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
|
||||
|
||||
|
@ -186,7 +174,7 @@ class GridSearchQ(BaseQuantifier):
|
|||
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
|
||||
by the model selection process.
|
||||
"""
|
||||
assert hasattr(self, "best_model_"), "quantify called before fit"
|
||||
assert hasattr(self, 'best_model_'), 'quantify called before fit'
|
||||
return self.best_model().quantify(instances)
|
||||
|
||||
def set_params(self, **parameters):
|
||||
|
@ -211,14 +199,14 @@ class GridSearchQ(BaseQuantifier):
|
|||
|
||||
:return: a trained quantifier
|
||||
"""
|
||||
if hasattr(self, "best_model_"):
|
||||
if hasattr(self, 'best_model_'):
|
||||
return self.best_model_
|
||||
raise ValueError("best_model called before fit")
|
||||
raise ValueError('best_model called before fit')
|
||||
|
||||
|
||||
def cross_val_predict(
|
||||
quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0
|
||||
):
|
||||
|
||||
|
||||
def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
|
||||
"""
|
||||
Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_
|
||||
but for quantification.
|
||||
|
@ -235,7 +223,9 @@ def cross_val_predict(
|
|||
for train, test in data.kFCV(nfolds=nfolds, random_state=random_state):
|
||||
quantifier.fit(train)
|
||||
fold_prev = quantifier.quantify(test.X)
|
||||
rel_size = 1.0 * len(test) / len(data)
|
||||
total_prev += fold_prev * rel_size
|
||||
rel_size = 1. * len(test) / len(data)
|
||||
total_prev += fold_prev*rel_size
|
||||
|
||||
return total_prev
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue