context timeout

This commit is contained in:
Alejandro Moreo Fernandez 2023-11-20 22:05:26 +01:00
parent f785a4eeef
commit 6663b4c91d
2 changed files with 78 additions and 10 deletions

View File

@ -13,6 +13,7 @@ from quapy import evaluation
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
from quapy.data.base import LabelledCollection
from quapy.method.aggregative import BaseQuantifier, AggregativeQuantifier
from quapy.util import timeout
from time import time
@ -127,7 +128,6 @@ class GridSearchQ(BaseQuantifier):
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
@check_status
def _prepare_classifier(self, args):
cls_params = args['cls-params']
training = args['training']
@ -136,9 +136,8 @@ class GridSearchQ(BaseQuantifier):
predictions = model.classifier_fit_predict(training)
return {'model': model, 'predictions': predictions, 'cls-params': cls_params}
@check_status
def _prepare_aggregation(self, args):
# (partial_setup, q_params), training = args
model = args['model']
predictions = args['predictions']
cls_params = args['cls-params']
@ -147,15 +146,32 @@ class GridSearchQ(BaseQuantifier):
params = {**cls_params, **q_params}
model = deepcopy(model)
# overrides default parameters with the parameters being explored at this iteration
model.set_params(**q_params)
model.aggregation_fit(predictions, training)
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
def job(model):
tinit = time()
model = deepcopy(model)
# overrides default parameters with the parameters being explored at this iteration
model.set_params(**q_params)
model.aggregation_fit(predictions, training)
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
ttime = time()-tinit
return {
'model': model,
'cls-params':cls_params,
'q-params': q_params,
'params': params,
'score': score,
'ttime':ttime
}
out, status = self._error_handler(job, args)
if status == Status.SUCCESS:
self._sout(f'hyperparams=[{params}]\t got {self.error.__name__} = {out["score"]:.5f} [took {out["time"]:.4f}s]')
elif status == Status.INVALID:
self._sout(f'the combination of hyperparameters {params} is invalid')
elif status == Status.
return {'model': model, 'cls-params':cls_params, 'q-params': q_params, 'params': params, 'score': score}
@check_status
def _prepare_model(self, args):
params, training = args
model = deepcopy(self.model)
@ -309,6 +325,24 @@ class GridSearchQ(BaseQuantifier):
raise ValueError('best_model called before fit')
def _error_handler(self, func, *args, **kwargs):
try:
with timeout(self.timeout):
output = func(*args, **kwargs)
return output, Status.SUCCESS
except TimeoutError:
return None, Status.TIMEOUT
except ValueError:
return None, Status.INVALID
except Exception:
return None, Status.ERROR
def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
"""
Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_

View File

@ -10,6 +10,8 @@ import quapy as qp
import numpy as np
from joblib import Parallel, delayed
from time import time
import signal
def _get_parallel_slices(n_tasks, n_jobs):
@ -257,3 +259,35 @@ class EarlyStop:
if self.patience <= 0:
self.STOP = True
@contextlib.contextmanager
def timeout(seconds):
"""
Opens a context that will launch an exception if not closed after a given number of seconds
>>> def func(start_msg, end_msg):
>>> print(start_msg)
>>> sleep(2)
>>> print(end_msg)
>>>
>>> with timeout(1):
>>> func('begin function', 'end function')
>>> Out[]
>>> begin function
>>> TimeoutError
:param seconds: number of seconds, set to <=0 to ignore the timer
"""
if seconds > 0:
def handler(signum, frame):
raise TimeoutError()
signal.signal(signal.SIGALRM, handler)
signal.alarm(seconds)
yield
if seconds > 0:
signal.alarm(0)