forked from moreo/QuaPy
merged
This commit is contained in:
commit
c56fe9c09c
|
@ -53,6 +53,7 @@ with qp.util.temp_seed(0):
|
||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
error='mae', # the error to optimize is the MAE (a quantification-oriented loss)
|
error='mae', # the error to optimize is the MAE (a quantification-oriented loss)
|
||||||
refit=False, # retrain on the whole labelled set once done
|
refit=False, # retrain on the whole labelled set once done
|
||||||
|
raise_errors=False,
|
||||||
verbose=True # show information as the process goes on
|
verbose=True # show information as the process goes on
|
||||||
).fit(training)
|
).fit(training)
|
||||||
|
|
||||||
|
@ -66,5 +67,5 @@ model = model.best_model_
|
||||||
mae_score = qp.evaluation.evaluate(model, protocol=APP(test), error_metric='mae')
|
mae_score = qp.evaluation.evaluate(model, protocol=APP(test), error_metric='mae')
|
||||||
|
|
||||||
print(f'MAE={mae_score:.5f}')
|
print(f'MAE={mae_score:.5f}')
|
||||||
print(f'model selection took {tend-tinit}s')
|
print(f'model selection took {tend-tinit:.1f}s')
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ from sklearn.neighbors import KernelDensity
|
||||||
|
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
from quapy.data import LabelledCollection
|
from quapy.data import LabelledCollection
|
||||||
from quapy.method.aggregative import AggregativeProbabilisticQuantifier, cross_generate_predictions
|
from quapy.method.aggregative import AggregativeSoftQuantifier
|
||||||
import quapy.functional as F
|
import quapy.functional as F
|
||||||
|
|
||||||
from sklearn.metrics.pairwise import rbf_kernel
|
from sklearn.metrics.pairwise import rbf_kernel
|
||||||
|
@ -33,7 +33,7 @@ class KDEBase:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class KDEyML(AggregativeProbabilisticQuantifier, KDEBase):
|
class KDEyML(AggregativeSoftQuantifier, KDEBase):
|
||||||
|
|
||||||
def __init__(self, classifier: BaseEstimator, val_split=10, bandwidth=0.1, n_jobs=None, random_state=0):
|
def __init__(self, classifier: BaseEstimator, val_split=10, bandwidth=0.1, n_jobs=None, random_state=0):
|
||||||
self._check_bandwidth(bandwidth)
|
self._check_bandwidth(bandwidth)
|
||||||
|
@ -43,16 +43,8 @@ class KDEyML(AggregativeProbabilisticQuantifier, KDEBase):
|
||||||
self.n_jobs = n_jobs
|
self.n_jobs = n_jobs
|
||||||
self.random_state=random_state
|
self.random_state=random_state
|
||||||
|
|
||||||
def fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, LabelledCollection] = None):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
if val_split is None:
|
self.mix_densities = self.get_mixture_components(*classif_predictions.Xy, data.n_classes, self.bandwidth)
|
||||||
val_split = self.val_split
|
|
||||||
|
|
||||||
self.classifier, y, posteriors, _, _ = cross_generate_predictions(
|
|
||||||
data, self.classifier, val_split, probabilistic=True, fit_classifier=fit_classifier, n_jobs=self.n_jobs
|
|
||||||
)
|
|
||||||
|
|
||||||
self.mix_densities = self.get_mixture_components(posteriors, y, data.n_classes, self.bandwidth)
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def aggregate(self, posteriors: np.ndarray):
|
def aggregate(self, posteriors: np.ndarray):
|
||||||
|
@ -76,7 +68,7 @@ class KDEyML(AggregativeProbabilisticQuantifier, KDEBase):
|
||||||
return F.optim_minimize(neg_loglikelihood, n_classes)
|
return F.optim_minimize(neg_loglikelihood, n_classes)
|
||||||
|
|
||||||
|
|
||||||
class KDEyHD(AggregativeProbabilisticQuantifier, KDEBase):
|
class KDEyHD(AggregativeSoftQuantifier, KDEBase):
|
||||||
|
|
||||||
def __init__(self, classifier: BaseEstimator, val_split=10, divergence: str='HD',
|
def __init__(self, classifier: BaseEstimator, val_split=10, divergence: str='HD',
|
||||||
bandwidth=0.1, n_jobs=None, random_state=0, montecarlo_trials=10000):
|
bandwidth=0.1, n_jobs=None, random_state=0, montecarlo_trials=10000):
|
||||||
|
@ -90,15 +82,8 @@ class KDEyHD(AggregativeProbabilisticQuantifier, KDEBase):
|
||||||
self.random_state=random_state
|
self.random_state=random_state
|
||||||
self.montecarlo_trials = montecarlo_trials
|
self.montecarlo_trials = montecarlo_trials
|
||||||
|
|
||||||
def fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, LabelledCollection] = None):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
if val_split is None:
|
self.mix_densities = self.get_mixture_components(*classif_predictions.Xy, data.n_classes, self.bandwidth)
|
||||||
val_split = self.val_split
|
|
||||||
|
|
||||||
self.classifier, y, posteriors, _, _ = cross_generate_predictions(
|
|
||||||
data, self.classifier, val_split, probabilistic=True, fit_classifier=fit_classifier, n_jobs=self.n_jobs
|
|
||||||
)
|
|
||||||
|
|
||||||
self.mix_densities = self.get_mixture_components(posteriors, y, data.n_classes, self.bandwidth)
|
|
||||||
|
|
||||||
N = self.montecarlo_trials
|
N = self.montecarlo_trials
|
||||||
rs = self.random_state
|
rs = self.random_state
|
||||||
|
@ -141,7 +126,7 @@ class KDEyHD(AggregativeProbabilisticQuantifier, KDEBase):
|
||||||
return F.optim_minimize(divergence, n_classes)
|
return F.optim_minimize(divergence, n_classes)
|
||||||
|
|
||||||
|
|
||||||
class KDEyCS(AggregativeProbabilisticQuantifier):
|
class KDEyCS(AggregativeSoftQuantifier):
|
||||||
|
|
||||||
def __init__(self, classifier: BaseEstimator, val_split=10, bandwidth=0.1, n_jobs=None, random_state=0):
|
def __init__(self, classifier: BaseEstimator, val_split=10, bandwidth=0.1, n_jobs=None, random_state=0):
|
||||||
KDEBase._check_bandwidth(bandwidth)
|
KDEBase._check_bandwidth(bandwidth)
|
||||||
|
@ -163,19 +148,14 @@ class KDEyCS(AggregativeProbabilisticQuantifier):
|
||||||
gram = norm_factor * rbf_kernel(X, Y, gamma=gamma)
|
gram = norm_factor * rbf_kernel(X, Y, gamma=gamma)
|
||||||
return gram.sum()
|
return gram.sum()
|
||||||
|
|
||||||
def fit(self, data: LabelledCollection, fit_classifier=True, val_split: Union[float, LabelledCollection] = None):
|
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||||
if val_split is None:
|
|
||||||
val_split = self.val_split
|
|
||||||
|
|
||||||
self.classifier, y, posteriors, _, _ = cross_generate_predictions(
|
P, y = classif_predictions.Xy
|
||||||
data, self.classifier, val_split, probabilistic=True, fit_classifier=fit_classifier, n_jobs=self.n_jobs
|
n = data.n_classes
|
||||||
)
|
|
||||||
|
|
||||||
assert all(sorted(np.unique(y)) == np.arange(data.n_classes)), \
|
assert all(sorted(np.unique(y)) == np.arange(n)), \
|
||||||
'label name gaps not allowed in current implementation'
|
'label name gaps not allowed in current implementation'
|
||||||
|
|
||||||
n = data.n_classes
|
|
||||||
P = posteriors
|
|
||||||
|
|
||||||
# counts_inv keeps track of the relative weight of each datapoint within its class
|
# counts_inv keeps track of the relative weight of each datapoint within its class
|
||||||
# (i.e., the weight in its KDE model)
|
# (i.e., the weight in its KDE model)
|
||||||
|
|
|
@ -23,54 +23,24 @@ class Status(Enum):
|
||||||
INVALID = 3
|
INVALID = 3
|
||||||
ERROR = 4
|
ERROR = 4
|
||||||
|
|
||||||
def check_status(func):
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
obj = args[0]
|
|
||||||
tinit = time()
|
|
||||||
|
|
||||||
job_descriptor = dict(args[1])
|
class ConfigStatus:
|
||||||
params = {**job_descriptor.get('cls-params', {}), **job_descriptor.get('q-params', {})}
|
def __init__(self, params, status, msg=''):
|
||||||
|
self.params = params
|
||||||
|
self.status = status
|
||||||
|
self.msg = msg
|
||||||
|
|
||||||
if obj.timeout > 0:
|
def __str__(self):
|
||||||
def handler(signum, frame):
|
return f':params:{self.params} :status:{self.status} ' + self.msg
|
||||||
raise TimeoutError()
|
|
||||||
|
|
||||||
signal.signal(signal.SIGALRM, handler)
|
def __repr__(self):
|
||||||
signal.alarm(obj.timeout)
|
return str(self)
|
||||||
|
|
||||||
try:
|
def success(self):
|
||||||
job_descriptor = func(*args, **kwargs)
|
return self.status == Status.SUCCESS
|
||||||
|
|
||||||
ttime = time() - tinit
|
def failed(self):
|
||||||
|
return self.status != Status.SUCCESS
|
||||||
score = job_descriptor.get('score', None)
|
|
||||||
if score is not None:
|
|
||||||
obj._sout(f'hyperparams=[{params}]\t got {obj.error.__name__} = {score:.5f} [took {ttime:.4f}s]')
|
|
||||||
|
|
||||||
if obj.timeout > 0:
|
|
||||||
signal.alarm(0)
|
|
||||||
|
|
||||||
exit_status = Status.SUCCESS
|
|
||||||
|
|
||||||
except TimeoutError:
|
|
||||||
obj._sout(f'timeout ({obj.timeout}s) reached for config {params}')
|
|
||||||
exit_status = Status.TIMEOUT
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
obj._sout(f'the combination of hyperparameters {params} is invalid')
|
|
||||||
obj._sout(f'\tException: {e}')
|
|
||||||
exit_status = Status.INVALID
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
obj._sout(f'something went wrong for config {params}; skipping:')
|
|
||||||
obj._sout(f'\tException: {e}')
|
|
||||||
exit_status = Status.ERROR
|
|
||||||
|
|
||||||
job_descriptor['status'] = exit_status
|
|
||||||
job_descriptor['params'] = params
|
|
||||||
return job_descriptor
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
class GridSearchQ(BaseQuantifier):
|
class GridSearchQ(BaseQuantifier):
|
||||||
|
@ -85,11 +55,14 @@ class GridSearchQ(BaseQuantifier):
|
||||||
:param protocol: a sample generation protocol, an instance of :class:`quapy.protocol.AbstractProtocol`
|
:param protocol: a sample generation protocol, an instance of :class:`quapy.protocol.AbstractProtocol`
|
||||||
:param error: an error function (callable) or a string indicating the name of an error function (valid ones
|
:param error: an error function (callable) or a string indicating the name of an error function (valid ones
|
||||||
are those in :class:`quapy.error.QUANTIFICATION_ERROR`
|
are those in :class:`quapy.error.QUANTIFICATION_ERROR`
|
||||||
:param refit: whether or not to refit the model on the whole labelled collection (training+validation) with
|
:param refit: whether to refit the model on the whole labelled collection (training+validation) with
|
||||||
the best chosen hyperparameter combination. Ignored if protocol='gen'
|
the best chosen hyperparameter combination. Ignored if protocol='gen'
|
||||||
:param timeout: establishes a timer (in seconds) for each of the hyperparameters configurations being tested.
|
:param timeout: establishes a timer (in seconds) for each of the hyperparameters configurations being tested.
|
||||||
Whenever a run takes longer than this timer, that configuration will be ignored. If all configurations end up
|
Whenever a run takes longer than this timer, that configuration will be ignored. If all configurations end up
|
||||||
being ignored, a TimeoutError exception is raised. If -1 (default) then no time bound is set.
|
being ignored, a TimeoutError exception is raised. If -1 (default) then no time bound is set.
|
||||||
|
:param raise_errors: boolean, if True then raises an exception when a param combination yields any error, if
|
||||||
|
otherwise is False (default), then the combination is marked with an error status, but the process goes on.
|
||||||
|
However, if no configuration yields a valid model, then a ValueError exception will be raised.
|
||||||
:param verbose: set to True to get information through the stdout
|
:param verbose: set to True to get information through the stdout
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -101,6 +74,7 @@ class GridSearchQ(BaseQuantifier):
|
||||||
refit=True,
|
refit=True,
|
||||||
timeout=-1,
|
timeout=-1,
|
||||||
n_jobs=None,
|
n_jobs=None,
|
||||||
|
raise_errors=False,
|
||||||
verbose=False):
|
verbose=False):
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
@ -109,6 +83,7 @@ class GridSearchQ(BaseQuantifier):
|
||||||
self.refit = refit
|
self.refit = refit
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.n_jobs = qp._get_njobs(n_jobs)
|
self.n_jobs = qp._get_njobs(n_jobs)
|
||||||
|
self.raise_errors = raise_errors
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.__check_error(error)
|
self.__check_error(error)
|
||||||
assert isinstance(protocol, AbstractProtocol), 'unknown protocol'
|
assert isinstance(protocol, AbstractProtocol), 'unknown protocol'
|
||||||
|
@ -128,112 +103,97 @@ class GridSearchQ(BaseQuantifier):
|
||||||
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
||||||
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
|
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
|
||||||
|
|
||||||
def _prepare_classifier(self, args):
|
def _prepare_classifier(self, cls_params):
|
||||||
cls_params = args['cls-params']
|
|
||||||
training = args['training']
|
|
||||||
model = deepcopy(self.model)
|
model = deepcopy(self.model)
|
||||||
model.set_params(**cls_params)
|
|
||||||
predictions = model.classifier_fit_predict(training)
|
def job(cls_params):
|
||||||
return {'model': model, 'predictions': predictions, 'cls-params': cls_params}
|
model.set_params(**cls_params)
|
||||||
|
predictions = model.classifier_fit_predict(self._training)
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
predictions, status, took = self._error_handler(job, cls_params)
|
||||||
|
self._sout(f'[classifier fit] hyperparams={cls_params} status={status} [took {took:.3f}s]')
|
||||||
|
return model, predictions, status, took
|
||||||
|
|
||||||
def _prepare_aggregation(self, args):
|
def _prepare_aggregation(self, args):
|
||||||
|
model, predictions, cls_took, cls_params, q_params = args
|
||||||
model = args['model']
|
model = deepcopy(model)
|
||||||
predictions = args['predictions']
|
|
||||||
cls_params = args['cls-params']
|
|
||||||
q_params = args['q-params']
|
|
||||||
training = args['training']
|
|
||||||
|
|
||||||
params = {**cls_params, **q_params}
|
params = {**cls_params, **q_params}
|
||||||
|
|
||||||
def job(model):
|
def job(q_params):
|
||||||
tinit = time()
|
|
||||||
model = deepcopy(model)
|
|
||||||
# overrides default parameters with the parameters being explored at this iteration
|
|
||||||
model.set_params(**q_params)
|
model.set_params(**q_params)
|
||||||
model.aggregation_fit(predictions, training)
|
model.aggregation_fit(predictions, self._training)
|
||||||
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
|
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
|
||||||
ttime = time()-tinit
|
return score
|
||||||
|
|
||||||
return {
|
score, status, aggr_took = self._error_handler(job, q_params)
|
||||||
'model': model,
|
self._print_status(params, score, status, aggr_took)
|
||||||
'cls-params':cls_params,
|
return model, params, score, status, (cls_took+aggr_took)
|
||||||
'q-params': q_params,
|
|
||||||
'params': params,
|
|
||||||
'score': score,
|
|
||||||
'ttime':ttime
|
|
||||||
}
|
|
||||||
|
|
||||||
out, status = self._error_handler(job, args)
|
def _prepare_nonaggr_model(self, params):
|
||||||
if status == Status.SUCCESS:
|
|
||||||
self._sout(f'hyperparams=[{params}]\t got {self.error.__name__} = {out["score"]:.5f} [took {out["time"]:.4f}s]')
|
|
||||||
elif status == Status.INVALID:
|
|
||||||
self._sout(f'the combination of hyperparameters {params} is invalid')
|
|
||||||
elif status == Status.
|
|
||||||
|
|
||||||
|
|
||||||
def _prepare_model(self, args):
|
|
||||||
params, training = args
|
|
||||||
model = deepcopy(self.model)
|
model = deepcopy(self.model)
|
||||||
# overrides default parameters with the parameters being explored at this iteration
|
|
||||||
model.set_params(**params)
|
|
||||||
model.fit(training)
|
|
||||||
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
|
|
||||||
return {'model': model, 'params': params, 'score': score}
|
|
||||||
|
|
||||||
|
def job(params):
|
||||||
|
model.set_params(**params)
|
||||||
|
model.fit(self._training)
|
||||||
|
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
|
||||||
|
return score
|
||||||
|
|
||||||
|
score, status, took = self._error_handler(job, params)
|
||||||
|
self._print_status(params, score, status, took)
|
||||||
|
return model, params, score, status, took
|
||||||
|
|
||||||
def _compute_scores_aggregative(self, training):
|
def _compute_scores_aggregative(self, training):
|
||||||
|
|
||||||
# break down the set of hyperparameters into two: classifier-specific, quantifier-specific
|
# break down the set of hyperparameters into two: classifier-specific, quantifier-specific
|
||||||
cls_configs, q_configs = group_params(self.param_grid)
|
cls_configs, q_configs = group_params(self.param_grid)
|
||||||
|
|
||||||
# train all classifiers and get the predictions
|
# train all classifiers and get the predictions
|
||||||
partial_setups = qp.util.parallel(
|
self._training = training
|
||||||
|
cls_outs = qp.util.parallel(
|
||||||
self._prepare_classifier,
|
self._prepare_classifier,
|
||||||
({'cls-params':params, 'training':training} for params in cls_configs),
|
cls_configs,
|
||||||
seed=qp.environ.get('_R_SEED', None),
|
|
||||||
n_jobs=self.n_jobs,
|
|
||||||
asarray=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# filter out classifier configurations that yield any error
|
|
||||||
for setup in partial_setups:
|
|
||||||
if setup['status'] != Status.SUCCESS:
|
|
||||||
self._sout(f'-> classifier hyperparemters {setup["params"]} caused '
|
|
||||||
f'error {setup["status"]} and will be ignored')
|
|
||||||
|
|
||||||
partial_setups = [setup for setup in partial_setups if setup['status']==Status.SUCCESS]
|
|
||||||
|
|
||||||
if len(partial_setups) == 0:
|
|
||||||
raise ValueError('No valid configuration found for the classifier.')
|
|
||||||
|
|
||||||
# explore the quantifier-specific hyperparameters for each training configuration
|
|
||||||
scores = qp.util.parallel(
|
|
||||||
self._prepare_aggregation,
|
|
||||||
({'q-params': setup[1], 'training': training, **setup[0]} for setup in itertools.product(partial_setups, q_configs)),
|
|
||||||
seed=qp.environ.get('_R_SEED', None),
|
seed=qp.environ.get('_R_SEED', None),
|
||||||
n_jobs=self.n_jobs
|
n_jobs=self.n_jobs
|
||||||
)
|
)
|
||||||
|
|
||||||
return scores
|
# filter out classifier configurations that yielded any error
|
||||||
|
success_outs = []
|
||||||
|
for (model, predictions, status, took), cls_config in zip(cls_outs, cls_configs):
|
||||||
|
if status.success():
|
||||||
|
success_outs.append((model, predictions, took, cls_config))
|
||||||
|
else:
|
||||||
|
self.error_collector.append(status)
|
||||||
|
|
||||||
|
if len(success_outs) == 0:
|
||||||
|
raise ValueError('No valid configuration found for the classifier!')
|
||||||
|
|
||||||
|
# explore the quantifier-specific hyperparameters for each valid training configuration
|
||||||
|
aggr_configs = [(*out, q_config) for out, q_config in itertools.product(success_outs, q_configs)]
|
||||||
|
aggr_outs = qp.util.parallel(
|
||||||
|
self._prepare_aggregation,
|
||||||
|
aggr_configs,
|
||||||
|
seed=qp.environ.get('_R_SEED', None),
|
||||||
|
n_jobs=self.n_jobs
|
||||||
|
)
|
||||||
|
|
||||||
|
return aggr_outs
|
||||||
|
|
||||||
def _compute_scores_nonaggregative(self, training):
|
def _compute_scores_nonaggregative(self, training):
|
||||||
configs = expand_grid(self.param_grid)
|
configs = expand_grid(self.param_grid)
|
||||||
|
self._training = training
|
||||||
# pass a seed to parallel, so it is set in child processes
|
|
||||||
scores = qp.util.parallel(
|
scores = qp.util.parallel(
|
||||||
self._prepare_model,
|
self._prepare_nonaggr_model,
|
||||||
((params, training) for params in configs),
|
configs,
|
||||||
seed=qp.environ.get('_R_SEED', None),
|
seed=qp.environ.get('_R_SEED', None),
|
||||||
n_jobs=self.n_jobs
|
n_jobs=self.n_jobs
|
||||||
)
|
)
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
def _compute_scores(self, training):
|
def _print_status(self, params, score, status, took):
|
||||||
if isinstance(self.model, AggregativeQuantifier):
|
if status.success():
|
||||||
return self._compute_scores_aggregative(training)
|
self._sout(f'hyperparams=[{params}]\t got {self.error.__name__} = {score:.5f} [took {took:.3f}s]')
|
||||||
else:
|
else:
|
||||||
return self._compute_scores_nonaggregative(training)
|
self._sout(f'error={status}')
|
||||||
|
|
||||||
def fit(self, training: LabelledCollection):
|
def fit(self, training: LabelledCollection):
|
||||||
""" Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
|
""" Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
|
||||||
|
@ -251,31 +211,41 @@ class GridSearchQ(BaseQuantifier):
|
||||||
|
|
||||||
tinit = time()
|
tinit = time()
|
||||||
|
|
||||||
|
self.error_collector = []
|
||||||
|
|
||||||
self._sout(f'starting model selection with n_jobs={self.n_jobs}')
|
self._sout(f'starting model selection with n_jobs={self.n_jobs}')
|
||||||
results = self._compute_scores(training)
|
if isinstance(self.model, AggregativeQuantifier):
|
||||||
|
results = self._compute_scores_aggregative(training)
|
||||||
|
else:
|
||||||
|
results = self._compute_scores_nonaggregative(training)
|
||||||
|
|
||||||
self.param_scores_ = {}
|
self.param_scores_ = {}
|
||||||
self.best_score_ = None
|
self.best_score_ = None
|
||||||
for job_result in results:
|
for model, params, score, status, took in results:
|
||||||
score = job_result.get('score', None)
|
if status.success():
|
||||||
params = job_result['params']
|
|
||||||
if score is not None:
|
|
||||||
if self.best_score_ is None or score < self.best_score_:
|
if self.best_score_ is None or score < self.best_score_:
|
||||||
self.best_score_ = score
|
self.best_score_ = score
|
||||||
self.best_params_ = params
|
self.best_params_ = params
|
||||||
self.best_model_ = job_result['model']
|
self.best_model_ = model
|
||||||
self.param_scores_[str(params)] = score
|
self.param_scores_[str(params)] = score
|
||||||
else:
|
else:
|
||||||
self.param_scores_[str(params)] = job_result['status']
|
self.param_scores_[str(params)] = status.status
|
||||||
|
self.error_collector.append(status)
|
||||||
|
|
||||||
tend = time()-tinit
|
tend = time()-tinit
|
||||||
|
|
||||||
if self.best_score_ is None:
|
if self.best_score_ is None:
|
||||||
raise TimeoutError('no combination of hyperparameters seem to work')
|
raise ValueError('no combination of hyperparameters seemed to work')
|
||||||
|
|
||||||
self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) '
|
self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) '
|
||||||
f'[took {tend:.4f}s]')
|
f'[took {tend:.4f}s]')
|
||||||
|
|
||||||
|
no_errors = len(self.error_collector)
|
||||||
|
if no_errors>0:
|
||||||
|
self._sout(f'warning: {no_errors} errors found')
|
||||||
|
for err in self.error_collector:
|
||||||
|
self._sout(f'\t{str(err)}')
|
||||||
|
|
||||||
if self.refit:
|
if self.refit:
|
||||||
if isinstance(self.protocol, OnLabelledCollectionProtocol):
|
if isinstance(self.protocol, OnLabelledCollectionProtocol):
|
||||||
tinit = time()
|
tinit = time()
|
||||||
|
@ -284,6 +254,7 @@ class GridSearchQ(BaseQuantifier):
|
||||||
tend = time() - tinit
|
tend = time() - tinit
|
||||||
self.refit_time_ = tend
|
self.refit_time_ = tend
|
||||||
else:
|
else:
|
||||||
|
# already checked
|
||||||
raise RuntimeWarning(f'the model cannot be refit on the whole dataset')
|
raise RuntimeWarning(f'the model cannot be refit on the whole dataset')
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
@ -324,23 +295,42 @@ class GridSearchQ(BaseQuantifier):
|
||||||
return self.best_model_
|
return self.best_model_
|
||||||
raise ValueError('best_model called before fit')
|
raise ValueError('best_model called before fit')
|
||||||
|
|
||||||
|
def _error_handler(self, func, params):
|
||||||
|
"""
|
||||||
|
Endorses one job with two returned values: the status, and the time of execution
|
||||||
|
|
||||||
def _error_handler(self, func, *args, **kwargs):
|
:param func: the function to be called
|
||||||
|
:param params: parameters of the function
|
||||||
|
:return: `tuple(out, status, time)` where `out` is the function output,
|
||||||
|
`status` is an enum value from `Status`, and `time` is the time it
|
||||||
|
took to complete the call
|
||||||
|
"""
|
||||||
|
|
||||||
|
output = None
|
||||||
|
|
||||||
|
def _handle(status, exception):
|
||||||
|
if self.raise_errors:
|
||||||
|
raise exception
|
||||||
|
else:
|
||||||
|
return ConfigStatus(params, status, str(e))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with timeout(self.timeout):
|
with timeout(self.timeout):
|
||||||
output = func(*args, **kwargs)
|
tinit = time()
|
||||||
return output, Status.SUCCESS
|
output = func(params)
|
||||||
|
status = ConfigStatus(params, Status.SUCCESS)
|
||||||
|
|
||||||
except TimeoutError:
|
except TimeoutError as e:
|
||||||
return None, Status.TIMEOUT
|
status = _handle(Status.TIMEOUT, str(e))
|
||||||
|
|
||||||
except ValueError:
|
except ValueError as e:
|
||||||
return None, Status.INVALID
|
status = _handle(Status.INVALID, str(e))
|
||||||
|
|
||||||
except Exception:
|
except Exception as e:
|
||||||
return None, Status.ERROR
|
status = _handle(Status.ERROR, str(e))
|
||||||
|
|
||||||
|
took = time() - tinit
|
||||||
|
return output, status, took
|
||||||
|
|
||||||
|
|
||||||
def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
|
def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
|
||||||
|
|
Loading…
Reference in New Issue