2020-12-23 11:14:35 +01:00
|
|
|
import itertools
|
2021-01-15 18:32:32 +01:00
|
|
|
import signal
|
|
|
|
from copy import deepcopy
|
2023-11-16 14:29:34 +01:00
|
|
|
from enum import Enum
|
2023-11-06 02:00:06 +01:00
|
|
|
from typing import Union, Callable
|
2023-11-16 19:56:30 +01:00
|
|
|
from functools import wraps
|
2022-12-12 17:32:30 +01:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
from sklearn import clone
|
|
|
|
|
2020-12-23 11:14:35 +01:00
|
|
|
import quapy as qp
|
2022-06-15 16:54:42 +02:00
|
|
|
from quapy import evaluation
|
2023-11-06 02:00:06 +01:00
|
|
|
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
|
2021-03-19 17:34:09 +01:00
|
|
|
from quapy.data.base import LabelledCollection
|
2023-11-16 14:29:34 +01:00
|
|
|
from quapy.method.aggregative import BaseQuantifier, AggregativeQuantifier
|
2023-11-20 22:05:26 +01:00
|
|
|
from quapy.util import timeout
|
2023-11-06 02:00:06 +01:00
|
|
|
from time import time
|
2020-12-23 11:14:35 +01:00
|
|
|
|
|
|
|
|
2023-11-16 14:29:34 +01:00
|
|
|
class Status(Enum):
|
|
|
|
SUCCESS = 1
|
|
|
|
TIMEOUT = 2
|
|
|
|
INVALID = 3
|
|
|
|
ERROR = 4
|
|
|
|
|
2023-11-16 19:56:30 +01:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
class ConfigStatus:
|
|
|
|
def __init__(self, params, status, msg=''):
|
|
|
|
self.params = params
|
|
|
|
self.status = status
|
|
|
|
self.msg = msg
|
2023-11-16 19:56:30 +01:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
def __str__(self):
|
|
|
|
return f':params:{self.params} :status:{self.status} ' + self.msg
|
2023-11-16 19:56:30 +01:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
def __repr__(self):
|
|
|
|
return str(self)
|
2023-11-16 19:56:30 +01:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
def success(self):
|
|
|
|
return self.status == Status.SUCCESS
|
2023-11-16 19:56:30 +01:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
def failed(self):
|
|
|
|
return self.status != Status.SUCCESS
|
2023-11-16 19:56:30 +01:00
|
|
|
|
|
|
|
|
2021-01-06 14:58:29 +01:00
|
|
|
class GridSearchQ(BaseQuantifier):
|
2021-11-09 15:44:57 +01:00
|
|
|
"""Grid Search optimization targeting a quantification-oriented metric.
|
2020-12-23 11:14:35 +01:00
|
|
|
|
2021-11-09 15:44:57 +01:00
|
|
|
Optimizes the hyperparameters of a quantification method, based on an evaluation method and on an evaluation
|
|
|
|
protocol for quantification.
|
|
|
|
|
|
|
|
:param model: the quantifier to optimize
|
|
|
|
:type model: BaseQuantifier
|
|
|
|
:param param_grid: a dictionary with keys the parameter names and values the list of values to explore
|
2023-02-14 17:00:50 +01:00
|
|
|
:param protocol: a sample generation protocol, an instance of :class:`quapy.protocol.AbstractProtocol`
|
2021-11-09 15:44:57 +01:00
|
|
|
:param error: an error function (callable) or a string indicating the name of an error function (valid ones
|
2023-02-14 17:00:50 +01:00
|
|
|
are those in :class:`quapy.error.QUANTIFICATION_ERROR`
|
2023-11-21 18:59:36 +01:00
|
|
|
:param refit: whether to refit the model on the whole labelled collection (training+validation) with
|
2021-11-08 18:01:49 +01:00
|
|
|
the best chosen hyperparameter combination. Ignored if protocol='gen'
|
2021-11-09 15:44:57 +01:00
|
|
|
:param timeout: establishes a timer (in seconds) for each of the hyperparameters configurations being tested.
|
2021-01-15 17:42:19 +01:00
|
|
|
Whenever a run takes longer than this timer, that configuration will be ignored. If all configurations end up
|
|
|
|
being ignored, a TimeoutError exception is raised. If -1 (default) then no time bound is set.
|
2023-11-21 18:59:36 +01:00
|
|
|
:param raise_errors: boolean, if True then raises an exception when a param combination yields any error, if
|
|
|
|
otherwise is False (default), then the combination is marked with an error status, but the process goes on.
|
|
|
|
However, if no configuration yields a valid model, then a ValueError exception will be raised.
|
2021-11-09 15:44:57 +01:00
|
|
|
:param verbose: set to True to get information through the stdout
|
|
|
|
"""
|
|
|
|
|
2023-11-06 01:58:36 +01:00
|
|
|
def __init__(self,
|
|
|
|
model: BaseQuantifier,
|
|
|
|
param_grid: dict,
|
|
|
|
protocol: AbstractProtocol,
|
|
|
|
error: Union[Callable, str] = qp.error.mae,
|
|
|
|
refit=True,
|
|
|
|
timeout=-1,
|
|
|
|
n_jobs=None,
|
2023-11-21 18:59:36 +01:00
|
|
|
raise_errors=False,
|
2023-11-06 01:58:36 +01:00
|
|
|
verbose=False):
|
|
|
|
|
2020-12-23 11:14:35 +01:00
|
|
|
self.model = model
|
|
|
|
self.param_grid = param_grid
|
2022-05-25 19:14:33 +02:00
|
|
|
self.protocol = protocol
|
2020-12-23 11:14:35 +01:00
|
|
|
self.refit = refit
|
2021-01-15 17:42:19 +01:00
|
|
|
self.timeout = timeout
|
2023-02-08 19:06:53 +01:00
|
|
|
self.n_jobs = qp._get_njobs(n_jobs)
|
2023-11-21 18:59:36 +01:00
|
|
|
self.raise_errors = raise_errors
|
2020-12-23 11:14:35 +01:00
|
|
|
self.verbose = verbose
|
|
|
|
self.__check_error(error)
|
2023-11-06 01:58:36 +01:00
|
|
|
assert isinstance(protocol, AbstractProtocol), 'unknown protocol'
|
2020-12-23 11:14:35 +01:00
|
|
|
|
2021-11-09 15:44:57 +01:00
|
|
|
def _sout(self, msg):
|
2020-12-23 11:14:35 +01:00
|
|
|
if self.verbose:
|
2023-11-06 01:58:36 +01:00
|
|
|
print(f'[{self.__class__.__name__}:{self.model.__class__.__name__}]: {msg}')
|
2020-12-23 11:14:35 +01:00
|
|
|
|
|
|
|
def __check_error(self, error):
|
|
|
|
if error in qp.error.QUANTIFICATION_ERROR:
|
|
|
|
self.error = error
|
|
|
|
elif isinstance(error, str):
|
2021-01-27 09:54:41 +01:00
|
|
|
self.error = qp.error.from_name(error)
|
2023-11-06 01:58:36 +01:00
|
|
|
elif hasattr(error, '__call__'):
|
2021-01-27 09:54:41 +01:00
|
|
|
self.error = error
|
2020-12-23 11:14:35 +01:00
|
|
|
else:
|
2023-11-06 01:58:36 +01:00
|
|
|
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
|
|
|
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
|
2020-12-23 11:14:35 +01:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
def _prepare_classifier(self, cls_params):
|
2023-11-16 14:29:34 +01:00
|
|
|
model = deepcopy(self.model)
|
2023-11-16 19:56:30 +01:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
def job(cls_params):
|
|
|
|
model.set_params(**cls_params)
|
|
|
|
predictions = model.classifier_fit_predict(self._training)
|
|
|
|
return predictions
|
2023-11-20 22:05:26 +01:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
predictions, status, took = self._error_handler(job, cls_params)
|
2024-01-17 19:15:50 +01:00
|
|
|
self._sout(f'[classifier fit] hyperparams={cls_params} [took {took:.3f}s]')
|
2023-11-21 18:59:36 +01:00
|
|
|
return model, predictions, status, took
|
2023-11-16 19:56:30 +01:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
def _prepare_aggregation(self, args):
|
|
|
|
model, predictions, cls_took, cls_params, q_params = args
|
|
|
|
model = deepcopy(model)
|
2023-11-16 19:56:30 +01:00
|
|
|
params = {**cls_params, **q_params}
|
2021-01-15 17:42:19 +01:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
def job(q_params):
|
2023-11-20 22:05:26 +01:00
|
|
|
model.set_params(**q_params)
|
2023-11-21 18:59:36 +01:00
|
|
|
model.aggregation_fit(predictions, self._training)
|
2023-11-20 22:05:26 +01:00
|
|
|
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
|
2023-11-21 18:59:36 +01:00
|
|
|
return score
|
|
|
|
|
|
|
|
score, status, aggr_took = self._error_handler(job, q_params)
|
|
|
|
self._print_status(params, score, status, aggr_took)
|
|
|
|
return model, params, score, status, (cls_took+aggr_took)
|
|
|
|
|
|
|
|
def _prepare_nonaggr_model(self, params):
|
2023-11-16 19:56:30 +01:00
|
|
|
model = deepcopy(self.model)
|
2022-05-25 19:14:33 +02:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
def job(params):
|
|
|
|
model.set_params(**params)
|
|
|
|
model.fit(self._training)
|
|
|
|
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
|
|
|
|
return score
|
2023-11-16 14:29:34 +01:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
score, status, took = self._error_handler(job, params)
|
|
|
|
self._print_status(params, score, status, took)
|
|
|
|
return model, params, score, status, took
|
2023-11-16 14:29:34 +01:00
|
|
|
|
2024-01-25 16:43:00 +01:00
|
|
|
def _break_down_fit(self):
|
|
|
|
"""
|
|
|
|
Decides whether to break down the fit phase in two (classifier-fit followed by aggregation-fit).
|
|
|
|
In order to do so, some conditions should be met: a) the quantifier is of type aggregative,
|
|
|
|
b) the set of hyperparameters can be split into two disjoint non-empty groups.
|
|
|
|
|
|
|
|
:return: True if the conditions are met, False otherwise
|
|
|
|
"""
|
|
|
|
if not isinstance(self.model, AggregativeQuantifier):
|
|
|
|
return False
|
|
|
|
cls_configs, q_configs = group_params(self.param_grid)
|
|
|
|
if (len(cls_configs) == 1) or (len(q_configs)==1):
|
|
|
|
return False
|
|
|
|
return True
|
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
def _compute_scores_aggregative(self, training):
|
2023-11-16 14:29:34 +01:00
|
|
|
# break down the set of hyperparameters into two: classifier-specific, quantifier-specific
|
|
|
|
cls_configs, q_configs = group_params(self.param_grid)
|
|
|
|
|
|
|
|
# train all classifiers and get the predictions
|
2023-11-21 18:59:36 +01:00
|
|
|
self._training = training
|
|
|
|
cls_outs = qp.util.parallel(
|
2023-11-16 19:56:30 +01:00
|
|
|
self._prepare_classifier,
|
2023-11-21 18:59:36 +01:00
|
|
|
cls_configs,
|
2023-11-16 14:29:34 +01:00
|
|
|
seed=qp.environ.get('_R_SEED', None),
|
2023-11-21 18:59:36 +01:00
|
|
|
n_jobs=self.n_jobs
|
2023-11-16 14:29:34 +01:00
|
|
|
)
|
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
# filter out classifier configurations that yielded any error
|
|
|
|
success_outs = []
|
|
|
|
for (model, predictions, status, took), cls_config in zip(cls_outs, cls_configs):
|
|
|
|
if status.success():
|
|
|
|
success_outs.append((model, predictions, took, cls_config))
|
|
|
|
else:
|
|
|
|
self.error_collector.append(status)
|
2023-11-16 19:56:30 +01:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
if len(success_outs) == 0:
|
|
|
|
raise ValueError('No valid configuration found for the classifier!')
|
2023-11-16 19:56:30 +01:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
# explore the quantifier-specific hyperparameters for each valid training configuration
|
|
|
|
aggr_configs = [(*out, q_config) for out, q_config in itertools.product(success_outs, q_configs)]
|
|
|
|
aggr_outs = qp.util.parallel(
|
2023-11-16 19:56:30 +01:00
|
|
|
self._prepare_aggregation,
|
2023-11-21 18:59:36 +01:00
|
|
|
aggr_configs,
|
2023-11-16 14:29:34 +01:00
|
|
|
seed=qp.environ.get('_R_SEED', None),
|
|
|
|
n_jobs=self.n_jobs
|
|
|
|
)
|
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
return aggr_outs
|
2023-11-16 14:29:34 +01:00
|
|
|
|
2023-11-16 19:56:30 +01:00
|
|
|
def _compute_scores_nonaggregative(self, training):
|
|
|
|
configs = expand_grid(self.param_grid)
|
2023-11-21 18:59:36 +01:00
|
|
|
self._training = training
|
2023-11-16 19:56:30 +01:00
|
|
|
scores = qp.util.parallel(
|
2023-11-21 18:59:36 +01:00
|
|
|
self._prepare_nonaggr_model,
|
|
|
|
configs,
|
2023-11-16 19:56:30 +01:00
|
|
|
seed=qp.environ.get('_R_SEED', None),
|
|
|
|
n_jobs=self.n_jobs
|
|
|
|
)
|
|
|
|
return scores
|
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
def _print_status(self, params, score, status, took):
|
|
|
|
if status.success():
|
|
|
|
self._sout(f'hyperparams=[{params}]\t got {self.error.__name__} = {score:.5f} [took {took:.3f}s]')
|
2023-11-16 19:56:30 +01:00
|
|
|
else:
|
2023-11-21 18:59:36 +01:00
|
|
|
self._sout(f'error={status}')
|
2023-11-16 14:29:34 +01:00
|
|
|
|
|
|
|
def fit(self, training: LabelledCollection):
|
|
|
|
""" Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
|
|
|
|
the error metric.
|
|
|
|
|
|
|
|
:param training: the training set on which to optimize the hyperparameters
|
|
|
|
:return: self
|
|
|
|
"""
|
|
|
|
|
|
|
|
if self.refit and not isinstance(self.protocol, OnLabelledCollectionProtocol):
|
2023-11-16 19:56:30 +01:00
|
|
|
raise RuntimeWarning(
|
|
|
|
f'"refit" was requested, but the protocol does not implement '
|
|
|
|
f'the {OnLabelledCollectionProtocol.__name__} interface'
|
|
|
|
)
|
2023-11-16 14:29:34 +01:00
|
|
|
|
|
|
|
tinit = time()
|
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
self.error_collector = []
|
|
|
|
|
2023-11-16 19:56:30 +01:00
|
|
|
self._sout(f'starting model selection with n_jobs={self.n_jobs}')
|
2024-01-25 16:43:00 +01:00
|
|
|
if self._break_down_fit():
|
2023-11-21 18:59:36 +01:00
|
|
|
results = self._compute_scores_aggregative(training)
|
|
|
|
else:
|
|
|
|
results = self._compute_scores_nonaggregative(training)
|
2023-11-16 14:29:34 +01:00
|
|
|
|
|
|
|
self.param_scores_ = {}
|
|
|
|
self.best_score_ = None
|
2023-11-21 18:59:36 +01:00
|
|
|
for model, params, score, status, took in results:
|
|
|
|
if status.success():
|
2023-11-16 14:29:34 +01:00
|
|
|
if self.best_score_ is None or score < self.best_score_:
|
|
|
|
self.best_score_ = score
|
|
|
|
self.best_params_ = params
|
2023-11-21 18:59:36 +01:00
|
|
|
self.best_model_ = model
|
2023-11-16 14:29:34 +01:00
|
|
|
self.param_scores_[str(params)] = score
|
|
|
|
else:
|
2023-11-21 18:59:36 +01:00
|
|
|
self.param_scores_[str(params)] = status.status
|
|
|
|
self.error_collector.append(status)
|
2023-11-16 14:29:34 +01:00
|
|
|
|
|
|
|
tend = time()-tinit
|
|
|
|
|
|
|
|
if self.best_score_ is None:
|
2023-11-21 18:59:36 +01:00
|
|
|
raise ValueError('no combination of hyperparameters seemed to work')
|
2023-11-16 14:29:34 +01:00
|
|
|
|
|
|
|
self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) '
|
|
|
|
f'[took {tend:.4f}s]')
|
2022-05-25 19:14:33 +02:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
no_errors = len(self.error_collector)
|
|
|
|
if no_errors>0:
|
|
|
|
self._sout(f'warning: {no_errors} errors found')
|
|
|
|
for err in self.error_collector:
|
|
|
|
self._sout(f'\t{str(err)}')
|
|
|
|
|
2023-11-16 14:29:34 +01:00
|
|
|
if self.refit:
|
|
|
|
if isinstance(self.protocol, OnLabelledCollectionProtocol):
|
|
|
|
tinit = time()
|
|
|
|
self._sout(f'refitting on the whole development set')
|
|
|
|
self.best_model_.fit(training + self.protocol.get_labelled_collection())
|
|
|
|
tend = time() - tinit
|
|
|
|
self.refit_time_ = tend
|
|
|
|
else:
|
2023-11-21 18:59:36 +01:00
|
|
|
# already checked
|
2023-11-16 14:29:34 +01:00
|
|
|
raise RuntimeWarning(f'the model cannot be refit on the whole dataset')
|
|
|
|
|
|
|
|
return self
|
2023-11-06 01:58:36 +01:00
|
|
|
|
2021-01-06 14:58:29 +01:00
|
|
|
def quantify(self, instances):
|
2021-11-24 11:20:42 +01:00
|
|
|
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
|
2021-11-09 15:44:57 +01:00
|
|
|
|
|
|
|
:param instances: sample contanining the instances
|
2021-11-24 11:20:42 +01:00
|
|
|
:return: a ndarray of shape `(n_classes)` with class prevalence estimates as according to the best model found
|
|
|
|
by the model selection process.
|
2021-11-09 15:44:57 +01:00
|
|
|
"""
|
2023-11-06 01:58:36 +01:00
|
|
|
assert hasattr(self, 'best_model_'), 'quantify called before fit'
|
2021-10-26 18:41:10 +02:00
|
|
|
return self.best_model().quantify(instances)
|
2021-01-06 14:58:29 +01:00
|
|
|
|
|
|
|
def set_params(self, **parameters):
|
2021-11-09 15:44:57 +01:00
|
|
|
"""Sets the hyper-parameters to explore.
|
|
|
|
|
|
|
|
:param parameters: a dictionary with keys the parameter names and values the list of values to explore
|
|
|
|
"""
|
2021-01-06 14:58:29 +01:00
|
|
|
self.param_grid = parameters
|
|
|
|
|
|
|
|
def get_params(self, deep=True):
|
2021-11-09 15:44:57 +01:00
|
|
|
"""Returns the dictionary of hyper-parameters to explore (`param_grid`)
|
|
|
|
|
|
|
|
:param deep: Unused
|
|
|
|
:return: the dictionary `param_grid`
|
|
|
|
"""
|
2021-01-06 14:58:29 +01:00
|
|
|
return self.param_grid
|
2020-12-23 11:14:35 +01:00
|
|
|
|
2021-01-11 18:31:12 +01:00
|
|
|
def best_model(self):
|
2021-11-24 11:20:42 +01:00
|
|
|
"""
|
|
|
|
Returns the best model found after calling the :meth:`fit` method, i.e., the one trained on the combination
|
|
|
|
of hyper-parameters that minimized the error function.
|
|
|
|
|
|
|
|
:return: a trained quantifier
|
|
|
|
"""
|
2023-11-06 01:58:36 +01:00
|
|
|
if hasattr(self, 'best_model_'):
|
2021-01-11 18:31:12 +01:00
|
|
|
return self.best_model_
|
2023-11-06 01:58:36 +01:00
|
|
|
raise ValueError('best_model called before fit')
|
2022-05-25 19:14:33 +02:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
def _error_handler(self, func, params):
|
|
|
|
"""
|
|
|
|
Endorses one job with two returned values: the status, and the time of execution
|
2022-12-12 17:32:30 +01:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
:param func: the function to be called
|
|
|
|
:param params: parameters of the function
|
|
|
|
:return: `tuple(out, status, time)` where `out` is the function output,
|
|
|
|
`status` is an enum value from `Status`, and `time` is the time it
|
|
|
|
took to complete the call
|
|
|
|
"""
|
|
|
|
|
|
|
|
output = None
|
|
|
|
|
|
|
|
def _handle(status, exception):
|
|
|
|
if self.raise_errors:
|
|
|
|
raise exception
|
|
|
|
else:
|
|
|
|
return ConfigStatus(params, status, str(e))
|
2023-11-20 22:05:26 +01:00
|
|
|
|
|
|
|
try:
|
|
|
|
with timeout(self.timeout):
|
2023-11-21 18:59:36 +01:00
|
|
|
tinit = time()
|
|
|
|
output = func(params)
|
|
|
|
status = ConfigStatus(params, Status.SUCCESS)
|
2023-11-20 22:05:26 +01:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
except TimeoutError as e:
|
|
|
|
status = _handle(Status.TIMEOUT, str(e))
|
2023-11-20 22:05:26 +01:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
except ValueError as e:
|
|
|
|
status = _handle(Status.INVALID, str(e))
|
2023-11-20 22:05:26 +01:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
except Exception as e:
|
|
|
|
status = _handle(Status.ERROR, str(e))
|
2023-11-20 22:05:26 +01:00
|
|
|
|
2023-11-21 18:59:36 +01:00
|
|
|
took = time() - tinit
|
|
|
|
return output, status, took
|
2023-11-20 22:05:26 +01:00
|
|
|
|
|
|
|
|
2023-11-06 01:58:36 +01:00
|
|
|
def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
|
2022-12-12 17:32:30 +01:00
|
|
|
"""
|
|
|
|
Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_
|
|
|
|
but for quantification.
|
|
|
|
|
|
|
|
:param quantifier: a quantifier issuing class prevalence values
|
|
|
|
:param data: a labelled collection
|
|
|
|
:param nfolds: number of folds for k-fold cross validation generation
|
|
|
|
:param random_state: random seed for reproducibility
|
|
|
|
:return: a vector of class prevalence values
|
|
|
|
"""
|
|
|
|
|
|
|
|
total_prev = np.zeros(shape=data.n_classes)
|
|
|
|
|
|
|
|
for train, test in data.kFCV(nfolds=nfolds, random_state=random_state):
|
|
|
|
quantifier.fit(train)
|
2024-02-07 18:31:34 +01:00
|
|
|
fold_prev = quantifier.quantify(test.Xtr)
|
2023-11-06 01:58:36 +01:00
|
|
|
rel_size = 1. * len(test) / len(data)
|
|
|
|
total_prev += fold_prev*rel_size
|
2022-12-12 17:32:30 +01:00
|
|
|
|
|
|
|
return total_prev
|
2023-11-06 01:58:36 +01:00
|
|
|
|
|
|
|
|
2023-11-15 10:55:13 +01:00
|
|
|
def expand_grid(param_grid: dict):
|
|
|
|
"""
|
|
|
|
Expands a param_grid dictionary as a list of configurations.
|
|
|
|
Example:
|
|
|
|
|
|
|
|
>>> combinations = expand_grid({'A': [1, 10, 100], 'B': [True, False]})
|
|
|
|
>>> print(combinations)
|
|
|
|
>>> [{'A': 1, 'B': True}, {'A': 1, 'B': False}, {'A': 10, 'B': True}, {'A': 10, 'B': False}, {'A': 100, 'B': True}, {'A': 100, 'B': False}]
|
|
|
|
|
|
|
|
:param param_grid: dictionary with keys representing hyper-parameter names, and values representing the range
|
|
|
|
to explore for that hyper-parameter
|
|
|
|
:return: a list of configurations, i.e., combinations of hyper-parameter assignments in the grid.
|
|
|
|
"""
|
|
|
|
params_keys = list(param_grid.keys())
|
|
|
|
params_values = list(param_grid.values())
|
|
|
|
configs = [{k: combs[i] for i, k in enumerate(params_keys)} for combs in itertools.product(*params_values)]
|
|
|
|
return configs
|
|
|
|
|
|
|
|
|
|
|
|
def group_params(param_grid: dict):
|
|
|
|
"""
|
|
|
|
Partitions a param_grid dictionary as two lists of configurations, one for the classifier-specific
|
|
|
|
hyper-parameters, and another for que quantifier-specific hyper-parameters
|
|
|
|
|
|
|
|
:param param_grid: dictionary with keys representing hyper-parameter names, and values representing the range
|
|
|
|
to explore for that hyper-parameter
|
|
|
|
:return: two expanded grids of configurations, one for the classifier, another for the quantifier
|
|
|
|
"""
|
|
|
|
classifier_params, quantifier_params = {}, {}
|
|
|
|
for key, values in param_grid.items():
|
|
|
|
if key.startswith('classifier__') or key == 'val_split':
|
|
|
|
classifier_params[key] = values
|
|
|
|
else:
|
|
|
|
quantifier_params[key] = values
|
|
|
|
|
|
|
|
classifier_configs = expand_grid(classifier_params)
|
|
|
|
quantifier_configs = expand_grid(quantifier_params)
|
|
|
|
|
|
|
|
return classifier_configs, quantifier_configs
|
|
|
|
|