From 6663b4c91d01c0edb5735da637cb2416433d031c Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Mon, 20 Nov 2023 22:05:26 +0100 Subject: [PATCH] context timeout --- quapy/model_selection.py | 54 ++++++++++++++++++++++++++++++++-------- quapy/util.py | 34 +++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 10 deletions(-) diff --git a/quapy/model_selection.py b/quapy/model_selection.py index 9bd0985..6637d62 100644 --- a/quapy/model_selection.py +++ b/quapy/model_selection.py @@ -13,6 +13,7 @@ from quapy import evaluation from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol from quapy.data.base import LabelledCollection from quapy.method.aggregative import BaseQuantifier, AggregativeQuantifier +from quapy.util import timeout from time import time @@ -127,7 +128,6 @@ class GridSearchQ(BaseQuantifier): raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n' f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}') - @check_status def _prepare_classifier(self, args): cls_params = args['cls-params'] training = args['training'] @@ -136,9 +136,8 @@ class GridSearchQ(BaseQuantifier): predictions = model.classifier_fit_predict(training) return {'model': model, 'predictions': predictions, 'cls-params': cls_params} - @check_status def _prepare_aggregation(self, args): - # (partial_setup, q_params), training = args + model = args['model'] predictions = args['predictions'] cls_params = args['cls-params'] @@ -147,15 +146,32 @@ class GridSearchQ(BaseQuantifier): params = {**cls_params, **q_params} - model = deepcopy(model) - # overrides default parameters with the parameters being explored at this iteration - model.set_params(**q_params) - model.aggregation_fit(predictions, training) - score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error) + def job(model): + tinit = time() + model = deepcopy(model) + # overrides default parameters with the parameters being explored at this iteration + model.set_params(**q_params) + model.aggregation_fit(predictions, training) + score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error) + ttime = time()-tinit + + return { + 'model': model, + 'cls-params':cls_params, + 'q-params': q_params, + 'params': params, + 'score': score, + 'ttime':ttime + } + + out, status = self._error_handler(job, args) + if status == Status.SUCCESS: + self._sout(f'hyperparams=[{params}]\t got {self.error.__name__} = {out["score"]:.5f} [took {out["time"]:.4f}s]') + elif status == Status.INVALID: + self._sout(f'the combination of hyperparameters {params} is invalid') + elif status == Status. - return {'model': model, 'cls-params':cls_params, 'q-params': q_params, 'params': params, 'score': score} - @check_status def _prepare_model(self, args): params, training = args model = deepcopy(self.model) @@ -309,6 +325,24 @@ class GridSearchQ(BaseQuantifier): raise ValueError('best_model called before fit') + def _error_handler(self, func, *args, **kwargs): + + try: + with timeout(self.timeout): + output = func(*args, **kwargs) + return output, Status.SUCCESS + + except TimeoutError: + return None, Status.TIMEOUT + + except ValueError: + return None, Status.INVALID + + except Exception: + return None, Status.ERROR + + + def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0): """ Akin to `scikit-learn's cross_val_predict `_ diff --git a/quapy/util.py b/quapy/util.py index 51c2a41..de5c131 100644 --- a/quapy/util.py +++ b/quapy/util.py @@ -10,6 +10,8 @@ import quapy as qp import numpy as np from joblib import Parallel, delayed +from time import time +import signal def _get_parallel_slices(n_tasks, n_jobs): @@ -257,3 +259,35 @@ class EarlyStop: if self.patience <= 0: self.STOP = True + +@contextlib.contextmanager +def timeout(seconds): + """ + Opens a context that will launch an exception if not closed after a given number of seconds + + >>> def func(start_msg, end_msg): + >>> print(start_msg) + >>> sleep(2) + >>> print(end_msg) + >>> + >>> with timeout(1): + >>> func('begin function', 'end function') + >>> Out[] + >>> begin function + >>> TimeoutError + + + :param seconds: number of seconds, set to <=0 to ignore the timer + """ + if seconds > 0: + def handler(signum, frame): + raise TimeoutError() + + signal.signal(signal.SIGALRM, handler) + signal.alarm(seconds) + + yield + + if seconds > 0: + signal.alarm(0) +