2020-12-23 11:14:35 +01:00
|
|
|
import itertools
|
2021-01-15 18:32:32 +01:00
|
|
|
import signal
|
|
|
|
from copy import deepcopy
|
|
|
|
from typing import Union, Callable
|
|
|
|
|
2020-12-23 11:14:35 +01:00
|
|
|
import quapy as qp
|
2021-03-19 17:34:09 +01:00
|
|
|
from quapy.data.base import LabelledCollection
|
2021-10-26 18:41:10 +02:00
|
|
|
from quapy.evaluation import artificial_prevalence_prediction, natural_prevalence_prediction, gen_prevalence_prediction
|
2021-01-15 18:32:32 +01:00
|
|
|
from quapy.method.aggregative import BaseQuantifier
|
2021-10-26 18:41:10 +02:00
|
|
|
import inspect
|
2020-12-23 11:14:35 +01:00
|
|
|
|
|
|
|
|
2021-01-06 14:58:29 +01:00
|
|
|
class GridSearchQ(BaseQuantifier):
|
2021-11-09 15:44:57 +01:00
|
|
|
"""Grid Search optimization targeting a quantification-oriented metric.
|
2020-12-23 11:14:35 +01:00
|
|
|
|
2021-11-09 15:44:57 +01:00
|
|
|
Optimizes the hyperparameters of a quantification method, based on an evaluation method and on an evaluation
|
|
|
|
protocol for quantification.
|
|
|
|
|
|
|
|
:param model: the quantifier to optimize
|
|
|
|
:type model: BaseQuantifier
|
|
|
|
:param param_grid: a dictionary with keys the parameter names and values the list of values to explore
|
|
|
|
:param sample_size: the size of the samples to extract from the validation set (ignored if protocl='gen')
|
|
|
|
:param protocol: either 'app' for the artificial prevalence protocol, 'npp' for the natural prevalence
|
2021-11-08 18:01:49 +01:00
|
|
|
protocol, or 'gen' for using a custom sampling generator function
|
2021-11-09 15:44:57 +01:00
|
|
|
:param n_prevpoints: if specified, indicates the number of equally distant points to extract from the interval
|
2020-12-23 11:14:35 +01:00
|
|
|
[0,1] in order to define the prevalences of the samples; e.g., if n_prevpoints=5, then the prevalences for
|
2021-06-16 11:45:40 +02:00
|
|
|
each class will be explored in [0.00, 0.25, 0.50, 0.75, 1.00]. If not specified, then eval_budget is requested.
|
2021-11-08 18:01:49 +01:00
|
|
|
Ignored if protocol!='app'.
|
2021-11-09 15:44:57 +01:00
|
|
|
:param n_repetitions: the number of repetitions for each combination of prevalences. This parameter is ignored
|
2021-11-08 18:01:49 +01:00
|
|
|
for the protocol='app' if eval_budget is set and is lower than the number of combinations that would be
|
|
|
|
generated using the value assigned to n_prevpoints (for the current number of classes and n_repetitions).
|
|
|
|
Ignored for protocol='npp' and protocol='gen' (use eval_budget for setting a maximum number of samples in
|
|
|
|
those cases).
|
2021-11-09 15:44:57 +01:00
|
|
|
:param eval_budget: if specified, sets a ceil on the number of evaluations to perform for each hyper-parameter
|
2021-11-08 18:01:49 +01:00
|
|
|
combination. For example, if protocol='app', there are 3 classes, n_repetitions=1 and eval_budget=20, then
|
|
|
|
n_prevpoints will be set to 5, since this will generate 15 different prevalences, i.e., [0, 0, 1],
|
|
|
|
[0, 0.25, 0.75], [0, 0.5, 0.5] ... [1, 0, 0], and since setting it to 6 would generate more than
|
|
|
|
20. When protocol='gen', indicates the maximum number of samples to generate, but less samples will be
|
|
|
|
generated if the generator yields less samples.
|
2021-11-09 15:44:57 +01:00
|
|
|
:param error: an error function (callable) or a string indicating the name of an error function (valid ones
|
2020-12-23 11:14:35 +01:00
|
|
|
are those in qp.error.QUANTIFICATION_ERROR
|
2021-11-09 15:44:57 +01:00
|
|
|
:param refit: whether or not to refit the model on the whole labelled collection (training+validation) with
|
2021-11-08 18:01:49 +01:00
|
|
|
the best chosen hyperparameter combination. Ignored if protocol='gen'
|
2021-11-09 15:44:57 +01:00
|
|
|
:param val_split: either a LabelledCollection on which to test the performance of the different settings, or
|
2021-11-08 18:01:49 +01:00
|
|
|
a float in [0,1] indicating the proportion of labelled data to extract from the training set, or a callable
|
|
|
|
returning a generator function each time it is invoked (only for protocol='gen').
|
2021-11-09 15:44:57 +01:00
|
|
|
:param n_jobs: number of parallel jobs
|
|
|
|
:param random_seed: set the seed of the random generator to replicate experiments. Ignored if protocol='gen'.
|
|
|
|
:param timeout: establishes a timer (in seconds) for each of the hyperparameters configurations being tested.
|
2021-01-15 17:42:19 +01:00
|
|
|
Whenever a run takes longer than this timer, that configuration will be ignored. If all configurations end up
|
|
|
|
being ignored, a TimeoutError exception is raised. If -1 (default) then no time bound is set.
|
2021-11-09 15:44:57 +01:00
|
|
|
:param verbose: set to True to get information through the stdout
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
model: BaseQuantifier,
|
|
|
|
param_grid: dict,
|
|
|
|
sample_size: Union[int, None],
|
|
|
|
protocol='app',
|
|
|
|
n_prevpoints: int = None,
|
|
|
|
n_repetitions: int = 1,
|
|
|
|
eval_budget: int = None,
|
|
|
|
error: Union[Callable, str] = qp.error.mae,
|
|
|
|
refit=True,
|
|
|
|
val_split=0.4,
|
|
|
|
n_jobs=1,
|
|
|
|
random_seed=42,
|
|
|
|
timeout=-1,
|
|
|
|
verbose=False):
|
|
|
|
|
2020-12-23 11:14:35 +01:00
|
|
|
self.model = model
|
|
|
|
self.param_grid = param_grid
|
|
|
|
self.sample_size = sample_size
|
2021-06-16 11:45:40 +02:00
|
|
|
self.protocol = protocol.lower()
|
2020-12-23 11:14:35 +01:00
|
|
|
self.n_prevpoints = n_prevpoints
|
|
|
|
self.n_repetitions = n_repetitions
|
|
|
|
self.eval_budget = eval_budget
|
|
|
|
self.refit = refit
|
2021-01-22 18:01:51 +01:00
|
|
|
self.val_split = val_split
|
2020-12-23 11:14:35 +01:00
|
|
|
self.n_jobs = n_jobs
|
|
|
|
self.random_seed = random_seed
|
2021-01-15 17:42:19 +01:00
|
|
|
self.timeout = timeout
|
2020-12-23 11:14:35 +01:00
|
|
|
self.verbose = verbose
|
|
|
|
self.__check_error(error)
|
2021-10-26 18:41:10 +02:00
|
|
|
assert self.protocol in {'app', 'npp', 'gen'}, \
|
|
|
|
'unknown protocol: valid ones are "app" or "npp" for the "artificial" or the "natural" prevalence ' \
|
|
|
|
'protocols. Use protocol="gen" when passing a generator function thorough val_split that yields a ' \
|
|
|
|
'sample (instances) and their prevalence (ndarray) at each iteration.'
|
2021-11-08 18:01:49 +01:00
|
|
|
assert self.eval_budget is None or isinstance(self.eval_budget, int)
|
|
|
|
if self.protocol in ['npp', 'gen']:
|
|
|
|
if self.protocol=='npp' and (self.eval_budget is None or self.eval_budget <= 0):
|
|
|
|
raise ValueError(f'when protocol="npp" the parameter eval_budget should be '
|
|
|
|
f'indicated (and should be >0).')
|
2021-11-12 14:30:02 +01:00
|
|
|
if self.n_repetitions != 1:
|
|
|
|
print('[warning] n_repetitions has been set and will be ignored for the selected protocol')
|
2020-12-23 11:14:35 +01:00
|
|
|
|
2021-11-09 15:44:57 +01:00
|
|
|
def _sout(self, msg):
|
2020-12-23 11:14:35 +01:00
|
|
|
if self.verbose:
|
|
|
|
print(f'[{self.__class__.__name__}]: {msg}')
|
|
|
|
|
|
|
|
def __check_training_validation(self, training, validation):
|
|
|
|
if isinstance(validation, LabelledCollection):
|
|
|
|
return training, validation
|
|
|
|
elif isinstance(validation, float):
|
|
|
|
assert 0. < validation < 1., 'validation proportion should be in (0,1)'
|
2021-05-04 17:09:13 +02:00
|
|
|
training, validation = training.split_stratified(train_prop=1 - validation)
|
2020-12-23 11:14:35 +01:00
|
|
|
return training, validation
|
2021-10-26 18:41:10 +02:00
|
|
|
elif self.protocol=='gen' and inspect.isgenerator(validation()):
|
|
|
|
return training, validation
|
2020-12-23 11:14:35 +01:00
|
|
|
else:
|
2021-01-15 18:32:32 +01:00
|
|
|
raise ValueError(f'"validation" must either be a LabelledCollection or a float in (0,1) indicating the'
|
2021-10-26 18:41:10 +02:00
|
|
|
f'proportion of training documents to extract (type found: {type(validation)}). '
|
|
|
|
f'Optionally, "validation" can be a callable function returning a generator that yields '
|
|
|
|
f'the sample instances along with their true prevalence at each iteration by '
|
|
|
|
f'setting protocol="gen".')
|
2020-12-23 11:14:35 +01:00
|
|
|
|
|
|
|
def __check_error(self, error):
|
|
|
|
if error in qp.error.QUANTIFICATION_ERROR:
|
|
|
|
self.error = error
|
|
|
|
elif isinstance(error, str):
|
2021-01-27 09:54:41 +01:00
|
|
|
self.error = qp.error.from_name(error)
|
|
|
|
elif hasattr(error, '__call__'):
|
|
|
|
self.error = error
|
2020-12-23 11:14:35 +01:00
|
|
|
else:
|
|
|
|
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
2021-01-06 14:58:29 +01:00
|
|
|
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
|
2020-12-23 11:14:35 +01:00
|
|
|
|
2021-06-16 11:45:40 +02:00
|
|
|
def __generate_predictions(self, model, val_split):
|
|
|
|
commons = {
|
|
|
|
'n_repetitions': self.n_repetitions,
|
|
|
|
'n_jobs': self.n_jobs,
|
|
|
|
'random_seed': self.random_seed,
|
|
|
|
'verbose': False
|
|
|
|
}
|
|
|
|
if self.protocol == 'app':
|
|
|
|
return artificial_prevalence_prediction(
|
|
|
|
model, val_split, self.sample_size,
|
|
|
|
n_prevpoints=self.n_prevpoints,
|
|
|
|
eval_budget=self.eval_budget,
|
|
|
|
**commons
|
|
|
|
)
|
|
|
|
elif self.protocol == 'npp':
|
|
|
|
return natural_prevalence_prediction(
|
|
|
|
model, val_split, self.sample_size,
|
|
|
|
**commons)
|
2021-10-26 18:41:10 +02:00
|
|
|
elif self.protocol == 'gen':
|
|
|
|
return gen_prevalence_prediction(model, gen_fn=val_split, eval_budget=self.eval_budget)
|
2021-06-16 11:45:40 +02:00
|
|
|
else:
|
|
|
|
raise ValueError('unknown protocol')
|
|
|
|
|
2021-11-08 18:01:49 +01:00
|
|
|
def fit(self, training: LabelledCollection, val_split: Union[LabelledCollection, float, Callable] = None):
|
2021-11-09 15:44:57 +01:00
|
|
|
""" Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
|
|
|
|
the error metric.
|
2021-11-24 11:20:42 +01:00
|
|
|
|
2020-12-23 11:14:35 +01:00
|
|
|
:param training: the training set on which to optimize the hyperparameters
|
2021-01-19 18:26:40 +01:00
|
|
|
:param val_split: either a LabelledCollection on which to test the performance of the different settings, or
|
2021-11-09 15:44:57 +01:00
|
|
|
a float in [0,1] indicating the proportion of labelled data to extract from the training set
|
2021-11-24 11:20:42 +01:00
|
|
|
:return: self
|
2020-12-23 11:14:35 +01:00
|
|
|
"""
|
2021-01-22 18:01:51 +01:00
|
|
|
if val_split is None:
|
|
|
|
val_split = self.val_split
|
2021-01-19 18:26:40 +01:00
|
|
|
training, val_split = self.__check_training_validation(training, val_split)
|
2021-10-26 18:41:10 +02:00
|
|
|
if self.protocol != 'gen':
|
|
|
|
assert isinstance(self.sample_size, int) and self.sample_size > 0, 'sample_size must be a positive integer'
|
2020-12-23 11:14:35 +01:00
|
|
|
|
|
|
|
params_keys = list(self.param_grid.keys())
|
|
|
|
params_values = list(self.param_grid.values())
|
|
|
|
|
|
|
|
model = self.model
|
|
|
|
|
2021-01-15 17:42:19 +01:00
|
|
|
if self.timeout > 0:
|
|
|
|
def handler(signum, frame):
|
2021-11-09 15:44:57 +01:00
|
|
|
self._sout('timeout reached')
|
2021-01-15 17:42:19 +01:00
|
|
|
raise TimeoutError()
|
2021-05-04 17:09:13 +02:00
|
|
|
|
2021-01-15 17:42:19 +01:00
|
|
|
signal.signal(signal.SIGALRM, handler)
|
|
|
|
|
2020-12-23 11:14:35 +01:00
|
|
|
self.param_scores_ = {}
|
|
|
|
self.best_score_ = None
|
2021-01-15 17:42:19 +01:00
|
|
|
some_timeouts = False
|
2020-12-23 11:14:35 +01:00
|
|
|
for values in itertools.product(*params_values):
|
2021-01-19 18:26:40 +01:00
|
|
|
params = dict({k: values[i] for i, k in enumerate(params_keys)})
|
2020-12-23 11:14:35 +01:00
|
|
|
|
2021-01-15 17:42:19 +01:00
|
|
|
if self.timeout > 0:
|
|
|
|
signal.alarm(self.timeout)
|
|
|
|
|
|
|
|
try:
|
|
|
|
# overrides default parameters with the parameters being explored at this iteration
|
|
|
|
model.set_params(**params)
|
|
|
|
model.fit(training)
|
2021-06-16 11:45:40 +02:00
|
|
|
true_prevalences, estim_prevalences = self.__generate_predictions(model, val_split)
|
2021-01-15 17:42:19 +01:00
|
|
|
score = self.error(true_prevalences, estim_prevalences)
|
2021-11-09 15:44:57 +01:00
|
|
|
self._sout(f'checking hyperparams={params} got {self.error.__name__} score {score:.5f}')
|
2021-01-15 17:42:19 +01:00
|
|
|
if self.best_score_ is None or score < self.best_score_:
|
|
|
|
self.best_score_ = score
|
|
|
|
self.best_params_ = params
|
2021-06-16 13:23:44 +02:00
|
|
|
self.best_model_ = deepcopy(model)
|
2021-01-15 17:42:19 +01:00
|
|
|
self.param_scores_[str(params)] = score
|
|
|
|
|
|
|
|
if self.timeout > 0:
|
|
|
|
signal.alarm(0)
|
|
|
|
except TimeoutError:
|
|
|
|
print(f'timeout reached for config {params}')
|
|
|
|
some_timeouts = True
|
|
|
|
|
|
|
|
if self.best_score_ is None and some_timeouts:
|
|
|
|
raise TimeoutError('all jobs took more than the timeout time to end')
|
2020-12-23 11:14:35 +01:00
|
|
|
|
2021-11-09 15:44:57 +01:00
|
|
|
self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f})')
|
2020-12-23 11:14:35 +01:00
|
|
|
|
|
|
|
if self.refit:
|
2021-11-09 15:44:57 +01:00
|
|
|
self._sout(f'refitting on the whole development set')
|
2021-01-19 18:26:40 +01:00
|
|
|
self.best_model_.fit(training + val_split)
|
2020-12-23 11:14:35 +01:00
|
|
|
|
2021-01-06 14:58:29 +01:00
|
|
|
return self
|
|
|
|
|
|
|
|
def quantify(self, instances):
|
2021-11-24 11:20:42 +01:00
|
|
|
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
|
2021-11-09 15:44:57 +01:00
|
|
|
|
|
|
|
:param instances: sample contanining the instances
|
2021-11-24 11:20:42 +01:00
|
|
|
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
|
|
|
|
by the model selection process.
|
2021-11-09 15:44:57 +01:00
|
|
|
"""
|
2021-06-16 11:45:40 +02:00
|
|
|
assert hasattr(self, 'best_model_'), 'quantify called before fit'
|
2021-10-26 18:41:10 +02:00
|
|
|
return self.best_model().quantify(instances)
|
2021-01-06 14:58:29 +01:00
|
|
|
|
2021-05-04 17:09:13 +02:00
|
|
|
@property
|
|
|
|
def classes_(self):
|
2021-11-24 11:20:42 +01:00
|
|
|
"""
|
|
|
|
Classes on which the quantifier has been trained on.
|
|
|
|
:return: a ndarray of shape `(n_classes)` with the class identifiers
|
|
|
|
"""
|
2021-10-26 18:41:10 +02:00
|
|
|
return self.best_model().classes_
|
2021-05-04 17:09:13 +02:00
|
|
|
|
2021-01-06 14:58:29 +01:00
|
|
|
def set_params(self, **parameters):
|
2021-11-09 15:44:57 +01:00
|
|
|
"""Sets the hyper-parameters to explore.
|
|
|
|
|
|
|
|
:param parameters: a dictionary with keys the parameter names and values the list of values to explore
|
|
|
|
"""
|
2021-01-06 14:58:29 +01:00
|
|
|
self.param_grid = parameters
|
|
|
|
|
|
|
|
def get_params(self, deep=True):
|
2021-11-09 15:44:57 +01:00
|
|
|
"""Returns the dictionary of hyper-parameters to explore (`param_grid`)
|
|
|
|
|
|
|
|
:param deep: Unused
|
|
|
|
:return: the dictionary `param_grid`
|
|
|
|
"""
|
2021-01-06 14:58:29 +01:00
|
|
|
return self.param_grid
|
2020-12-23 11:14:35 +01:00
|
|
|
|
2021-01-11 18:31:12 +01:00
|
|
|
def best_model(self):
|
2021-11-24 11:20:42 +01:00
|
|
|
"""
|
|
|
|
Returns the best model found after calling the :meth:`fit` method, i.e., the one trained on the combination
|
|
|
|
of hyper-parameters that minimized the error function.
|
|
|
|
|
|
|
|
:return: a trained quantifier
|
|
|
|
"""
|
2021-01-11 18:31:12 +01:00
|
|
|
if hasattr(self, 'best_model_'):
|
|
|
|
return self.best_model_
|
|
|
|
raise ValueError('best_model called before fit')
|