QuaPy/quapy/model_selection.py

409 lines
15 KiB
Python
Raw Normal View History

import itertools
2021-01-15 18:32:32 +01:00
import signal
from copy import deepcopy
from enum import Enum
2023-11-06 02:00:06 +01:00
from typing import Union, Callable
2023-11-16 19:56:30 +01:00
from functools import wraps
import numpy as np
from sklearn import clone
import quapy as qp
2022-06-15 16:54:42 +02:00
from quapy import evaluation
2023-11-06 02:00:06 +01:00
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
2021-03-19 17:34:09 +01:00
from quapy.data.base import LabelledCollection
from quapy.method.aggregative import BaseQuantifier, AggregativeQuantifier
2023-11-20 22:05:26 +01:00
from quapy.util import timeout
2023-11-06 02:00:06 +01:00
from time import time
class Status(Enum):
SUCCESS = 1
TIMEOUT = 2
INVALID = 3
ERROR = 4
2023-11-16 19:56:30 +01:00
def check_status(func):
@wraps(func)
def wrapper(*args, **kwargs):
obj = args[0]
tinit = time()
job_descriptor = dict(args[1])
params = {**job_descriptor.get('cls-params', {}), **job_descriptor.get('q-params', {})}
if obj.timeout > 0:
def handler(signum, frame):
raise TimeoutError()
signal.signal(signal.SIGALRM, handler)
signal.alarm(obj.timeout)
try:
job_descriptor = func(*args, **kwargs)
ttime = time() - tinit
score = job_descriptor.get('score', None)
if score is not None:
obj._sout(f'hyperparams=[{params}]\t got {obj.error.__name__} = {score:.5f} [took {ttime:.4f}s]')
if obj.timeout > 0:
signal.alarm(0)
exit_status = Status.SUCCESS
except TimeoutError:
obj._sout(f'timeout ({obj.timeout}s) reached for config {params}')
exit_status = Status.TIMEOUT
except ValueError as e:
obj._sout(f'the combination of hyperparameters {params} is invalid')
obj._sout(f'\tException: {e}')
exit_status = Status.INVALID
except Exception as e:
obj._sout(f'something went wrong for config {params}; skipping:')
obj._sout(f'\tException: {e}')
exit_status = Status.ERROR
job_descriptor['status'] = exit_status
job_descriptor['params'] = params
return job_descriptor
return wrapper
class GridSearchQ(BaseQuantifier):
2021-11-09 15:44:57 +01:00
"""Grid Search optimization targeting a quantification-oriented metric.
2021-11-09 15:44:57 +01:00
Optimizes the hyperparameters of a quantification method, based on an evaluation method and on an evaluation
protocol for quantification.
:param model: the quantifier to optimize
:type model: BaseQuantifier
:param param_grid: a dictionary with keys the parameter names and values the list of values to explore
2023-02-14 17:00:50 +01:00
:param protocol: a sample generation protocol, an instance of :class:`quapy.protocol.AbstractProtocol`
2021-11-09 15:44:57 +01:00
:param error: an error function (callable) or a string indicating the name of an error function (valid ones
2023-02-14 17:00:50 +01:00
are those in :class:`quapy.error.QUANTIFICATION_ERROR`
2021-11-09 15:44:57 +01:00
:param refit: whether or not to refit the model on the whole labelled collection (training+validation) with
2021-11-08 18:01:49 +01:00
the best chosen hyperparameter combination. Ignored if protocol='gen'
2021-11-09 15:44:57 +01:00
:param timeout: establishes a timer (in seconds) for each of the hyperparameters configurations being tested.
Whenever a run takes longer than this timer, that configuration will be ignored. If all configurations end up
being ignored, a TimeoutError exception is raised. If -1 (default) then no time bound is set.
2021-11-09 15:44:57 +01:00
:param verbose: set to True to get information through the stdout
"""
2023-11-06 01:58:36 +01:00
def __init__(self,
model: BaseQuantifier,
param_grid: dict,
protocol: AbstractProtocol,
error: Union[Callable, str] = qp.error.mae,
refit=True,
timeout=-1,
n_jobs=None,
verbose=False):
self.model = model
self.param_grid = param_grid
2022-05-25 19:14:33 +02:00
self.protocol = protocol
self.refit = refit
self.timeout = timeout
self.n_jobs = qp._get_njobs(n_jobs)
self.verbose = verbose
self.__check_error(error)
2023-11-06 01:58:36 +01:00
assert isinstance(protocol, AbstractProtocol), 'unknown protocol'
2021-11-09 15:44:57 +01:00
def _sout(self, msg):
if self.verbose:
2023-11-06 01:58:36 +01:00
print(f'[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}')
def __check_error(self, error):
if error in qp.error.QUANTIFICATION_ERROR:
self.error = error
elif isinstance(error, str):
self.error = qp.error.from_name(error)
2023-11-06 01:58:36 +01:00
elif hasattr(error, '__call__'):
self.error = error
else:
2023-11-06 01:58:36 +01:00
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
2023-11-16 19:56:30 +01:00
def _prepare_classifier(self, args):
cls_params = args['cls-params']
training = args['training']
model = deepcopy(self.model)
model.set_params(**cls_params)
predictions = model.classifier_fit_predict(training)
2023-11-16 19:56:30 +01:00
return {'model': model, 'predictions': predictions, 'cls-params': cls_params}
def _prepare_aggregation(self, args):
2023-11-20 22:05:26 +01:00
2023-11-16 19:56:30 +01:00
model = args['model']
predictions = args['predictions']
cls_params = args['cls-params']
q_params = args['q-params']
training = args['training']
params = {**cls_params, **q_params}
2023-11-20 22:05:26 +01:00
def job(model):
tinit = time()
model = deepcopy(model)
# overrides default parameters with the parameters being explored at this iteration
model.set_params(**q_params)
model.aggregation_fit(predictions, training)
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
ttime = time()-tinit
return {
'model': model,
'cls-params':cls_params,
'q-params': q_params,
'params': params,
'score': score,
'ttime':ttime
}
out, status = self._error_handler(job, args)
if status == Status.SUCCESS:
self._sout(f'hyperparams=[{params}]\t got {self.error.__name__} = {out["score"]:.5f} [took {out["time"]:.4f}s]')
elif status == Status.INVALID:
self._sout(f'the combination of hyperparameters {params} is invalid')
elif status == Status.
2023-11-16 19:56:30 +01:00
def _prepare_model(self, args):
2022-05-25 19:14:33 +02:00
params, training = args
2023-11-16 19:56:30 +01:00
model = deepcopy(self.model)
# overrides default parameters with the parameters being explored at this iteration
model.set_params(**params)
model.fit(training)
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
return {'model': model, 'params': params, 'score': score}
2022-05-25 19:14:33 +02:00
2023-11-16 19:56:30 +01:00
def _compute_scores_aggregative(self, training):
# break down the set of hyperparameters into two: classifier-specific, quantifier-specific
cls_configs, q_configs = group_params(self.param_grid)
# train all classifiers and get the predictions
2023-11-16 19:56:30 +01:00
partial_setups = qp.util.parallel(
self._prepare_classifier,
({'cls-params':params, 'training':training} for params in cls_configs),
seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs,
asarray=False,
)
2023-11-16 19:56:30 +01:00
# filter out classifier configurations that yield any error
for setup in partial_setups:
if setup['status'] != Status.SUCCESS:
self._sout(f'-> classifier hyperparemters {setup["params"]} caused '
f'error {setup["status"]} and will be ignored')
partial_setups = [setup for setup in partial_setups if setup['status']==Status.SUCCESS]
if len(partial_setups) == 0:
raise ValueError('No valid configuration found for the classifier.')
# explore the quantifier-specific hyperparameters for each training configuration
scores = qp.util.parallel(
2023-11-16 19:56:30 +01:00
self._prepare_aggregation,
({'q-params': setup[1], 'training': training, **setup[0]} for setup in itertools.product(partial_setups, q_configs)),
seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs
)
return scores
2023-11-16 19:56:30 +01:00
def _compute_scores_nonaggregative(self, training):
configs = expand_grid(self.param_grid)
# pass a seed to parallel, so it is set in child processes
scores = qp.util.parallel(
self._prepare_model,
((params, training) for params in configs),
seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs
)
return scores
def _compute_scores(self, training):
if isinstance(self.model, AggregativeQuantifier):
return self._compute_scores_aggregative(training)
else:
return self._compute_scores_nonaggregative(training)
def fit(self, training: LabelledCollection):
""" Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
the error metric.
:param training: the training set on which to optimize the hyperparameters
:return: self
"""
if self.refit and not isinstance(self.protocol, OnLabelledCollectionProtocol):
2023-11-16 19:56:30 +01:00
raise RuntimeWarning(
f'"refit" was requested, but the protocol does not implement '
f'the {OnLabelledCollectionProtocol.__name__} interface'
)
tinit = time()
2023-11-16 19:56:30 +01:00
self._sout(f'starting model selection with n_jobs={self.n_jobs}')
results = self._compute_scores(training)
self.param_scores_ = {}
self.best_score_ = None
2023-11-16 19:56:30 +01:00
for job_result in results:
score = job_result.get('score', None)
params = job_result['params']
if score is not None:
if self.best_score_ is None or score < self.best_score_:
self.best_score_ = score
self.best_params_ = params
2023-11-16 19:56:30 +01:00
self.best_model_ = job_result['model']
self.param_scores_[str(params)] = score
else:
2023-11-16 19:56:30 +01:00
self.param_scores_[str(params)] = job_result['status']
tend = time()-tinit
if self.best_score_ is None:
raise TimeoutError('no combination of hyperparameters seem to work')
self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) '
f'[took {tend:.4f}s]')
2022-05-25 19:14:33 +02:00
if self.refit:
if isinstance(self.protocol, OnLabelledCollectionProtocol):
tinit = time()
self._sout(f'refitting on the whole development set')
self.best_model_.fit(training + self.protocol.get_labelled_collection())
tend = time() - tinit
self.refit_time_ = tend
else:
raise RuntimeWarning(f'the model cannot be refit on the whole dataset')
return self
2023-11-06 01:58:36 +01:00
def quantify(self, instances):
2021-11-24 11:20:42 +01:00
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
2021-11-09 15:44:57 +01:00
:param instances: sample contanining the instances
2021-11-24 11:20:42 +01:00
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
by the model selection process.
2021-11-09 15:44:57 +01:00
"""
2023-11-06 01:58:36 +01:00
assert hasattr(self, 'best_model_'), 'quantify called before fit'
return self.best_model().quantify(instances)
def set_params(self, **parameters):
2021-11-09 15:44:57 +01:00
"""Sets the hyper-parameters to explore.
:param parameters: a dictionary with keys the parameter names and values the list of values to explore
"""
self.param_grid = parameters
def get_params(self, deep=True):
2021-11-09 15:44:57 +01:00
"""Returns the dictionary of hyper-parameters to explore (`param_grid`)
:param deep: Unused
:return: the dictionary `param_grid`
"""
return self.param_grid
2021-01-11 18:31:12 +01:00
def best_model(self):
2021-11-24 11:20:42 +01:00
"""
Returns the best model found after calling the :meth:`fit` method, i.e., the one trained on the combination
of hyper-parameters that minimized the error function.
:return: a trained quantifier
"""
2023-11-06 01:58:36 +01:00
if hasattr(self, 'best_model_'):
2021-01-11 18:31:12 +01:00
return self.best_model_
2023-11-06 01:58:36 +01:00
raise ValueError('best_model called before fit')
2022-05-25 19:14:33 +02:00
2023-11-20 22:05:26 +01:00
def _error_handler(self, func, *args, **kwargs):
try:
with timeout(self.timeout):
output = func(*args, **kwargs)
return output, Status.SUCCESS
except TimeoutError:
return None, Status.TIMEOUT
except ValueError:
return None, Status.INVALID
except Exception:
return None, Status.ERROR
2023-11-06 01:58:36 +01:00
def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
"""
Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_
but for quantification.
:param quantifier: a quantifier issuing class prevalence values
:param data: a labelled collection
:param nfolds: number of folds for k-fold cross validation generation
:param random_state: random seed for reproducibility
:return: a vector of class prevalence values
"""
total_prev = np.zeros(shape=data.n_classes)
for train, test in data.kFCV(nfolds=nfolds, random_state=random_state):
quantifier.fit(train)
fold_prev = quantifier.quantify(test.X)
2023-11-06 01:58:36 +01:00
rel_size = 1. * len(test) / len(data)
total_prev += fold_prev*rel_size
return total_prev
2023-11-06 01:58:36 +01:00
2023-11-15 10:55:13 +01:00
def expand_grid(param_grid: dict):
"""
Expands a param_grid dictionary as a list of configurations.
Example:
>>> combinations = expand_grid({'A': [1, 10, 100], 'B': [True, False]})
>>> print(combinations)
>>> [{'A': 1, 'B': True}, {'A': 1, 'B': False}, {'A': 10, 'B': True}, {'A': 10, 'B': False}, {'A': 100, 'B': True}, {'A': 100, 'B': False}]
:param param_grid: dictionary with keys representing hyper-parameter names, and values representing the range
to explore for that hyper-parameter
:return: a list of configurations, i.e., combinations of hyper-parameter assignments in the grid.
"""
params_keys = list(param_grid.keys())
params_values = list(param_grid.values())
configs = [{k: combs[i] for i, k in enumerate(params_keys)} for combs in itertools.product(*params_values)]
return configs
def group_params(param_grid: dict):
"""
Partitions a param_grid dictionary as two lists of configurations, one for the classifier-specific
hyper-parameters, and another for que quantifier-specific hyper-parameters
:param param_grid: dictionary with keys representing hyper-parameter names, and values representing the range
to explore for that hyper-parameter
:return: two expanded grids of configurations, one for the classifier, another for the quantifier
"""
classifier_params, quantifier_params = {}, {}
for key, values in param_grid.items():
if key.startswith('classifier__') or key == 'val_split':
classifier_params[key] = values
else:
quantifier_params[key] = values
classifier_configs = expand_grid(classifier_params)
quantifier_configs = expand_grid(quantifier_params)
return classifier_configs, quantifier_configs