1
0
Fork 0

context timeout

This commit is contained in:
Alejandro Moreo Fernandez 2023-11-20 22:05:26 +01:00
parent f785a4eeef
commit 6663b4c91d
2 changed files with 78 additions and 10 deletions

View File

@ -13,6 +13,7 @@ from quapy import evaluation
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
from quapy.data.base import LabelledCollection from quapy.data.base import LabelledCollection
from quapy.method.aggregative import BaseQuantifier, AggregativeQuantifier from quapy.method.aggregative import BaseQuantifier, AggregativeQuantifier
from quapy.util import timeout
from time import time from time import time
@ -127,7 +128,6 @@ class GridSearchQ(BaseQuantifier):
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n' raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}') f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
@check_status
def _prepare_classifier(self, args): def _prepare_classifier(self, args):
cls_params = args['cls-params'] cls_params = args['cls-params']
training = args['training'] training = args['training']
@ -136,9 +136,8 @@ class GridSearchQ(BaseQuantifier):
predictions = model.classifier_fit_predict(training) predictions = model.classifier_fit_predict(training)
return {'model': model, 'predictions': predictions, 'cls-params': cls_params} return {'model': model, 'predictions': predictions, 'cls-params': cls_params}
@check_status
def _prepare_aggregation(self, args): def _prepare_aggregation(self, args):
# (partial_setup, q_params), training = args
model = args['model'] model = args['model']
predictions = args['predictions'] predictions = args['predictions']
cls_params = args['cls-params'] cls_params = args['cls-params']
@ -147,15 +146,32 @@ class GridSearchQ(BaseQuantifier):
params = {**cls_params, **q_params} params = {**cls_params, **q_params}
model = deepcopy(model) def job(model):
# overrides default parameters with the parameters being explored at this iteration tinit = time()
model.set_params(**q_params) model = deepcopy(model)
model.aggregation_fit(predictions, training) # overrides default parameters with the parameters being explored at this iteration
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error) model.set_params(**q_params)
model.aggregation_fit(predictions, training)
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
ttime = time()-tinit
return {
'model': model,
'cls-params':cls_params,
'q-params': q_params,
'params': params,
'score': score,
'ttime':ttime
}
out, status = self._error_handler(job, args)
if status == Status.SUCCESS:
self._sout(f'hyperparams=[{params}]\t got {self.error.__name__} = {out["score"]:.5f} [took {out["time"]:.4f}s]')
elif status == Status.INVALID:
self._sout(f'the combination of hyperparameters {params} is invalid')
elif status == Status.
return {'model': model, 'cls-params':cls_params, 'q-params': q_params, 'params': params, 'score': score}
@check_status
def _prepare_model(self, args): def _prepare_model(self, args):
params, training = args params, training = args
model = deepcopy(self.model) model = deepcopy(self.model)
@ -309,6 +325,24 @@ class GridSearchQ(BaseQuantifier):
raise ValueError('best_model called before fit') raise ValueError('best_model called before fit')
def _error_handler(self, func, *args, **kwargs):
try:
with timeout(self.timeout):
output = func(*args, **kwargs)
return output, Status.SUCCESS
except TimeoutError:
return None, Status.TIMEOUT
except ValueError:
return None, Status.INVALID
except Exception:
return None, Status.ERROR
def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0): def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
""" """
Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_ Akin to `scikit-learn's cross_val_predict <https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html>`_

View File

@ -10,6 +10,8 @@ import quapy as qp
import numpy as np import numpy as np
from joblib import Parallel, delayed from joblib import Parallel, delayed
from time import time
import signal
def _get_parallel_slices(n_tasks, n_jobs): def _get_parallel_slices(n_tasks, n_jobs):
@ -257,3 +259,35 @@ class EarlyStop:
if self.patience <= 0: if self.patience <= 0:
self.STOP = True self.STOP = True
@contextlib.contextmanager
def timeout(seconds):
"""
Opens a context that will launch an exception if not closed after a given number of seconds
>>> def func(start_msg, end_msg):
>>> print(start_msg)
>>> sleep(2)
>>> print(end_msg)
>>>
>>> with timeout(1):
>>> func('begin function', 'end function')
>>> Out[]
>>> begin function
>>> TimeoutError
:param seconds: number of seconds, set to <=0 to ignore the timer
"""
if seconds > 0:
def handler(signum, frame):
raise TimeoutError()
signal.signal(signal.SIGALRM, handler)
signal.alarm(seconds)
yield
if seconds > 0:
signal.alarm(0)