QuaPy/quapy/model_selection.py

231 lines
11 KiB
Python

import itertools
import signal
from copy import deepcopy
from typing import Union, Callable
import quapy as qp
from quapy.data.base import LabelledCollection
from quapy.evaluation import artificial_prevalence_prediction, natural_prevalence_prediction, gen_prevalence_prediction
from quapy.method.aggregative import BaseQuantifier
import inspect
class GridSearchQ(BaseQuantifier):
def __init__(self,
model: BaseQuantifier,
param_grid: dict,
sample_size: int,
protocol='app',
n_prevpoints: int = None,
n_repetitions: int = 1,
eval_budget: int = None,
error: Union[Callable, str] = qp.error.mae,
refit=True,
val_split=0.4,
n_jobs=1,
random_seed=42,
timeout=-1,
verbose=False):
"""
Optimizes the hyperparameters of a quantification method, based on an evaluation method and on an evaluation
protocol for quantification.
:param model: the quantifier to optimize
:param param_grid: a dictionary with keys the parameter names and values the list of values to explore for
:param sample_size: the size of the samples to extract from the validation set
that particular parameter
:param protocol: either 'app' for the artificial prevalence protocol, or 'npp' for the natural prevalence
protocol
:param n_prevpoints: if specified, indicates the number of equally distant point to extract from the interval
[0,1] in order to define the prevalences of the samples; e.g., if n_prevpoints=5, then the prevalences for
each class will be explored in [0.00, 0.25, 0.50, 0.75, 1.00]. If not specified, then eval_budget is requested.
Ignored if protocol='npp'.
:param n_repetitions: the number of repetitions for each combination of prevalences. This parameter is ignored
if eval_budget is set and is lower than the number of combinations that would be generated using the value
assigned to n_prevpoints (for the current number of classes and n_repetitions)
:param eval_budget: if specified, sets a ceil on the number of evaluations to perform for each hyper-parameter
combination. For example, if there are 3 classes, n_repetitions=1 and eval_budget=20, then n_prevpoints will be
set to 5, since this will generate 15 different prevalences:
[0, 0, 1], [0, 0.25, 0.75], [0, 0.5, 0.5] ... [1, 0, 0]
Ignored if protocol='npp'.
:param error: an error function (callable) or a string indicating the name of an error function (valid ones
are those in qp.error.QUANTIFICATION_ERROR
:param refit: whether or not to refit the model on the whole labelled collection (training+validation) with
the best chosen hyperparameter combination
:param val_split: either a LabelledCollection on which to test the performance of the different settings, or
a float in [0,1] indicating the proportion of labelled data to extract from the training set
:param n_jobs: number of parallel jobs
:param random_seed: set the seed of the random generator to replicate experiments
:param timeout: establishes a timer (in seconds) for each of the hyperparameters configurations being tested.
Whenever a run takes longer than this timer, that configuration will be ignored. If all configurations end up
being ignored, a TimeoutError exception is raised. If -1 (default) then no time bound is set.
:param verbose: set to True to get information through the stdout
"""
self.model = model
self.param_grid = param_grid
self.sample_size = sample_size
self.protocol = protocol.lower()
self.n_prevpoints = n_prevpoints
self.n_repetitions = n_repetitions
self.eval_budget = eval_budget
self.refit = refit
self.val_split = val_split
self.n_jobs = n_jobs
self.random_seed = random_seed
self.timeout = timeout
self.verbose = verbose
self.__check_error(error)
assert self.protocol in {'app', 'npp', 'gen'}, \
'unknown protocol: valid ones are "app" or "npp" for the "artificial" or the "natural" prevalence ' \
'protocols. Use protocol="gen" when passing a generator function thorough val_split that yields a ' \
'sample (instances) and their prevalence (ndarray) at each iteration.'
if self.protocol == 'npp':
if self.n_repetitions is None or self.n_repetitions == 1:
if self.eval_budget is not None:
print(f'[warning] when protocol=="npp" the parameter n_repetitions should be indicated '
f'(and not eval_budget). Setting n_repetitions={self.eval_budget}...')
self.n_repetitions = self.eval_budget
else:
raise ValueError(f'when protocol=="npp" the parameter n_repetitions should be indicated '
f'(and should be >1).')
if self.n_prevpoints is not None:
print('[warning] n_prevpoints has been set along with the npp protocol, and will be ignored')
def sout(self, msg):
if self.verbose:
print(f'[{self.__class__.__name__}]: {msg}')
def __check_training_validation(self, training, validation):
if isinstance(validation, LabelledCollection):
return training, validation
elif isinstance(validation, float):
assert 0. < validation < 1., 'validation proportion should be in (0,1)'
training, validation = training.split_stratified(train_prop=1 - validation)
return training, validation
elif self.protocol=='gen' and inspect.isgenerator(validation()):
return training, validation
else:
raise ValueError(f'"validation" must either be a LabelledCollection or a float in (0,1) indicating the'
f'proportion of training documents to extract (type found: {type(validation)}). '
f'Optionally, "validation" can be a callable function returning a generator that yields '
f'the sample instances along with their true prevalence at each iteration by '
f'setting protocol="gen".')
def __check_error(self, error):
if error in qp.error.QUANTIFICATION_ERROR:
self.error = error
elif isinstance(error, str):
self.error = qp.error.from_name(error)
elif hasattr(error, '__call__'):
self.error = error
else:
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
def __generate_predictions(self, model, val_split):
commons = {
'n_repetitions': self.n_repetitions,
'n_jobs': self.n_jobs,
'random_seed': self.random_seed,
'verbose': False
}
if self.protocol == 'app':
return artificial_prevalence_prediction(
model, val_split, self.sample_size,
n_prevpoints=self.n_prevpoints,
eval_budget=self.eval_budget,
**commons
)
elif self.protocol == 'npp':
return natural_prevalence_prediction(
model, val_split, self.sample_size,
**commons)
elif self.protocol == 'gen':
return gen_prevalence_prediction(model, gen_fn=val_split, eval_budget=self.eval_budget)
else:
raise ValueError('unknown protocol')
def fit(self, training: LabelledCollection, val_split: Union[LabelledCollection, float] = None):
"""
:param training: the training set on which to optimize the hyperparameters
:param val_split: either a LabelledCollection on which to test the performance of the different settings, or
a float in [0,1] indicating the proportion of labelled data to extract from the training set
"""
if val_split is None:
val_split = self.val_split
training, val_split = self.__check_training_validation(training, val_split)
if self.protocol != 'gen':
assert isinstance(self.sample_size, int) and self.sample_size > 0, 'sample_size must be a positive integer'
params_keys = list(self.param_grid.keys())
params_values = list(self.param_grid.values())
model = self.model
n_jobs = self.n_jobs
if self.timeout > 0:
def handler(signum, frame):
self.sout('timeout reached')
raise TimeoutError()
signal.signal(signal.SIGALRM, handler)
self.sout(f'starting optimization with n_jobs={n_jobs}')
self.param_scores_ = {}
self.best_score_ = None
some_timeouts = False
for values in itertools.product(*params_values):
params = dict({k: values[i] for i, k in enumerate(params_keys)})
if self.timeout > 0:
signal.alarm(self.timeout)
try:
# overrides default parameters with the parameters being explored at this iteration
model.set_params(**params)
model.fit(training)
true_prevalences, estim_prevalences = self.__generate_predictions(model, val_split)
score = self.error(true_prevalences, estim_prevalences)
self.sout(f'checking hyperparams={params} got {self.error.__name__} score {score:.5f}')
if self.best_score_ is None or score < self.best_score_:
self.best_score_ = score
self.best_params_ = params
self.best_model_ = deepcopy(model)
self.param_scores_[str(params)] = score
if self.timeout > 0:
signal.alarm(0)
except TimeoutError:
print(f'timeout reached for config {params}')
some_timeouts = True
if self.best_score_ is None and some_timeouts:
raise TimeoutError('all jobs took more than the timeout time to end')
self.sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f})')
if self.refit:
self.sout(f'refitting on the whole development set')
self.best_model_.fit(training + val_split)
return self
def quantify(self, instances):
assert hasattr(self, 'best_model_'), 'quantify called before fit'
return self.best_model().quantify(instances)
@property
def classes_(self):
return self.best_model().classes_
def set_params(self, **parameters):
self.param_grid = parameters
def get_params(self, deep=True):
return self.param_grid
def best_model(self):
if hasattr(self, 'best_model_'):
return self.best_model_
raise ValueError('best_model called before fit')