forked from moreo/QuaPy
context timeout
This commit is contained in:
parent
f785a4eeef
commit
6663b4c91d
|
@ -13,6 +13,7 @@ from quapy import evaluation
|
||||||
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
|
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
|
||||||
from quapy.data.base import LabelledCollection
|
from quapy.data.base import LabelledCollection
|
||||||
from quapy.method.aggregative import BaseQuantifier, AggregativeQuantifier
|
from quapy.method.aggregative import BaseQuantifier, AggregativeQuantifier
|
||||||
|
from quapy.util import timeout
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,7 +128,6 @@ class GridSearchQ(BaseQuantifier):
|
||||||
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
||||||
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
|
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
|
||||||
|
|
||||||
@check_status
|
|
||||||
def _prepare_classifier(self, args):
|
def _prepare_classifier(self, args):
|
||||||
cls_params = args['cls-params']
|
cls_params = args['cls-params']
|
||||||
training = args['training']
|
training = args['training']
|
||||||
|
@ -136,9 +136,8 @@ class GridSearchQ(BaseQuantifier):
|
||||||
predictions = model.classifier_fit_predict(training)
|
predictions = model.classifier_fit_predict(training)
|
||||||
return {'model': model, 'predictions': predictions, 'cls-params': cls_params}
|
return {'model': model, 'predictions': predictions, 'cls-params': cls_params}
|
||||||
|
|
||||||
@check_status
|
|
||||||
def _prepare_aggregation(self, args):
|
def _prepare_aggregation(self, args):
|
||||||
# (partial_setup, q_params), training = args
|
|
||||||
model = args['model']
|
model = args['model']
|
||||||
predictions = args['predictions']
|
predictions = args['predictions']
|
||||||
cls_params = args['cls-params']
|
cls_params = args['cls-params']
|
||||||
|
@ -147,15 +146,32 @@ class GridSearchQ(BaseQuantifier):
|
||||||
|
|
||||||
params = {**cls_params, **q_params}
|
params = {**cls_params, **q_params}
|
||||||
|
|
||||||
model = deepcopy(model)
|
def job(model):
|
||||||
# overrides default parameters with the parameters being explored at this iteration
|
tinit = time()
|
||||||
model.set_params(**q_params)
|
model = deepcopy(model)
|
||||||
model.aggregation_fit(predictions, training)
|
# overrides default parameters with the parameters being explored at this iteration
|
||||||
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
|
model.set_params(**q_params)
|
||||||
|
model.aggregation_fit(predictions, training)
|
||||||
|
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
|
||||||
|
ttime = time()-tinit
|
||||||
|
|
||||||
|
return {
|
||||||
|
'model': model,
|
||||||
|
'cls-params':cls_params,
|
||||||
|
'q-params': q_params,
|
||||||
|
'params': params,
|
||||||
|
'score': score,
|
||||||
|
'ttime':ttime
|
||||||
|
}
|
||||||
|
|
||||||
|
out, status = self._error_handler(job, args)
|
||||||
|
if status == Status.SUCCESS:
|
||||||
|
self._sout(f'hyperparams=[{params}]\t got {self.error.__name__} = {out["score"]:.5f} [took {out["time"]:.4f}s]')
|
||||||
|
elif status == Status.INVALID:
|
||||||
|
self._sout(f'the combination of hyperparameters {params} is invalid')
|
||||||
|
elif status == Status.
|
||||||
|
|
||||||
return {'model': model, 'cls-params':cls_params, 'q-params': q_params, 'params': params, 'score': score}
|
|
||||||
|
|
||||||
@check_status
|
|
||||||
def _prepare_model(self, args):
|
def _prepare_model(self, args):
|
||||||
params, training = args
|
params, training = args
|
||||||
model = deepcopy(self.model)
|
model = deepcopy(self.model)
|
||||||
|
@ -309,6 +325,24 @@ class GridSearchQ(BaseQuantifier):
|
||||||
raise ValueError('best_model called before fit')
|
raise ValueError('best_model called before fit')
|
||||||
|
|
||||||
|
|
||||||
|
def _error_handler(self, func, *args, **kwargs):
|
||||||
|
|
||||||
|
try:
|
||||||
|
with timeout(self.timeout):
|
||||||
|
output = func(*args, **kwargs)
|
||||||
|
return output, Status.SUCCESS
|
||||||
|
|
||||||
|
except TimeoutError:
|
||||||
|
return None, Status.TIMEOUT
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
return None, Status.INVALID
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return None, Status.ERROR
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
|
def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
|
||||||
"""
|
"""
|
||||||
Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_
|
Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_
|
||||||
|
|
|
@ -10,6 +10,8 @@ import quapy as qp
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from joblib import Parallel, delayed
|
from joblib import Parallel, delayed
|
||||||
|
from time import time
|
||||||
|
import signal
|
||||||
|
|
||||||
|
|
||||||
def _get_parallel_slices(n_tasks, n_jobs):
|
def _get_parallel_slices(n_tasks, n_jobs):
|
||||||
|
@ -257,3 +259,35 @@ class EarlyStop:
|
||||||
if self.patience <= 0:
|
if self.patience <= 0:
|
||||||
self.STOP = True
|
self.STOP = True
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def timeout(seconds):
|
||||||
|
"""
|
||||||
|
Opens a context that will launch an exception if not closed after a given number of seconds
|
||||||
|
|
||||||
|
>>> def func(start_msg, end_msg):
|
||||||
|
>>> print(start_msg)
|
||||||
|
>>> sleep(2)
|
||||||
|
>>> print(end_msg)
|
||||||
|
>>>
|
||||||
|
>>> with timeout(1):
|
||||||
|
>>> func('begin function', 'end function')
|
||||||
|
>>> Out[]
|
||||||
|
>>> begin function
|
||||||
|
>>> TimeoutError
|
||||||
|
|
||||||
|
|
||||||
|
:param seconds: number of seconds, set to <=0 to ignore the timer
|
||||||
|
"""
|
||||||
|
if seconds > 0:
|
||||||
|
def handler(signum, frame):
|
||||||
|
raise TimeoutError()
|
||||||
|
|
||||||
|
signal.signal(signal.SIGALRM, handler)
|
||||||
|
signal.alarm(seconds)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
if seconds > 0:
|
||||||
|
signal.alarm(0)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue