QuaPy/quapy/model_selection.py

399 lines
15 KiB
Python
Raw Normal View History

import itertools
2021-01-15 18:32:32 +01:00
import signal
from copy import deepcopy
from enum import Enum
2023-11-06 02:00:06 +01:00
from typing import Union, Callable
2023-11-16 19:56:30 +01:00
from functools import wraps
import numpy as np
from sklearn import clone
import quapy as qp
2022-06-15 16:54:42 +02:00
from quapy import evaluation
2023-11-06 02:00:06 +01:00
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
2021-03-19 17:34:09 +01:00
from quapy.data.base import LabelledCollection
from quapy.method.aggregative import BaseQuantifier, AggregativeQuantifier
2023-11-20 22:05:26 +01:00
from quapy.util import timeout
2023-11-06 02:00:06 +01:00
from time import time
class Status(Enum):
SUCCESS = 1
TIMEOUT = 2
INVALID = 3
ERROR = 4
2023-11-16 19:56:30 +01:00
2023-11-21 18:59:36 +01:00
class ConfigStatus:
def __init__(self, params, status, msg=''):
self.params = params
self.status = status
self.msg = msg
2023-11-16 19:56:30 +01:00
2023-11-21 18:59:36 +01:00
def __str__(self):
return f':params:{self.params} :status:{self.status} ' + self.msg
2023-11-16 19:56:30 +01:00
2023-11-21 18:59:36 +01:00
def __repr__(self):
return str(self)
2023-11-16 19:56:30 +01:00
2023-11-21 18:59:36 +01:00
def success(self):
return self.status == Status.SUCCESS
2023-11-16 19:56:30 +01:00
2023-11-21 18:59:36 +01:00
def failed(self):
return self.status != Status.SUCCESS
2023-11-16 19:56:30 +01:00
class GridSearchQ(BaseQuantifier):
2021-11-09 15:44:57 +01:00
"""Grid Search optimization targeting a quantification-oriented metric.
2021-11-09 15:44:57 +01:00
Optimizes the hyperparameters of a quantification method, based on an evaluation method and on an evaluation
protocol for quantification.
:param model: the quantifier to optimize
:type model: BaseQuantifier
:param param_grid: a dictionary with keys the parameter names and values the list of values to explore
2023-02-14 17:00:50 +01:00
:param protocol: a sample generation protocol, an instance of :class:`quapy.protocol.AbstractProtocol`
2021-11-09 15:44:57 +01:00
:param error: an error function (callable) or a string indicating the name of an error function (valid ones
2023-02-14 17:00:50 +01:00
are those in :class:`quapy.error.QUANTIFICATION_ERROR`
2023-11-21 18:59:36 +01:00
:param refit: whether to refit the model on the whole labelled collection (training+validation) with
2021-11-08 18:01:49 +01:00
the best chosen hyperparameter combination. Ignored if protocol='gen'
2021-11-09 15:44:57 +01:00
:param timeout: establishes a timer (in seconds) for each of the hyperparameters configurations being tested.
Whenever a run takes longer than this timer, that configuration will be ignored. If all configurations end up
being ignored, a TimeoutError exception is raised. If -1 (default) then no time bound is set.
2023-11-21 18:59:36 +01:00
:param raise_errors: boolean, if True then raises an exception when a param combination yields any error, if
otherwise is False (default), then the combination is marked with an error status, but the process goes on.
However, if no configuration yields a valid model, then a ValueError exception will be raised.
2021-11-09 15:44:57 +01:00
:param verbose: set to True to get information through the stdout
"""
2023-11-06 01:58:36 +01:00
def __init__(self,
model: BaseQuantifier,
param_grid: dict,
protocol: AbstractProtocol,
error: Union[Callable, str] = qp.error.mae,
refit=True,
timeout=-1,
n_jobs=None,
2023-11-21 18:59:36 +01:00
raise_errors=False,
2023-11-06 01:58:36 +01:00
verbose=False):
self.model = model
self.param_grid = param_grid
2022-05-25 19:14:33 +02:00
self.protocol = protocol
self.refit = refit
self.timeout = timeout
self.n_jobs = qp._get_njobs(n_jobs)
2023-11-21 18:59:36 +01:00
self.raise_errors = raise_errors
self.verbose = verbose
self.__check_error(error)
2023-11-06 01:58:36 +01:00
assert isinstance(protocol, AbstractProtocol), 'unknown protocol'
2021-11-09 15:44:57 +01:00
def _sout(self, msg):
if self.verbose:
2023-11-06 01:58:36 +01:00
print(f'[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}')
def __check_error(self, error):
if error in qp.error.QUANTIFICATION_ERROR:
self.error = error
elif isinstance(error, str):
self.error = qp.error.from_name(error)
2023-11-06 01:58:36 +01:00
elif hasattr(error, '__call__'):
self.error = error
else:
2023-11-06 01:58:36 +01:00
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
2023-11-21 18:59:36 +01:00
def _prepare_classifier(self, cls_params):
model = deepcopy(self.model)
2023-11-16 19:56:30 +01:00
2023-11-21 18:59:36 +01:00
def job(cls_params):
model.set_params(**cls_params)
predictions = model.classifier_fit_predict(self._training)
return predictions
2023-11-20 22:05:26 +01:00
2023-11-21 18:59:36 +01:00
predictions, status, took = self._error_handler(job, cls_params)
self._sout(f'[classifier fit] hyperparams={cls_params} [took {took:.3f}s]')
2023-11-21 18:59:36 +01:00
return model, predictions, status, took
2023-11-16 19:56:30 +01:00
2023-11-21 18:59:36 +01:00
def _prepare_aggregation(self, args):
model, predictions, cls_took, cls_params, q_params = args
model = deepcopy(model)
2023-11-16 19:56:30 +01:00
params = {**cls_params, **q_params}
2023-11-21 18:59:36 +01:00
def job(q_params):
2023-11-20 22:05:26 +01:00
model.set_params(**q_params)
2023-11-21 18:59:36 +01:00
model.aggregation_fit(predictions, self._training)
2023-11-20 22:05:26 +01:00
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
2023-11-21 18:59:36 +01:00
return score
score, status, aggr_took = self._error_handler(job, q_params)
self._print_status(params, score, status, aggr_took)
return model, params, score, status, (cls_took+aggr_took)
def _prepare_nonaggr_model(self, params):
2023-11-16 19:56:30 +01:00
model = deepcopy(self.model)
2022-05-25 19:14:33 +02:00
2023-11-21 18:59:36 +01:00
def job(params):
model.set_params(**params)
model.fit(self._training)
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
return score
2023-11-21 18:59:36 +01:00
score, status, took = self._error_handler(job, params)
self._print_status(params, score, status, took)
return model, params, score, status, took
2023-11-21 18:59:36 +01:00
def _compute_scores_aggregative(self, training):
# break down the set of hyperparameters into two: classifier-specific, quantifier-specific
cls_configs, q_configs = group_params(self.param_grid)
# train all classifiers and get the predictions
2023-11-21 18:59:36 +01:00
self._training = training
cls_outs = qp.util.parallel(
2023-11-16 19:56:30 +01:00
self._prepare_classifier,
2023-11-21 18:59:36 +01:00
cls_configs,
seed=qp.environ.get('_R_SEED', None),
2023-11-21 18:59:36 +01:00
n_jobs=self.n_jobs
)
2023-11-21 18:59:36 +01:00
# filter out classifier configurations that yielded any error
success_outs = []
for (model, predictions, status, took), cls_config in zip(cls_outs, cls_configs):
if status.success():
success_outs.append((model, predictions, took, cls_config))
else:
self.error_collector.append(status)
2023-11-16 19:56:30 +01:00
2023-11-21 18:59:36 +01:00
if len(success_outs) == 0:
raise ValueError('No valid configuration found for the classifier!')
2023-11-16 19:56:30 +01:00
2023-11-21 18:59:36 +01:00
# explore the quantifier-specific hyperparameters for each valid training configuration
aggr_configs = [(*out, q_config) for out, q_config in itertools.product(success_outs, q_configs)]
aggr_outs = qp.util.parallel(
2023-11-16 19:56:30 +01:00
self._prepare_aggregation,
2023-11-21 18:59:36 +01:00
aggr_configs,
seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs
)
2023-11-21 18:59:36 +01:00
return aggr_outs
2023-11-16 19:56:30 +01:00
def _compute_scores_nonaggregative(self, training):
configs = expand_grid(self.param_grid)
2023-11-21 18:59:36 +01:00
self._training = training
2023-11-16 19:56:30 +01:00
scores = qp.util.parallel(
2023-11-21 18:59:36 +01:00
self._prepare_nonaggr_model,
configs,
2023-11-16 19:56:30 +01:00
seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs
)
return scores
2023-11-21 18:59:36 +01:00
def _print_status(self, params, score, status, took):
if status.success():
self._sout(f'hyperparams=[{params}]\t got {self.error.__name__} = {score:.5f} [took {took:.3f}s]')
2023-11-16 19:56:30 +01:00
else:
2023-11-21 18:59:36 +01:00
self._sout(f'error={status}')
def fit(self, training: LabelledCollection):
""" Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
the error metric.
:param training: the training set on which to optimize the hyperparameters
:return: self
"""
if self.refit and not isinstance(self.protocol, OnLabelledCollectionProtocol):
2023-11-16 19:56:30 +01:00
raise RuntimeWarning(
f'"refit" was requested, but the protocol does not implement '
f'the {OnLabelledCollectionProtocol.__name__} interface'
)
tinit = time()
2023-11-21 18:59:36 +01:00
self.error_collector = []
2023-11-16 19:56:30 +01:00
self._sout(f'starting model selection with n_jobs={self.n_jobs}')
2023-11-21 18:59:36 +01:00
if isinstance(self.model, AggregativeQuantifier):
results = self._compute_scores_aggregative(training)
else:
results = self._compute_scores_nonaggregative(training)
self.param_scores_ = {}
self.best_score_ = None
2023-11-21 18:59:36 +01:00
for model, params, score, status, took in results:
if status.success():
if self.best_score_ is None or score < self.best_score_:
self.best_score_ = score
self.best_params_ = params
2023-11-21 18:59:36 +01:00
self.best_model_ = model
self.param_scores_[str(params)] = score
else:
2023-11-21 18:59:36 +01:00
self.param_scores_[str(params)] = status.status
self.error_collector.append(status)
tend = time()-tinit
if self.best_score_ is None:
2023-11-21 18:59:36 +01:00
raise ValueError('no combination of hyperparameters seemed to work')
self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) '
f'[took {tend:.4f}s]')
2022-05-25 19:14:33 +02:00
2023-11-21 18:59:36 +01:00
no_errors = len(self.error_collector)
if no_errors>0:
self._sout(f'warning: {no_errors} errors found')
for err in self.error_collector:
self._sout(f'\t{str(err)}')
if self.refit:
if isinstance(self.protocol, OnLabelledCollectionProtocol):
tinit = time()
self._sout(f'refitting on the whole development set')
self.best_model_.fit(training + self.protocol.get_labelled_collection())
tend = time() - tinit
self.refit_time_ = tend
else:
2023-11-21 18:59:36 +01:00
# already checked
raise RuntimeWarning(f'the model cannot be refit on the whole dataset')
return self
2023-11-06 01:58:36 +01:00
def quantify(self, instances):
2021-11-24 11:20:42 +01:00
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
2021-11-09 15:44:57 +01:00
:param instances: sample contanining the instances
2021-11-24 11:20:42 +01:00
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
by the model selection process.
2021-11-09 15:44:57 +01:00
"""
2023-11-06 01:58:36 +01:00
assert hasattr(self, 'best_model_'), 'quantify called before fit'
return self.best_model().quantify(instances)
def set_params(self, **parameters):
2021-11-09 15:44:57 +01:00
"""Sets the hyper-parameters to explore.
:param parameters: a dictionary with keys the parameter names and values the list of values to explore
"""
self.param_grid = parameters
def get_params(self, deep=True):
2021-11-09 15:44:57 +01:00
"""Returns the dictionary of hyper-parameters to explore (`param_grid`)
:param deep: Unused
:return: the dictionary `param_grid`
"""
return self.param_grid
2021-01-11 18:31:12 +01:00
def best_model(self):
2021-11-24 11:20:42 +01:00
"""
Returns the best model found after calling the :meth:`fit` method, i.e., the one trained on the combination
of hyper-parameters that minimized the error function.
:return: a trained quantifier
"""
2023-11-06 01:58:36 +01:00
if hasattr(self, 'best_model_'):
2021-01-11 18:31:12 +01:00
return self.best_model_
2023-11-06 01:58:36 +01:00
raise ValueError('best_model called before fit')
2022-05-25 19:14:33 +02:00
2023-11-21 18:59:36 +01:00
def _error_handler(self, func, params):
"""
Endorses one job with two returned values: the status, and the time of execution
2023-11-21 18:59:36 +01:00
:param func: the function to be called
:param params: parameters of the function
:return: `tuple(out, status, time)` where `out` is the function output,
`status` is an enum value from `Status`, and `time` is the time it
took to complete the call
"""
output = None
def _handle(status, exception):
if self.raise_errors:
raise exception
else:
return ConfigStatus(params, status, str(e))
2023-11-20 22:05:26 +01:00
try:
with timeout(self.timeout):
2023-11-21 18:59:36 +01:00
tinit = time()
output = func(params)
status = ConfigStatus(params, Status.SUCCESS)
2023-11-20 22:05:26 +01:00
2023-11-21 18:59:36 +01:00
except TimeoutError as e:
status = _handle(Status.TIMEOUT, str(e))
2023-11-20 22:05:26 +01:00
2023-11-21 18:59:36 +01:00
except ValueError as e:
status = _handle(Status.INVALID, str(e))
2023-11-20 22:05:26 +01:00
2023-11-21 18:59:36 +01:00
except Exception as e:
status = _handle(Status.ERROR, str(e))
2023-11-20 22:05:26 +01:00
2023-11-21 18:59:36 +01:00
took = time() - tinit
return output, status, took
2023-11-20 22:05:26 +01:00
2023-11-06 01:58:36 +01:00
def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
"""
Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_
but for quantification.
:param quantifier: a quantifier issuing class prevalence values
:param data: a labelled collection
:param nfolds: number of folds for k-fold cross validation generation
:param random_state: random seed for reproducibility
:return: a vector of class prevalence values
"""
total_prev = np.zeros(shape=data.n_classes)
for train, test in data.kFCV(nfolds=nfolds, random_state=random_state):
quantifier.fit(train)
fold_prev = quantifier.quantify(test.X)
2023-11-06 01:58:36 +01:00
rel_size = 1. * len(test) / len(data)
total_prev += fold_prev*rel_size
return total_prev
2023-11-06 01:58:36 +01:00
2023-11-15 10:55:13 +01:00
def expand_grid(param_grid: dict):
"""
Expands a param_grid dictionary as a list of configurations.
Example:
>>> combinations = expand_grid({'A': [1, 10, 100], 'B': [True, False]})
>>> print(combinations)
>>> [{'A': 1, 'B': True}, {'A': 1, 'B': False}, {'A': 10, 'B': True}, {'A': 10, 'B': False}, {'A': 100, 'B': True}, {'A': 100, 'B': False}]
:param param_grid: dictionary with keys representing hyper-parameter names, and values representing the range
to explore for that hyper-parameter
:return: a list of configurations, i.e., combinations of hyper-parameter assignments in the grid.
"""
params_keys = list(param_grid.keys())
params_values = list(param_grid.values())
configs = [{k: combs[i] for i, k in enumerate(params_keys)} for combs in itertools.product(*params_values)]
return configs
def group_params(param_grid: dict):
"""
Partitions a param_grid dictionary as two lists of configurations, one for the classifier-specific
hyper-parameters, and another for que quantifier-specific hyper-parameters
:param param_grid: dictionary with keys representing hyper-parameter names, and values representing the range
to explore for that hyper-parameter
:return: two expanded grids of configurations, one for the classifier, another for the quantifier
"""
classifier_params, quantifier_params = {}, {}
for key, values in param_grid.items():
if key.startswith('classifier__') or key == 'val_split':
classifier_params[key] = values
else:
quantifier_params[key] = values
classifier_configs = expand_grid(classifier_params)
quantifier_configs = expand_grid(quantifier_params)
return classifier_configs, quantifier_configs