forked from moreo/QuaPy
context timeout
This commit is contained in:
parent
f785a4eeef
commit
6663b4c91d
|
@ -13,6 +13,7 @@ from quapy import evaluation
|
|||
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
|
||||
from quapy.data.base import LabelledCollection
|
||||
from quapy.method.aggregative import BaseQuantifier, AggregativeQuantifier
|
||||
from quapy.util import timeout
|
||||
from time import time
|
||||
|
||||
|
||||
|
@ -127,7 +128,6 @@ class GridSearchQ(BaseQuantifier):
|
|||
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
||||
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
|
||||
|
||||
@check_status
|
||||
def _prepare_classifier(self, args):
|
||||
cls_params = args['cls-params']
|
||||
training = args['training']
|
||||
|
@ -136,9 +136,8 @@ class GridSearchQ(BaseQuantifier):
|
|||
predictions = model.classifier_fit_predict(training)
|
||||
return {'model': model, 'predictions': predictions, 'cls-params': cls_params}
|
||||
|
||||
@check_status
|
||||
def _prepare_aggregation(self, args):
|
||||
# (partial_setup, q_params), training = args
|
||||
|
||||
model = args['model']
|
||||
predictions = args['predictions']
|
||||
cls_params = args['cls-params']
|
||||
|
@ -147,15 +146,32 @@ class GridSearchQ(BaseQuantifier):
|
|||
|
||||
params = {**cls_params, **q_params}
|
||||
|
||||
model = deepcopy(model)
|
||||
# overrides default parameters with the parameters being explored at this iteration
|
||||
model.set_params(**q_params)
|
||||
model.aggregation_fit(predictions, training)
|
||||
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
|
||||
def job(model):
|
||||
tinit = time()
|
||||
model = deepcopy(model)
|
||||
# overrides default parameters with the parameters being explored at this iteration
|
||||
model.set_params(**q_params)
|
||||
model.aggregation_fit(predictions, training)
|
||||
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
|
||||
ttime = time()-tinit
|
||||
|
||||
return {
|
||||
'model': model,
|
||||
'cls-params':cls_params,
|
||||
'q-params': q_params,
|
||||
'params': params,
|
||||
'score': score,
|
||||
'ttime':ttime
|
||||
}
|
||||
|
||||
out, status = self._error_handler(job, args)
|
||||
if status == Status.SUCCESS:
|
||||
self._sout(f'hyperparams=[{params}]\t got {self.error.__name__} = {out["score"]:.5f} [took {out["time"]:.4f}s]')
|
||||
elif status == Status.INVALID:
|
||||
self._sout(f'the combination of hyperparameters {params} is invalid')
|
||||
elif status == Status.
|
||||
|
||||
return {'model': model, 'cls-params':cls_params, 'q-params': q_params, 'params': params, 'score': score}
|
||||
|
||||
@check_status
|
||||
def _prepare_model(self, args):
|
||||
params, training = args
|
||||
model = deepcopy(self.model)
|
||||
|
@ -309,6 +325,24 @@ class GridSearchQ(BaseQuantifier):
|
|||
raise ValueError('best_model called before fit')
|
||||
|
||||
|
||||
def _error_handler(self, func, *args, **kwargs):
|
||||
|
||||
try:
|
||||
with timeout(self.timeout):
|
||||
output = func(*args, **kwargs)
|
||||
return output, Status.SUCCESS
|
||||
|
||||
except TimeoutError:
|
||||
return None, Status.TIMEOUT
|
||||
|
||||
except ValueError:
|
||||
return None, Status.INVALID
|
||||
|
||||
except Exception:
|
||||
return None, Status.ERROR
|
||||
|
||||
|
||||
|
||||
def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
|
||||
"""
|
||||
Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_
|
||||
|
|
|
@ -10,6 +10,8 @@ import quapy as qp
|
|||
|
||||
import numpy as np
|
||||
from joblib import Parallel, delayed
|
||||
from time import time
|
||||
import signal
|
||||
|
||||
|
||||
def _get_parallel_slices(n_tasks, n_jobs):
|
||||
|
@ -257,3 +259,35 @@ class EarlyStop:
|
|||
if self.patience <= 0:
|
||||
self.STOP = True
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def timeout(seconds):
|
||||
"""
|
||||
Opens a context that will launch an exception if not closed after a given number of seconds
|
||||
|
||||
>>> def func(start_msg, end_msg):
|
||||
>>> print(start_msg)
|
||||
>>> sleep(2)
|
||||
>>> print(end_msg)
|
||||
>>>
|
||||
>>> with timeout(1):
|
||||
>>> func('begin function', 'end function')
|
||||
>>> Out[]
|
||||
>>> begin function
|
||||
>>> TimeoutError
|
||||
|
||||
|
||||
:param seconds: number of seconds, set to <=0 to ignore the timer
|
||||
"""
|
||||
if seconds > 0:
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError()
|
||||
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(seconds)
|
||||
|
||||
yield
|
||||
|
||||
if seconds > 0:
|
||||
signal.alarm(0)
|
||||
|
||||
|
|
Loading…
Reference in New Issue