2020-12-23 11:14:35 +01:00
|
|
|
import itertools
|
2021-01-15 18:32:32 +01:00
|
|
|
import signal
|
|
|
|
from copy import deepcopy
|
|
|
|
from typing import Union, Callable
|
2022-05-25 19:14:33 +02:00
|
|
|
import evaluation
|
2020-12-23 11:14:35 +01:00
|
|
|
import quapy as qp
|
2022-05-25 19:14:33 +02:00
|
|
|
from protocol import AbstractProtocol, OnLabelledCollectionProtocol
|
2021-03-19 17:34:09 +01:00
|
|
|
from quapy.data.base import LabelledCollection
|
2021-01-15 18:32:32 +01:00
|
|
|
from quapy.method.aggregative import BaseQuantifier
|
2022-05-25 19:14:33 +02:00
|
|
|
from time import time
|
2020-12-23 11:14:35 +01:00
|
|
|
|
|
|
|
|
2021-01-06 14:58:29 +01:00
|
|
|
class GridSearchQ(BaseQuantifier):
|
2021-11-09 15:44:57 +01:00
|
|
|
"""Grid Search optimization targeting a quantification-oriented metric.
|
2020-12-23 11:14:35 +01:00
|
|
|
|
2021-11-09 15:44:57 +01:00
|
|
|
Optimizes the hyperparameters of a quantification method, based on an evaluation method and on an evaluation
|
|
|
|
protocol for quantification.
|
|
|
|
|
|
|
|
:param model: the quantifier to optimize
|
|
|
|
:type model: BaseQuantifier
|
|
|
|
:param param_grid: a dictionary with keys the parameter names and values the list of values to explore
|
2022-05-25 19:14:33 +02:00
|
|
|
:param protocol:
|
2021-11-09 15:44:57 +01:00
|
|
|
:param error: an error function (callable) or a string indicating the name of an error function (valid ones
|
2020-12-23 11:14:35 +01:00
|
|
|
are those in qp.error.QUANTIFICATION_ERROR
|
2021-11-09 15:44:57 +01:00
|
|
|
:param refit: whether or not to refit the model on the whole labelled collection (training+validation) with
|
2021-11-08 18:01:49 +01:00
|
|
|
the best chosen hyperparameter combination. Ignored if protocol='gen'
|
2021-11-09 15:44:57 +01:00
|
|
|
:param timeout: establishes a timer (in seconds) for each of the hyperparameters configurations being tested.
|
2021-01-15 17:42:19 +01:00
|
|
|
Whenever a run takes longer than this timer, that configuration will be ignored. If all configurations end up
|
|
|
|
being ignored, a TimeoutError exception is raised. If -1 (default) then no time bound is set.
|
2021-11-09 15:44:57 +01:00
|
|
|
:param verbose: set to True to get information through the stdout
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
model: BaseQuantifier,
|
|
|
|
param_grid: dict,
|
2022-05-25 19:14:33 +02:00
|
|
|
protocol: AbstractProtocol,
|
2021-11-09 15:44:57 +01:00
|
|
|
error: Union[Callable, str] = qp.error.mae,
|
|
|
|
refit=True,
|
|
|
|
timeout=-1,
|
2022-05-25 19:14:33 +02:00
|
|
|
n_jobs=1,
|
2021-11-09 15:44:57 +01:00
|
|
|
verbose=False):
|
|
|
|
|
2020-12-23 11:14:35 +01:00
|
|
|
self.model = model
|
|
|
|
self.param_grid = param_grid
|
2022-05-25 19:14:33 +02:00
|
|
|
self.protocol = protocol
|
2020-12-23 11:14:35 +01:00
|
|
|
self.refit = refit
|
2021-01-15 17:42:19 +01:00
|
|
|
self.timeout = timeout
|
2022-05-25 19:14:33 +02:00
|
|
|
self.n_jobs = n_jobs
|
2020-12-23 11:14:35 +01:00
|
|
|
self.verbose = verbose
|
|
|
|
self.__check_error(error)
|
2022-05-25 19:14:33 +02:00
|
|
|
assert isinstance(protocol, AbstractProtocol), 'unknown protocol'
|
2020-12-23 11:14:35 +01:00
|
|
|
|
2021-11-09 15:44:57 +01:00
|
|
|
def _sout(self, msg):
|
2020-12-23 11:14:35 +01:00
|
|
|
if self.verbose:
|
|
|
|
print(f'[{self.__class__.__name__}]: {msg}')
|
|
|
|
|
|
|
|
def __check_error(self, error):
|
|
|
|
if error in qp.error.QUANTIFICATION_ERROR:
|
|
|
|
self.error = error
|
|
|
|
elif isinstance(error, str):
|
2021-01-27 09:54:41 +01:00
|
|
|
self.error = qp.error.from_name(error)
|
|
|
|
elif hasattr(error, '__call__'):
|
|
|
|
self.error = error
|
2020-12-23 11:14:35 +01:00
|
|
|
else:
|
|
|
|
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
2021-01-06 14:58:29 +01:00
|
|
|
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
|
2020-12-23 11:14:35 +01:00
|
|
|
|
2022-05-25 19:14:33 +02:00
|
|
|
def fit(self, training: LabelledCollection):
|
2021-11-09 15:44:57 +01:00
|
|
|
""" Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
|
|
|
|
the error metric.
|
2021-11-24 11:20:42 +01:00
|
|
|
|
2020-12-23 11:14:35 +01:00
|
|
|
:param training: the training set on which to optimize the hyperparameters
|
2021-11-24 11:20:42 +01:00
|
|
|
:return: self
|
2020-12-23 11:14:35 +01:00
|
|
|
"""
|
|
|
|
params_keys = list(self.param_grid.keys())
|
|
|
|
params_values = list(self.param_grid.values())
|
|
|
|
|
2022-05-25 19:14:33 +02:00
|
|
|
protocol = self.protocol
|
|
|
|
n_jobs = self.n_jobs
|
2021-01-15 17:42:19 +01:00
|
|
|
|
2020-12-23 11:14:35 +01:00
|
|
|
self.param_scores_ = {}
|
|
|
|
self.best_score_ = None
|
2021-01-15 17:42:19 +01:00
|
|
|
|
2022-05-25 19:14:33 +02:00
|
|
|
hyper = [dict({k: values[i] for i, k in enumerate(params_keys)}) for values in itertools.product(*params_values)]
|
|
|
|
scores = qp.util.parallel(self._delayed_eval, ((params, training) for params in hyper), n_jobs=n_jobs)
|
2021-11-26 10:57:49 +01:00
|
|
|
|
2022-05-25 19:14:33 +02:00
|
|
|
for params, score, model in scores:
|
|
|
|
if score is not None:
|
2021-01-15 17:42:19 +01:00
|
|
|
if self.best_score_ is None or score < self.best_score_:
|
|
|
|
self.best_score_ = score
|
|
|
|
self.best_params_ = params
|
2022-05-25 19:14:33 +02:00
|
|
|
self.best_model_ = model
|
2021-01-15 17:42:19 +01:00
|
|
|
self.param_scores_[str(params)] = score
|
2022-05-25 19:14:33 +02:00
|
|
|
else:
|
|
|
|
self.param_scores_[str(params)] = 'timeout'
|
2021-01-15 17:42:19 +01:00
|
|
|
|
2022-05-25 19:14:33 +02:00
|
|
|
if self.best_score_ is None:
|
2021-01-15 17:42:19 +01:00
|
|
|
raise TimeoutError('all jobs took more than the timeout time to end')
|
2020-12-23 11:14:35 +01:00
|
|
|
|
2021-11-09 15:44:57 +01:00
|
|
|
self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f})')
|
2020-12-23 11:14:35 +01:00
|
|
|
|
|
|
|
if self.refit:
|
2022-05-25 19:14:33 +02:00
|
|
|
if isinstance(protocol, OnLabelledCollectionProtocol):
|
|
|
|
self._sout(f'refitting on the whole development set')
|
|
|
|
self.best_model_.fit(training + protocol.get_labelled_collection())
|
|
|
|
else:
|
|
|
|
raise RuntimeWarning(f'"refit" was requested, but the protocol does not '
|
|
|
|
f'implement the {OnLabelledCollectionProtocol.__name__} interface')
|
2020-12-23 11:14:35 +01:00
|
|
|
|
2021-01-06 14:58:29 +01:00
|
|
|
return self
|
|
|
|
|
2022-05-25 19:14:33 +02:00
|
|
|
def _delayed_eval(self, args):
|
|
|
|
params, training = args
|
|
|
|
|
|
|
|
protocol = self.protocol
|
|
|
|
error = self.error
|
|
|
|
|
|
|
|
if self.timeout > 0:
|
|
|
|
def handler(signum, frame):
|
|
|
|
raise TimeoutError()
|
|
|
|
|
|
|
|
signal.signal(signal.SIGALRM, handler)
|
|
|
|
|
|
|
|
tinit = time()
|
|
|
|
|
|
|
|
if self.timeout > 0:
|
|
|
|
signal.alarm(self.timeout)
|
|
|
|
|
|
|
|
try:
|
|
|
|
model = deepcopy(self.model)
|
|
|
|
# overrides default parameters with the parameters being explored at this iteration
|
|
|
|
model.set_params(**params)
|
|
|
|
model.fit(training)
|
|
|
|
score = evaluation.evaluate(model, protocol=protocol, error_metric=error)
|
|
|
|
|
|
|
|
ttime = time()-tinit
|
|
|
|
self._sout(f'hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]')
|
|
|
|
|
|
|
|
if self.timeout > 0:
|
|
|
|
signal.alarm(0)
|
|
|
|
except TimeoutError:
|
|
|
|
self._sout(f'timeout ({self.timeout}s) reached for config {params}')
|
|
|
|
score = None
|
|
|
|
|
|
|
|
return params, score, model
|
|
|
|
|
|
|
|
|
2021-01-06 14:58:29 +01:00
|
|
|
def quantify(self, instances):
|
2021-11-24 11:20:42 +01:00
|
|
|
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
|
2021-11-09 15:44:57 +01:00
|
|
|
|
|
|
|
:param instances: sample contanining the instances
|
2021-11-24 11:20:42 +01:00
|
|
|
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
|
|
|
|
by the model selection process.
|
2021-11-09 15:44:57 +01:00
|
|
|
"""
|
2021-06-16 11:45:40 +02:00
|
|
|
assert hasattr(self, 'best_model_'), 'quantify called before fit'
|
2021-10-26 18:41:10 +02:00
|
|
|
return self.best_model().quantify(instances)
|
2021-01-06 14:58:29 +01:00
|
|
|
|
|
|
|
def set_params(self, **parameters):
|
2021-11-09 15:44:57 +01:00
|
|
|
"""Sets the hyper-parameters to explore.
|
|
|
|
|
|
|
|
:param parameters: a dictionary with keys the parameter names and values the list of values to explore
|
|
|
|
"""
|
2021-01-06 14:58:29 +01:00
|
|
|
self.param_grid = parameters
|
|
|
|
|
|
|
|
def get_params(self, deep=True):
|
2021-11-09 15:44:57 +01:00
|
|
|
"""Returns the dictionary of hyper-parameters to explore (`param_grid`)
|
|
|
|
|
|
|
|
:param deep: Unused
|
|
|
|
:return: the dictionary `param_grid`
|
|
|
|
"""
|
2021-01-06 14:58:29 +01:00
|
|
|
return self.param_grid
|
2020-12-23 11:14:35 +01:00
|
|
|
|
2021-01-11 18:31:12 +01:00
|
|
|
def best_model(self):
|
2021-11-24 11:20:42 +01:00
|
|
|
"""
|
|
|
|
Returns the best model found after calling the :meth:`fit` method, i.e., the one trained on the combination
|
|
|
|
of hyper-parameters that minimized the error function.
|
|
|
|
|
|
|
|
:return: a trained quantifier
|
|
|
|
"""
|
2021-01-11 18:31:12 +01:00
|
|
|
if hasattr(self, 'best_model_'):
|
|
|
|
return self.best_model_
|
|
|
|
raise ValueError('best_model called before fit')
|
2022-05-25 19:14:33 +02:00
|
|
|
|
|
|
|
|