1
0
Fork 0
QuaPy/quapy/model_selection.py

232 lines
8.7 KiB
Python
Raw Normal View History

import itertools
2021-01-15 18:32:32 +01:00
import signal
from copy import deepcopy
from typing import Union, Callable
import numpy as np
from sklearn import clone
import quapy as qp
2022-06-15 16:54:42 +02:00
from quapy import evaluation
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
2021-03-19 17:34:09 +01:00
from quapy.data.base import LabelledCollection
2021-01-15 18:32:32 +01:00
from quapy.method.aggregative import BaseQuantifier
2022-05-25 19:14:33 +02:00
from time import time
class GridSearchQ(BaseQuantifier):
2021-11-09 15:44:57 +01:00
"""Grid Search optimization targeting a quantification-oriented metric.
2021-11-09 15:44:57 +01:00
Optimizes the hyperparameters of a quantification method, based on an evaluation method and on an evaluation
protocol for quantification.
:param model: the quantifier to optimize
:type model: BaseQuantifier
:param param_grid: a dictionary with keys the parameter names and values the list of values to explore
2023-02-14 17:00:50 +01:00
:param protocol: a sample generation protocol, an instance of :class:`quapy.protocol.AbstractProtocol`
2021-11-09 15:44:57 +01:00
:param error: an error function (callable) or a string indicating the name of an error function (valid ones
2023-02-14 17:00:50 +01:00
are those in :class:`quapy.error.QUANTIFICATION_ERROR`
2021-11-09 15:44:57 +01:00
:param refit: whether or not to refit the model on the whole labelled collection (training+validation) with
2021-11-08 18:01:49 +01:00
the best chosen hyperparameter combination. Ignored if protocol='gen'
2021-11-09 15:44:57 +01:00
:param timeout: establishes a timer (in seconds) for each of the hyperparameters configurations being tested.
Whenever a run takes longer than this timer, that configuration will be ignored. If all configurations end up
being ignored, a TimeoutError exception is raised. If -1 (default) then no time bound is set.
2021-11-09 15:44:57 +01:00
:param verbose: set to True to get information through the stdout
"""
def __init__(self,
model: BaseQuantifier,
param_grid: dict,
2022-05-25 19:14:33 +02:00
protocol: AbstractProtocol,
2021-11-09 15:44:57 +01:00
error: Union[Callable, str] = qp.error.mae,
refit=True,
timeout=-1,
n_jobs=None,
2021-11-09 15:44:57 +01:00
verbose=False):
self.model = model
self.param_grid = param_grid
2022-05-25 19:14:33 +02:00
self.protocol = protocol
self.refit = refit
self.timeout = timeout
self.n_jobs = qp._get_njobs(n_jobs)
self.verbose = verbose
self.__check_error(error)
2022-05-25 19:14:33 +02:00
assert isinstance(protocol, AbstractProtocol), 'unknown protocol'
2021-11-09 15:44:57 +01:00
def _sout(self, msg):
if self.verbose:
print(f'[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}')
def __check_error(self, error):
if error in qp.error.QUANTIFICATION_ERROR:
self.error = error
elif isinstance(error, str):
self.error = qp.error.from_name(error)
elif hasattr(error, '__call__'):
self.error = error
else:
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
2022-05-25 19:14:33 +02:00
def fit(self, training: LabelledCollection):
2021-11-09 15:44:57 +01:00
""" Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
the error metric.
2021-11-24 11:20:42 +01:00
:param training: the training set on which to optimize the hyperparameters
2021-11-24 11:20:42 +01:00
:return: self
"""
params_keys = list(self.param_grid.keys())
params_values = list(self.param_grid.values())
2022-05-25 19:14:33 +02:00
protocol = self.protocol
self.param_scores_ = {}
self.best_score_ = None
2022-06-01 18:28:59 +02:00
tinit = time()
2023-02-10 19:02:17 +01:00
hyper = [dict({k: val[i] for i, k in enumerate(params_keys)}) for val in itertools.product(*params_values)]
self._sout(f'starting model selection with {self.n_jobs =}')
#pass a seed to parallel so it is set in clild processes
scores = qp.util.parallel(
self._delayed_eval,
((params, training) for params in hyper),
seed=qp.environ.get('_R_SEED', None),
n_jobs=self.n_jobs
)
2021-11-26 10:57:49 +01:00
2022-05-25 19:14:33 +02:00
for params, score, model in scores:
if score is not None:
if self.best_score_ is None or score < self.best_score_:
self.best_score_ = score
self.best_params_ = params
2022-05-25 19:14:33 +02:00
self.best_model_ = model
self.param_scores_[str(params)] = score
2022-05-25 19:14:33 +02:00
else:
self.param_scores_[str(params)] = 'timeout'
2022-06-01 18:28:59 +02:00
tend = time()-tinit
2022-05-25 19:14:33 +02:00
if self.best_score_ is None:
raise TimeoutError('no combination of hyperparameters seem to work')
2022-06-01 18:28:59 +02:00
self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) '
f'[took {tend:.4f}s]')
if self.refit:
2022-05-25 19:14:33 +02:00
if isinstance(protocol, OnLabelledCollectionProtocol):
self._sout(f'refitting on the whole development set')
self.best_model_.fit(training + protocol.get_labelled_collection())
else:
raise RuntimeWarning(f'"refit" was requested, but the protocol does not '
f'implement the {OnLabelledCollectionProtocol.__name__} interface')
return self
2022-05-25 19:14:33 +02:00
def _delayed_eval(self, args):
params, training = args
protocol = self.protocol
error = self.error
if self.timeout > 0:
def handler(signum, frame):
raise TimeoutError()
signal.signal(signal.SIGALRM, handler)
tinit = time()
if self.timeout > 0:
signal.alarm(self.timeout)
try:
model = deepcopy(self.model)
# overrides default parameters with the parameters being explored at this iteration
model.set_params(**params)
model.fit(training)
score = evaluation.evaluate(model, protocol=protocol, error_metric=error)
ttime = time()-tinit
self._sout(f'hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]')
if self.timeout > 0:
signal.alarm(0)
except TimeoutError:
self._sout(f'timeout ({self.timeout}s) reached for config {params}')
score = None
except ValueError as e:
self._sout(f'the combination of hyperparameters {params} is invalid')
raise e
except Exception as e:
self._sout(f'something went wrong for config {params}; skipping:')
self._sout(f'\tException: {e}')
score = None
2022-05-25 19:14:33 +02:00
return params, score, model
def quantify(self, instances):
2021-11-24 11:20:42 +01:00
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
2021-11-09 15:44:57 +01:00
:param instances: sample contanining the instances
2021-11-24 11:20:42 +01:00
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
by the model selection process.
2021-11-09 15:44:57 +01:00
"""
assert hasattr(self, 'best_model_'), 'quantify called before fit'
return self.best_model().quantify(instances)
def set_params(self, **parameters):
2021-11-09 15:44:57 +01:00
"""Sets the hyper-parameters to explore.
:param parameters: a dictionary with keys the parameter names and values the list of values to explore
"""
self.param_grid = parameters
def get_params(self, deep=True):
2021-11-09 15:44:57 +01:00
"""Returns the dictionary of hyper-parameters to explore (`param_grid`)
:param deep: Unused
:return: the dictionary `param_grid`
"""
return self.param_grid
2021-01-11 18:31:12 +01:00
def best_model(self):
2021-11-24 11:20:42 +01:00
"""
Returns the best model found after calling the :meth:`fit` method, i.e., the one trained on the combination
of hyper-parameters that minimized the error function.
:return: a trained quantifier
"""
2021-01-11 18:31:12 +01:00
if hasattr(self, 'best_model_'):
return self.best_model_
raise ValueError('best_model called before fit')
2022-05-25 19:14:33 +02:00
def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
"""
Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_
but for quantification.
:param quantifier: a quantifier issuing class prevalence values
:param data: a labelled collection
:param nfolds: number of folds for k-fold cross validation generation
:param random_state: random seed for reproducibility
:return: a vector of class prevalence values
"""
total_prev = np.zeros(shape=data.n_classes)
for train, test in data.kFCV(nfolds=nfolds, random_state=random_state):
quantifier.fit(train)
fold_prev = quantifier.quantify(test.X)
rel_size = len(test.X)/len(data)
total_prev += fold_prev*rel_size
return total_prev