import itertools import signal from copy import deepcopy from typing import Union, Callable import numpy as np import quapy as qp from quapy.data.base import LabelledCollection from quapy.evaluation import artificial_prevalence_prediction, natural_prevalence_prediction, gen_prevalence_prediction from quapy.method.aggregative import BaseQuantifier import inspect class GridSearchQ(BaseQuantifier): """Grid Search optimization targeting a quantification-oriented metric. Optimizes the hyperparameters of a quantification method, based on an evaluation method and on an evaluation protocol for quantification. :param model: the quantifier to optimize :type model: BaseQuantifier :param param_grid: a dictionary with keys the parameter names and values the list of values to explore :param sample_size: the size of the samples to extract from the validation set (ignored if protocol='gen') :param protocol: either 'app' for the artificial prevalence protocol, 'npp' for the natural prevalence protocol, or 'gen' for using a custom sampling generator function :param n_prevpoints: if specified, indicates the number of equally distant points to extract from the interval [0,1] in order to define the prevalences of the samples; e.g., if n_prevpoints=5, then the prevalences for each class will be explored in [0.00, 0.25, 0.50, 0.75, 1.00]. If not specified, then eval_budget is requested. Ignored if protocol!='app'. :param n_repetitions: the number of repetitions for each combination of prevalences. This parameter is ignored for the protocol='app' if eval_budget is set and is lower than the number of combinations that would be generated using the value assigned to n_prevpoints (for the current number of classes and n_repetitions). Ignored for protocol='npp' and protocol='gen' (use eval_budget for setting a maximum number of samples in those cases). :param eval_budget: if specified, sets a ceil on the number of evaluations to perform for each hyper-parameter combination. For example, if protocol='app', there are 3 classes, n_repetitions=1 and eval_budget=20, then n_prevpoints will be set to 5, since this will generate 15 different prevalences, i.e., [0, 0, 1], [0, 0.25, 0.75], [0, 0.5, 0.5] ... [1, 0, 0], and since setting it to 6 would generate more than 20. When protocol='gen', indicates the maximum number of samples to generate, but less samples will be generated if the generator yields less samples. :param error: an error function (callable) or a string indicating the name of an error function (valid ones are those in qp.error.QUANTIFICATION_ERROR :param refit: whether or not to refit the model on the whole labelled collection (training+validation) with the best chosen hyperparameter combination. Ignored if protocol='gen' :param val_split: either a LabelledCollection on which to test the performance of the different settings, or a float in [0,1] indicating the proportion of labelled data to extract from the training set, or a callable returning a generator function each time it is invoked (only for protocol='gen'). :param n_jobs: number of parallel jobs :param random_seed: set the seed of the random generator to replicate experiments. Ignored if protocol='gen'. :param timeout: establishes a timer (in seconds) for each of the hyperparameters configurations being tested. Whenever a run takes longer than this timer, that configuration will be ignored. If all configurations end up being ignored, a TimeoutError exception is raised. If -1 (default) then no time bound is set. :param verbose: set to True to get information through the stdout """ def __init__(self, model: BaseQuantifier, param_grid: dict, sample_size: Union[int, None], protocol='app', n_prevpoints: int = None, n_repetitions: int = 1, eval_budget: int = None, error: Union[Callable, str] = qp.error.mae, refit=True, val_split=0.4, n_jobs=1, random_seed=42, timeout=-1, verbose=False): self.model = model self.param_grid = param_grid self.sample_size = sample_size self.protocol = protocol.lower() self.n_prevpoints = n_prevpoints self.n_repetitions = n_repetitions self.eval_budget = eval_budget self.refit = refit self.val_split = val_split self.n_jobs = n_jobs self.random_seed = random_seed self.timeout = timeout self.verbose = verbose self.__check_error(error) assert self.protocol in {'app', 'npp', 'gen'}, \ 'unknown protocol: valid ones are "app" or "npp" for the "artificial" or the "natural" prevalence ' \ 'protocols. Use protocol="gen" when passing a generator function thorough val_split that yields a ' \ 'sample (instances) and their prevalence (ndarray) at each iteration.' assert self.eval_budget is None or isinstance(self.eval_budget, int) if self.protocol in ['npp', 'gen']: if self.protocol=='npp' and (self.eval_budget is None or self.eval_budget <= 0): raise ValueError(f'when protocol="npp" the parameter eval_budget should be ' f'indicated (and should be >0).') if self.n_repetitions != 1: print('[warning] n_repetitions has been set and will be ignored for the selected protocol') def _sout(self, msg): if self.verbose: print(f'[{self.__class__.__name__}]: {msg}') def __check_training_validation(self, training, validation): if isinstance(validation, LabelledCollection): return training, validation elif isinstance(validation, float): assert 0. < validation < 1., 'validation proportion should be in (0,1)' training, validation = training.split_stratified(train_prop=1 - validation) return training, validation elif self.protocol=='gen' and inspect.isgenerator(validation()): return training, validation else: raise ValueError(f'"validation" must either be a LabelledCollection or a float in (0,1) indicating the' f'proportion of training documents to extract (type found: {type(validation)}). ' f'Optionally, "validation" can be a callable function returning a generator that yields ' f'the sample instances along with their true prevalence at each iteration by ' f'setting protocol="gen".') def __check_error(self, error): if error in qp.error.QUANTIFICATION_ERROR: self.error = error elif isinstance(error, str): self.error = qp.error.from_name(error) elif hasattr(error, '__call__'): self.error = error else: raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n' f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}') def __generate_predictions(self, model, val_split): commons = { 'n_repetitions': self.n_repetitions, 'n_jobs': self.n_jobs, 'random_seed': self.random_seed, 'verbose': False } if self.protocol == 'app': return artificial_prevalence_prediction( model, val_split, self.sample_size, n_prevpoints=self.n_prevpoints, eval_budget=self.eval_budget, **commons ) elif self.protocol == 'npp': return natural_prevalence_prediction( model, val_split, self.sample_size, **commons) elif self.protocol == 'gen': return gen_prevalence_prediction(model, gen_fn=val_split, eval_budget=self.eval_budget) else: raise ValueError('unknown protocol') def fit(self, training: LabelledCollection, val_split: Union[LabelledCollection, float, Callable] = None): """ Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing the error metric. :param training: the training set on which to optimize the hyperparameters :param val_split: either a LabelledCollection on which to test the performance of the different settings, or a float in [0,1] indicating the proportion of labelled data to extract from the training set :return: self """ if val_split is None: val_split = self.val_split training, val_split = self.__check_training_validation(training, val_split) if self.protocol != 'gen': assert isinstance(self.sample_size, int) and self.sample_size > 0, 'sample_size must be a positive integer' params_keys = list(self.param_grid.keys()) params_values = list(self.param_grid.values()) model = self.model if self.timeout > 0: def handler(signum, frame): self._sout('timeout reached') raise TimeoutError() signal.signal(signal.SIGALRM, handler) self.param_scores_ = {} self.best_score_ = None some_timeouts = False for values in itertools.product(*params_values): params = dict({k: values[i] for i, k in enumerate(params_keys)}) if self.timeout > 0: signal.alarm(self.timeout) try: # overrides default parameters with the parameters being explored at this iteration model.set_params(**params) model.fit(training) true_prevalences, estim_prevalences = self.__generate_predictions(model, val_split) score = self.error(true_prevalences, estim_prevalences) self._sout(f'checking hyperparams={params} got {self.error.__name__} score {score:.5f}') if self.best_score_ is None or score < self.best_score_: self.best_score_ = score self.best_params_ = params self.best_model_ = deepcopy(model) self.param_scores_[str(params)] = score if self.timeout > 0: signal.alarm(0) except TimeoutError: print(f'timeout reached for config {params}') some_timeouts = True if self.best_score_ is None and some_timeouts: raise TimeoutError('all jobs took more than the timeout time to end') self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f})') if self.refit: self._sout(f'refitting on the whole development set') self.best_model_.fit(training + val_split) return self def quantify(self, instances): """Estimate class prevalence values using the best model found after calling the :meth:`fit` method. :param instances: sample contanining the instances :return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found by the model selection process. """ assert hasattr(self, 'best_model_'), 'quantify called before fit' return self.best_model().quantify(instances) @property def classes_(self): """ Classes on which the quantifier has been trained on. :return: a ndarray of shape `(n_classes)` with the class identifiers """ return self.best_model().classes_ def set_params(self, **parameters): """Sets the hyper-parameters to explore. :param parameters: a dictionary with keys the parameter names and values the list of values to explore """ self.param_grid = parameters def get_params(self, deep=True): """Returns the dictionary of hyper-parameters to explore (`param_grid`) :param deep: Unused :return: the dictionary `param_grid` """ return self.param_grid def best_model(self): """ Returns the best model found after calling the :meth:`fit` method, i.e., the one trained on the combination of hyper-parameters that minimized the error function. :return: a trained quantifier """ if hasattr(self, 'best_model_'): return self.best_model_ raise ValueError('best_model called before fit')