forked from moreo/QuaPy
Merging aggregativefit into devel. The aggregative fit was created to generate a two-level quantification fit mirroring the inference phase. I.e., the fit now amounts to fitting a classifier plus fitting an aggregation function (just like the fit procedure, that amounts to invoking a classifier, and invoking an aggregation function). This is useful to nestle training phaes in model selection.
This commit is contained in:
commit
efe385318f
|
@ -130,3 +130,32 @@ dmypy.json
|
|||
.pyre/
|
||||
|
||||
*__pycache__*
|
||||
*.pdf
|
||||
*.zip
|
||||
*.png
|
||||
*.csv
|
||||
*.pkl
|
||||
*.dataframe
|
||||
|
||||
|
||||
# other projects
|
||||
LeQua2022
|
||||
MultiLabel
|
||||
NewMethods
|
||||
Ordinal
|
||||
Retrieval
|
||||
eDiscovery
|
||||
poster-cikm
|
||||
slides-cikm
|
||||
slides-short-cikm
|
||||
quick_experiment
|
||||
svm_perf_quantification/svm_struct
|
||||
svm_perf_quantification/svm_light
|
||||
TweetSentQuant
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
*.png
|
||||
|
|
|
@ -2,7 +2,7 @@ import quapy as qp
|
|||
from quapy.data import LabelledCollection
|
||||
from quapy.method.base import BinaryQuantifier
|
||||
from quapy.model_selection import GridSearchQ
|
||||
from quapy.method.aggregative import AggregativeProbabilisticQuantifier
|
||||
from quapy.method.aggregative import AggregativeSoftQuantifier
|
||||
from quapy.protocol import APP
|
||||
import numpy as np
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
|
@ -15,7 +15,7 @@ from sklearn.linear_model import LogisticRegression
|
|||
# internal hyperparameter (let say, alpha) which is the decision threshold. Let's also assume the quantifier
|
||||
# is binary, for simplicity.
|
||||
|
||||
class MyQuantifier(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
||||
class MyQuantifier(AggregativeSoftQuantifier, BinaryQuantifier):
|
||||
def __init__(self, classifier, alpha=0.5):
|
||||
self.alpha = alpha
|
||||
# aggregative quantifiers have an internal self.classifier attribute
|
||||
|
|
|
@ -1,19 +1,26 @@
|
|||
import quapy as qp
|
||||
from quapy.protocol import APP
|
||||
from method.kdey import KDEyML
|
||||
from quapy.method.non_aggregative import DMx
|
||||
from quapy.protocol import APP, UPP
|
||||
from quapy.method.aggregative import DMy
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from examples.comparing_gridsearch import OLD_GridSearchQ
|
||||
import numpy as np
|
||||
from time import time
|
||||
|
||||
"""
|
||||
In this example, we show how to perform model selection on a DistributionMatching quantifier.
|
||||
"""
|
||||
|
||||
model = DMy(LogisticRegression())
|
||||
model = KDEyML(LogisticRegression())
|
||||
|
||||
qp.environ['SAMPLE_SIZE'] = 100
|
||||
qp.environ['N_JOBS'] = -1
|
||||
|
||||
training, test = qp.datasets.fetch_reviews('imdb', tfidf=True, min_df=5).train_test
|
||||
# training, test = qp.datasets.fetch_reviews('imdb', tfidf=True, min_df=5).train_test
|
||||
training, test = qp.datasets.fetch_UCIMulticlassDataset('letter').train_test
|
||||
|
||||
with qp.util.temp_seed(0):
|
||||
|
||||
# The model will be returned by the fit method of GridSearchQ.
|
||||
# Every combination of hyper-parameters will be evaluated by confronting the
|
||||
|
@ -23,7 +30,7 @@ training, test = qp.datasets.fetch_reviews('imdb', tfidf=True, min_df=5).train_t
|
|||
# values in the entire range of values from a grid (e.g., [0, 0.1, 0.2, ..., 1]).
|
||||
# We devote 30% of the dataset for this exploration.
|
||||
training, validation = training.split_stratified(train_prop=0.7)
|
||||
protocol = APP(validation)
|
||||
protocol = UPP(validation)
|
||||
|
||||
# We will explore a classification-dependent hyper-parameter (e.g., the 'C'
|
||||
# hyper-parameter of LogisticRegression) and a quantification-dependent hyper-parameter
|
||||
|
@ -33,25 +40,32 @@ protocol = APP(validation)
|
|||
# classifier.
|
||||
param_grid = {
|
||||
'classifier__C': np.logspace(-3,3,7),
|
||||
'nbins': [8, 16, 32, 64],
|
||||
'classifier__class_weight': ['balanced', None],
|
||||
'bandwidth': np.linspace(0.01, 0.2, 20),
|
||||
}
|
||||
|
||||
tinit = time()
|
||||
|
||||
# model = OLD_GridSearchQ(
|
||||
model = qp.model_selection.GridSearchQ(
|
||||
model=model,
|
||||
param_grid=param_grid,
|
||||
protocol=protocol,
|
||||
error='mae', # the error to optimize is the MAE (a quantification-oriented loss)
|
||||
refit=True, # retrain on the whole labelled set once done
|
||||
refit=False, # retrain on the whole labelled set once done
|
||||
# raise_errors=False,
|
||||
verbose=True # show information as the process goes on
|
||||
).fit(training)
|
||||
|
||||
tend = time()
|
||||
|
||||
print(f'model selection ended: best hyper-parameters={model.best_params_}')
|
||||
model = model.best_model_
|
||||
|
||||
# evaluation in terms of MAE
|
||||
# we use the same evaluation protocol (APP) on the test set
|
||||
mae_score = qp.evaluation.evaluate(model, protocol=APP(test), error_metric='mae')
|
||||
mae_score = qp.evaluation.evaluate(model, protocol=UPP(test), error_metric='mae')
|
||||
|
||||
print(f'MAE={mae_score:.5f}')
|
||||
|
||||
print(f'model selection took {tend-tinit:.1f}s')
|
||||
|
||||
|
|
|
@ -104,7 +104,7 @@ def run(experiment):
|
|||
timeout=60*60,
|
||||
verbose=True
|
||||
)
|
||||
model_selection.fit(data.training)
|
||||
model_selection.fit(train)
|
||||
model = model_selection.best_model()
|
||||
best_params = model_selection.best_params_
|
||||
else:
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
Change Log 0.1.8
|
||||
----------------
|
||||
|
||||
- Fixed ThresholdOptimization methods (X, T50, MAX, MS and MS2). Thanks to Tobias Schumacher and colleagues for pointing
|
||||
this out in Appendix A of "Schumacher, T., Strohmaier, M., & Lemmerich, F. (2021). A comparative evaluation of
|
||||
quantification methods. arXiv:2103.03223v3 [cs.LG]"
|
||||
- Added HDx and DistributionMatchingX to non-aggregative quantifiers (see also the new example "comparing_HDy_HDx.py")
|
||||
- New UCI multiclass datasets added (thanks to Pablo González). The 5 UCI multiclass datasets are those corresponding
|
||||
to the following criteria:
|
||||
|
|
|
@ -24,7 +24,8 @@ class RecalibratedProbabilisticClassifier:
|
|||
class RecalibratedProbabilisticClassifierBase(BaseEstimator, RecalibratedProbabilisticClassifier):
|
||||
"""
|
||||
Applies a (re)calibration method from `abstention.calibration`, as defined in
|
||||
`Alexandari et al. paper <http://proceedings.mlr.press/v119/alexandari20a.html>`_:
|
||||
`Alexandari et al. paper <http://proceedings.mlr.press/v119/alexandari20a.html>`_.
|
||||
|
||||
|
||||
:param classifier: a scikit-learn probabilistic classifier
|
||||
:param calibrator: the calibration object (an instance of abstention.calibration.CalibratorFactory)
|
||||
|
@ -59,7 +60,7 @@ class RecalibratedProbabilisticClassifierBase(BaseEstimator, RecalibratedProbabi
|
|||
elif isinstance(k, float):
|
||||
if not (0 < k < 1):
|
||||
raise ValueError('wrong value for val_split: the proportion of validation documents must be in (0,1)')
|
||||
return self.fit_cv(X, y)
|
||||
return self.fit_tr_val(X, y)
|
||||
|
||||
def fit_cv(self, X, y):
|
||||
"""
|
||||
|
@ -94,7 +95,7 @@ class RecalibratedProbabilisticClassifierBase(BaseEstimator, RecalibratedProbabi
|
|||
self.classifier.fit(Xtr, ytr)
|
||||
posteriors = self.classifier.predict_proba(Xva)
|
||||
nclasses = len(np.unique(yva))
|
||||
self.calibrator = self.calibrator(posteriors, np.eye(nclasses)[yva], posterior_supplied=True)
|
||||
self.calibration_function = self.calibrator(posteriors, np.eye(nclasses)[yva], posterior_supplied=True)
|
||||
return self
|
||||
|
||||
def predict(self, X):
|
||||
|
|
|
@ -66,6 +66,24 @@ def prevalence_from_probabilities(posteriors, binarize: bool = False):
|
|||
return prevalences
|
||||
|
||||
|
||||
def as_binary_prevalence(positive_prevalence: Union[float, np.ndarray], clip_if_necessary=False):
|
||||
"""
|
||||
Helper that, given a float representing the prevalence for the positive class, returns a np.ndarray of two
|
||||
values representing a binary distribution.
|
||||
|
||||
:param positive_prevalence: prevalence for the positive class
|
||||
:param clip_if_necessary: if True, clips the value in [0,1] in order to guarantee the resulting distribution
|
||||
is valid. If False, it then checks that the value is in the valid range, and raises an error if not.
|
||||
:return: np.ndarray of shape `(2,)`
|
||||
"""
|
||||
if clip_if_necessary:
|
||||
positive_prevalence = np.clip(positive_prevalence, 0, 1)
|
||||
else:
|
||||
assert 0 <= positive_prevalence <= 1, 'the value provided is not a valid prevalence for the positive class'
|
||||
return np.asarray([1-positive_prevalence, positive_prevalence]).T
|
||||
|
||||
|
||||
|
||||
def HellingerDistance(P, Q) -> float:
|
||||
"""
|
||||
Computes the Hellingher Distance (HD) between (discretized) distributions `P` and `Q`.
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,214 @@
|
|||
from typing import Union
|
||||
import numpy as np
|
||||
from sklearn.base import BaseEstimator
|
||||
from sklearn.neighbors import KernelDensity
|
||||
|
||||
import quapy as qp
|
||||
from quapy.data import LabelledCollection
|
||||
from quapy.method.aggregative import AggregativeSoftQuantifier
|
||||
import quapy.functional as F
|
||||
|
||||
from sklearn.metrics.pairwise import rbf_kernel
|
||||
|
||||
|
||||
class KDEBase:
|
||||
|
||||
BANDWIDTH_METHOD = ['scott', 'silverman']
|
||||
|
||||
@classmethod
|
||||
def _check_bandwidth(cls, bandwidth):
|
||||
assert bandwidth in KDEBase.BANDWIDTH_METHOD or isinstance(bandwidth, float), \
|
||||
f'invalid bandwidth, valid ones are {KDEBase.BANDWIDTH_METHOD} or float values'
|
||||
if isinstance(bandwidth, float):
|
||||
assert 0 < bandwidth < 1, "the bandwith for KDEy should be in (0,1), since this method models the unit simplex"
|
||||
|
||||
def get_kde_function(self, X, bandwidth):
|
||||
return KernelDensity(bandwidth=bandwidth).fit(X)
|
||||
|
||||
def pdf(self, kde, X):
|
||||
return np.exp(kde.score_samples(X))
|
||||
|
||||
def get_mixture_components(self, X, y, n_classes, bandwidth):
|
||||
return [self.get_kde_function(X[y == cat], bandwidth) for cat in range(n_classes)]
|
||||
|
||||
|
||||
|
||||
class KDEyML(AggregativeSoftQuantifier, KDEBase):
|
||||
|
||||
def __init__(self, classifier: BaseEstimator, val_split=10, bandwidth=0.1, n_jobs=None, random_state=0):
|
||||
self._check_bandwidth(bandwidth)
|
||||
self.classifier = classifier
|
||||
self.val_split = val_split
|
||||
self.bandwidth = bandwidth
|
||||
self.n_jobs = n_jobs
|
||||
self.random_state=random_state
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
self.mix_densities = self.get_mixture_components(*classif_predictions.Xy, data.n_classes, self.bandwidth)
|
||||
return self
|
||||
|
||||
def aggregate(self, posteriors: np.ndarray):
|
||||
"""
|
||||
Searches for the mixture model parameter (the sought prevalence values) that maximizes the likelihood
|
||||
of the data (i.e., that minimizes the negative log-likelihood)
|
||||
|
||||
:param posteriors: instances in the sample converted into posterior probabilities
|
||||
:return: a vector of class prevalence estimates
|
||||
"""
|
||||
np.random.RandomState(self.random_state)
|
||||
epsilon = 1e-10
|
||||
n_classes = len(self.mix_densities)
|
||||
test_densities = [self.pdf(kde_i, posteriors) for kde_i in self.mix_densities]
|
||||
|
||||
def neg_loglikelihood(prev):
|
||||
test_mixture_likelihood = sum(prev_i * dens_i for prev_i, dens_i in zip (prev, test_densities))
|
||||
test_loglikelihood = np.log(test_mixture_likelihood + epsilon)
|
||||
return -np.sum(test_loglikelihood)
|
||||
|
||||
return F.optim_minimize(neg_loglikelihood, n_classes)
|
||||
|
||||
|
||||
class KDEyHD(AggregativeSoftQuantifier, KDEBase):
|
||||
|
||||
def __init__(self, classifier: BaseEstimator, val_split=10, divergence: str='HD',
|
||||
bandwidth=0.1, n_jobs=None, random_state=0, montecarlo_trials=10000):
|
||||
|
||||
self._check_bandwidth(bandwidth)
|
||||
self.classifier = classifier
|
||||
self.val_split = val_split
|
||||
self.divergence = divergence
|
||||
self.bandwidth = bandwidth
|
||||
self.n_jobs = n_jobs
|
||||
self.random_state=random_state
|
||||
self.montecarlo_trials = montecarlo_trials
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
self.mix_densities = self.get_mixture_components(*classif_predictions.Xy, data.n_classes, self.bandwidth)
|
||||
|
||||
N = self.montecarlo_trials
|
||||
rs = self.random_state
|
||||
n = data.n_classes
|
||||
self.reference_samples = np.vstack([kde_i.sample(N//n, random_state=rs) for kde_i in self.mix_densities])
|
||||
self.reference_classwise_densities = np.asarray([self.pdf(kde_j, self.reference_samples) for kde_j in self.mix_densities])
|
||||
self.reference_density = np.mean(self.reference_classwise_densities, axis=0) # equiv. to (uniform @ self.reference_classwise_densities)
|
||||
|
||||
return self
|
||||
|
||||
def aggregate(self, posteriors: np.ndarray):
|
||||
# we retain all n*N examples (sampled from a mixture with uniform parameter), and then
|
||||
# apply importance sampling (IS). In this version we compute D(p_alpha||q) with IS
|
||||
n_classes = len(self.mix_densities)
|
||||
|
||||
test_kde = self.get_kde_function(posteriors, self.bandwidth)
|
||||
test_densities = self.pdf(test_kde, self.reference_samples)
|
||||
|
||||
def f_squared_hellinger(u):
|
||||
return (np.sqrt(u)-1)**2
|
||||
|
||||
# todo: this will fail when self.divergence is a callable, and is not the right place to do it anyway
|
||||
if self.divergence.lower() == 'hd':
|
||||
f = f_squared_hellinger
|
||||
else:
|
||||
raise ValueError('only squared HD is currently implemented')
|
||||
|
||||
epsilon = 1e-10
|
||||
qs = test_densities + epsilon
|
||||
rs = self.reference_density + epsilon
|
||||
iw = qs/rs #importance weights
|
||||
p_class = self.reference_classwise_densities + epsilon
|
||||
fracs = p_class/qs
|
||||
|
||||
def divergence(prev):
|
||||
# ps / qs = (prev @ p_class) / qs = prev @ (p_class / qs) = prev @ fracs
|
||||
ps_div_qs = prev @ fracs
|
||||
return np.mean( f(ps_div_qs) * iw )
|
||||
|
||||
return F.optim_minimize(divergence, n_classes)
|
||||
|
||||
|
||||
class KDEyCS(AggregativeSoftQuantifier):
|
||||
|
||||
def __init__(self, classifier: BaseEstimator, val_split=10, bandwidth=0.1, n_jobs=None, random_state=0):
|
||||
KDEBase._check_bandwidth(bandwidth)
|
||||
self.classifier = classifier
|
||||
self.val_split = val_split
|
||||
self.bandwidth = bandwidth
|
||||
self.n_jobs = n_jobs
|
||||
self.random_state=random_state
|
||||
|
||||
def gram_matrix_mix_sum(self, X, Y=None):
|
||||
# this adapts the output of the rbf_kernel function (pairwise evaluations of Gaussian kernels k(x,y))
|
||||
# to contain pairwise evaluations of N(x|mu,Sigma1+Sigma2) with mu=y and Sigma1 and Sigma2 are
|
||||
# two "scalar matrices" (h^2)*I each, so Sigma1+Sigma2 has scalar 2(h^2) (h is the bandwidth)
|
||||
h = self.bandwidth
|
||||
variance = 2 * (h**2)
|
||||
nD = X.shape[1]
|
||||
gamma = 1/(2*variance)
|
||||
norm_factor = 1/np.sqrt(((2*np.pi)**nD) * (variance**(nD)))
|
||||
gram = norm_factor * rbf_kernel(X, Y, gamma=gamma)
|
||||
return gram.sum()
|
||||
|
||||
def aggregation_fit(self, classif_predictions: LabelledCollection, data: LabelledCollection):
|
||||
|
||||
P, y = classif_predictions.Xy
|
||||
n = data.n_classes
|
||||
|
||||
assert all(sorted(np.unique(y)) == np.arange(n)), \
|
||||
'label name gaps not allowed in current implementation'
|
||||
|
||||
|
||||
# counts_inv keeps track of the relative weight of each datapoint within its class
|
||||
# (i.e., the weight in its KDE model)
|
||||
counts_inv = 1 / (data.counts())
|
||||
|
||||
# tr_tr_sums corresponds to symbol \overline{B} in the paper
|
||||
tr_tr_sums = np.zeros(shape=(n,n), dtype=float)
|
||||
for i in range(n):
|
||||
for j in range(n):
|
||||
if i > j:
|
||||
tr_tr_sums[i,j] = tr_tr_sums[j,i]
|
||||
else:
|
||||
block = self.gram_matrix_mix_sum(P[y == i], P[y == j] if i!=j else None)
|
||||
tr_tr_sums[i, j] = block
|
||||
|
||||
# keep track of these data structures for the test phase
|
||||
self.Ptr = P
|
||||
self.ytr = y
|
||||
self.tr_tr_sums = tr_tr_sums
|
||||
self.counts_inv = counts_inv
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def aggregate(self, posteriors: np.ndarray):
|
||||
Ptr = self.Ptr
|
||||
Pte = posteriors
|
||||
y = self.ytr
|
||||
tr_tr_sums = self.tr_tr_sums
|
||||
|
||||
M, nD = Pte.shape
|
||||
Minv = (1/M) # t in the paper
|
||||
n = Ptr.shape[1]
|
||||
|
||||
|
||||
# becomes a constant that does not affect the optimization, no need to compute it
|
||||
# partC = 0.5*np.log(self.gram_matrix_mix_sum(Pte) * Kinv * Kinv)
|
||||
|
||||
# tr_te_sums corresponds to \overline{a}*(1/Li)*(1/M) in the paper (note the constants
|
||||
# are already aggregated to tr_te_sums, so these multiplications are not carried out
|
||||
# at each iteration of the optimization phase)
|
||||
tr_te_sums = np.zeros(shape=n, dtype=float)
|
||||
for i in range(n):
|
||||
tr_te_sums[i] = self.gram_matrix_mix_sum(Ptr[y==i], Pte)
|
||||
|
||||
def divergence(alpha):
|
||||
# called \overline{r} in the paper
|
||||
alpha_ratio = alpha * self.counts_inv
|
||||
|
||||
# recal that tr_te_sums already accounts for the constant terms (1/Li)*(1/M)
|
||||
partA = -np.log((alpha_ratio @ tr_te_sums) * Minv)
|
||||
partB = 0.5 * np.log(alpha_ratio @ tr_tr_sums @ alpha_ratio)
|
||||
return partA + partB #+ partC
|
||||
|
||||
return F.optim_minimize(divergence, n)
|
||||
|
|
@ -12,7 +12,7 @@ from quapy import functional as F
|
|||
from quapy.data import LabelledCollection
|
||||
from quapy.model_selection import GridSearchQ
|
||||
from quapy.method.base import BaseQuantifier, BinaryQuantifier
|
||||
from quapy.method.aggregative import CC, ACC, PACC, HDy, EMQ
|
||||
from quapy.method.aggregative import CC, ACC, PACC, HDy, EMQ, AggregativeQuantifier
|
||||
|
||||
try:
|
||||
from . import neural
|
||||
|
@ -26,6 +26,65 @@ else:
|
|||
QuaNet = "QuaNet is not available due to missing torch package"
|
||||
|
||||
|
||||
class MedianEstimator2(BinaryQuantifier):
|
||||
"""
|
||||
This method is a meta-quantifier that returns, as the estimated class prevalence values, the median of the
|
||||
estimation returned by differently (hyper)parameterized base quantifiers.
|
||||
The median of unit-vectors is only guaranteed to be a unit-vector for n=2 dimensions,
|
||||
i.e., in cases of binary quantification.
|
||||
|
||||
:param base_quantifier: the base, binary quantifier
|
||||
:param random_state: a seed to be set before fitting any base quantifier (default None)
|
||||
:param param_grid: the grid or parameters towards which the median will be computed
|
||||
:param n_jobs: number of parllel workes
|
||||
"""
|
||||
def __init__(self, base_quantifier: BinaryQuantifier, param_grid: dict, random_state=None, n_jobs=None):
|
||||
self.base_quantifier = base_quantifier
|
||||
self.param_grid = param_grid
|
||||
self.random_state = random_state
|
||||
self.n_jobs = qp._get_njobs(n_jobs)
|
||||
|
||||
def get_params(self, deep=True):
|
||||
return self.base_quantifier.get_params(deep)
|
||||
|
||||
def set_params(self, **params):
|
||||
self.base_quantifier.set_params(**params)
|
||||
|
||||
def _delayed_fit(self, args):
|
||||
with qp.util.temp_seed(self.random_state):
|
||||
params, training = args
|
||||
model = deepcopy(self.base_quantifier)
|
||||
model.set_params(**params)
|
||||
model.fit(training)
|
||||
return model
|
||||
|
||||
def fit(self, training: LabelledCollection):
|
||||
self._check_binary(training, self.__class__.__name__)
|
||||
|
||||
configs = qp.model_selection.expand_grid(self.param_grid)
|
||||
self.models = qp.util.parallel(
|
||||
self._delayed_fit,
|
||||
((params, training) for params in configs),
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs
|
||||
)
|
||||
return self
|
||||
|
||||
def _delayed_predict(self, args):
|
||||
model, instances = args
|
||||
return model.quantify(instances)
|
||||
|
||||
def quantify(self, instances):
|
||||
prev_preds = qp.util.parallel(
|
||||
self._delayed_predict,
|
||||
((model, instances) for model in self.models),
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs
|
||||
)
|
||||
prev_preds = np.asarray(prev_preds)
|
||||
return np.median(prev_preds, axis=0)
|
||||
|
||||
|
||||
class MedianEstimator(BinaryQuantifier):
|
||||
"""
|
||||
This method is a meta-quantifier that returns, as the estimated class prevalence values, the median of the
|
||||
|
@ -58,16 +117,63 @@ class MedianEstimator(BinaryQuantifier):
|
|||
model.fit(training)
|
||||
return model
|
||||
|
||||
def _delayed_fit_classifier(self, args):
|
||||
with qp.util.temp_seed(self.random_state):
|
||||
print('enter job')
|
||||
cls_params, training = args
|
||||
model = deepcopy(self.base_quantifier)
|
||||
model.set_params(**cls_params)
|
||||
predictions = model.classifier_fit_predict(training, predict_on=model.val_split)
|
||||
print('exit job')
|
||||
return (model, predictions)
|
||||
|
||||
def _delayed_fit_aggregation(self, args):
|
||||
with qp.util.temp_seed(self.random_state):
|
||||
print('\tenter job')
|
||||
((model, predictions), q_params), training = args
|
||||
model = deepcopy(model)
|
||||
model.set_params(**q_params)
|
||||
model.aggregation_fit(predictions, training)
|
||||
print('\texit job')
|
||||
return model
|
||||
|
||||
|
||||
def fit(self, training: LabelledCollection):
|
||||
self._check_binary(training, self.__class__.__name__)
|
||||
params_keys = list(self.param_grid.keys())
|
||||
params_values = list(self.param_grid.values())
|
||||
hyper = [dict({k: val[i] for i, k in enumerate(params_keys)}) for val in itertools.product(*params_values)]
|
||||
|
||||
if isinstance(self.base_quantifier, AggregativeQuantifier):
|
||||
cls_configs, q_configs = qp.model_selection.group_params(self.param_grid)
|
||||
|
||||
if len(cls_configs) > 1:
|
||||
models_preds = qp.util.parallel(
|
||||
self._delayed_fit_classifier,
|
||||
((params, training) for params in cls_configs),
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs,
|
||||
asarray=False
|
||||
)
|
||||
else:
|
||||
print('only 1')
|
||||
model = self.base_quantifier
|
||||
model.set_params(**cls_configs[0])
|
||||
predictions = model.classifier_fit_predict(training, predict_on=model.val_split)
|
||||
models_preds = [(model, predictions)]
|
||||
|
||||
self.models = qp.util.parallel(
|
||||
self._delayed_fit_aggregation,
|
||||
((setup, training) for setup in itertools.product(models_preds, q_configs)),
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs,
|
||||
asarray=False
|
||||
)
|
||||
else:
|
||||
configs = qp.model_selection.expand_grid(self.param_grid)
|
||||
self.models = qp.util.parallel(
|
||||
self._delayed_fit,
|
||||
((params, training) for params in hyper),
|
||||
((params, training) for params in configs),
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs
|
||||
n_jobs=self.n_jobs,
|
||||
asarray=False
|
||||
)
|
||||
return self
|
||||
|
||||
|
@ -80,13 +186,13 @@ class MedianEstimator(BinaryQuantifier):
|
|||
self._delayed_predict,
|
||||
((model, instances) for model in self.models),
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs
|
||||
n_jobs=self.n_jobs,
|
||||
asarray=False
|
||||
)
|
||||
prev_preds = np.asarray(prev_preds)
|
||||
return np.median(prev_preds, axis=0)
|
||||
|
||||
|
||||
|
||||
class Ensemble(BaseQuantifier):
|
||||
VALID_POLICIES = {'ave', 'ptr', 'ds'} | qp.error.QUANTIFICATION_ERROR_NAMES
|
||||
|
||||
|
|
|
@ -194,7 +194,7 @@ class QuaNetTrainer(BaseQuantifier):
|
|||
label_predictions = np.argmax(posteriors, axis=-1)
|
||||
prevs_estim = []
|
||||
for quantifier in self.quantifiers.values():
|
||||
predictions = posteriors if isinstance(quantifier, AggregativeProbabilisticQuantifier) else label_predictions
|
||||
predictions = posteriors if isinstance(quantifier, AggregativeSoftQuantifier) else label_predictions
|
||||
prevs_estim.extend(quantifier.aggregate(predictions))
|
||||
|
||||
# there is no real need for adding static estims like the TPR or FPR from training since those are constant
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Union, Callable
|
||||
import numpy as np
|
||||
|
||||
from functional import get_divergence
|
||||
from quapy.functional import get_divergence
|
||||
from quapy.data import LabelledCollection
|
||||
from quapy.method.base import BaseQuantifier, BinaryQuantifier
|
||||
import quapy.functional as F
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import itertools
|
||||
import signal
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Union, Callable
|
||||
from functools import wraps
|
||||
|
||||
import numpy as np
|
||||
from sklearn import clone
|
||||
|
@ -10,10 +12,37 @@ import quapy as qp
|
|||
from quapy import evaluation
|
||||
from quapy.protocol import AbstractProtocol, OnLabelledCollectionProtocol
|
||||
from quapy.data.base import LabelledCollection
|
||||
from quapy.method.aggregative import BaseQuantifier
|
||||
from quapy.method.aggregative import BaseQuantifier, AggregativeQuantifier
|
||||
from quapy.util import timeout
|
||||
from time import time
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
SUCCESS = 1
|
||||
TIMEOUT = 2
|
||||
INVALID = 3
|
||||
ERROR = 4
|
||||
|
||||
|
||||
class ConfigStatus:
|
||||
def __init__(self, params, status, msg=''):
|
||||
self.params = params
|
||||
self.status = status
|
||||
self.msg = msg
|
||||
|
||||
def __str__(self):
|
||||
return f':params:{self.params} :status:{self.status} ' + self.msg
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
def success(self):
|
||||
return self.status == Status.SUCCESS
|
||||
|
||||
def failed(self):
|
||||
return self.status != Status.SUCCESS
|
||||
|
||||
|
||||
class GridSearchQ(BaseQuantifier):
|
||||
"""Grid Search optimization targeting a quantification-oriented metric.
|
||||
|
||||
|
@ -26,11 +55,14 @@ class GridSearchQ(BaseQuantifier):
|
|||
:param protocol: a sample generation protocol, an instance of :class:`quapy.protocol.AbstractProtocol`
|
||||
:param error: an error function (callable) or a string indicating the name of an error function (valid ones
|
||||
are those in :class:`quapy.error.QUANTIFICATION_ERROR`
|
||||
:param refit: whether or not to refit the model on the whole labelled collection (training+validation) with
|
||||
:param refit: whether to refit the model on the whole labelled collection (training+validation) with
|
||||
the best chosen hyperparameter combination. Ignored if protocol='gen'
|
||||
:param timeout: establishes a timer (in seconds) for each of the hyperparameters configurations being tested.
|
||||
Whenever a run takes longer than this timer, that configuration will be ignored. If all configurations end up
|
||||
being ignored, a TimeoutError exception is raised. If -1 (default) then no time bound is set.
|
||||
:param raise_errors: boolean, if True then raises an exception when a param combination yields any error, if
|
||||
otherwise is False (default), then the combination is marked with an error status, but the process goes on.
|
||||
However, if no configuration yields a valid model, then a ValueError exception will be raised.
|
||||
:param verbose: set to True to get information through the stdout
|
||||
"""
|
||||
|
||||
|
@ -42,6 +74,7 @@ class GridSearchQ(BaseQuantifier):
|
|||
refit=True,
|
||||
timeout=-1,
|
||||
n_jobs=None,
|
||||
raise_errors=False,
|
||||
verbose=False):
|
||||
|
||||
self.model = model
|
||||
|
@ -50,6 +83,7 @@ class GridSearchQ(BaseQuantifier):
|
|||
self.refit = refit
|
||||
self.timeout = timeout
|
||||
self.n_jobs = qp._get_njobs(n_jobs)
|
||||
self.raise_errors = raise_errors
|
||||
self.verbose = verbose
|
||||
self.__check_error(error)
|
||||
assert isinstance(protocol, AbstractProtocol), 'unknown protocol'
|
||||
|
@ -69,6 +103,98 @@ class GridSearchQ(BaseQuantifier):
|
|||
raise ValueError(f'unexpected error type; must either be a callable function or a str representing\n'
|
||||
f'the name of an error function in {qp.error.QUANTIFICATION_ERROR_NAMES}')
|
||||
|
||||
def _prepare_classifier(self, cls_params):
|
||||
model = deepcopy(self.model)
|
||||
|
||||
def job(cls_params):
|
||||
model.set_params(**cls_params)
|
||||
predictions = model.classifier_fit_predict(self._training)
|
||||
return predictions
|
||||
|
||||
predictions, status, took = self._error_handler(job, cls_params)
|
||||
self._sout(f'[classifier fit] hyperparams={cls_params} [took {took:.3f}s]')
|
||||
return model, predictions, status, took
|
||||
|
||||
def _prepare_aggregation(self, args):
|
||||
model, predictions, cls_took, cls_params, q_params = args
|
||||
model = deepcopy(model)
|
||||
params = {**cls_params, **q_params}
|
||||
|
||||
def job(q_params):
|
||||
model.set_params(**q_params)
|
||||
model.aggregation_fit(predictions, self._training)
|
||||
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
|
||||
return score
|
||||
|
||||
score, status, aggr_took = self._error_handler(job, q_params)
|
||||
self._print_status(params, score, status, aggr_took)
|
||||
return model, params, score, status, (cls_took+aggr_took)
|
||||
|
||||
def _prepare_nonaggr_model(self, params):
|
||||
model = deepcopy(self.model)
|
||||
|
||||
def job(params):
|
||||
model.set_params(**params)
|
||||
model.fit(self._training)
|
||||
score = evaluation.evaluate(model, protocol=self.protocol, error_metric=self.error)
|
||||
return score
|
||||
|
||||
score, status, took = self._error_handler(job, params)
|
||||
self._print_status(params, score, status, took)
|
||||
return model, params, score, status, took
|
||||
|
||||
def _compute_scores_aggregative(self, training):
|
||||
# break down the set of hyperparameters into two: classifier-specific, quantifier-specific
|
||||
cls_configs, q_configs = group_params(self.param_grid)
|
||||
|
||||
# train all classifiers and get the predictions
|
||||
self._training = training
|
||||
cls_outs = qp.util.parallel(
|
||||
self._prepare_classifier,
|
||||
cls_configs,
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs
|
||||
)
|
||||
|
||||
# filter out classifier configurations that yielded any error
|
||||
success_outs = []
|
||||
for (model, predictions, status, took), cls_config in zip(cls_outs, cls_configs):
|
||||
if status.success():
|
||||
success_outs.append((model, predictions, took, cls_config))
|
||||
else:
|
||||
self.error_collector.append(status)
|
||||
|
||||
if len(success_outs) == 0:
|
||||
raise ValueError('No valid configuration found for the classifier!')
|
||||
|
||||
# explore the quantifier-specific hyperparameters for each valid training configuration
|
||||
aggr_configs = [(*out, q_config) for out, q_config in itertools.product(success_outs, q_configs)]
|
||||
aggr_outs = qp.util.parallel(
|
||||
self._prepare_aggregation,
|
||||
aggr_configs,
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs
|
||||
)
|
||||
|
||||
return aggr_outs
|
||||
|
||||
def _compute_scores_nonaggregative(self, training):
|
||||
configs = expand_grid(self.param_grid)
|
||||
self._training = training
|
||||
scores = qp.util.parallel(
|
||||
self._prepare_nonaggr_model,
|
||||
configs,
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs
|
||||
)
|
||||
return scores
|
||||
|
||||
def _print_status(self, params, score, status, took):
|
||||
if status.success():
|
||||
self._sout(f'hyperparams=[{params}]\t got {self.error.__name__} = {score:.5f} [took {took:.3f}s]')
|
||||
else:
|
||||
self._sout(f'error={status}')
|
||||
|
||||
def fit(self, training: LabelledCollection):
|
||||
""" Learning routine. Fits methods with all combinations of hyperparameters and selects the one minimizing
|
||||
the error metric.
|
||||
|
@ -76,97 +202,63 @@ class GridSearchQ(BaseQuantifier):
|
|||
:param training: the training set on which to optimize the hyperparameters
|
||||
:return: self
|
||||
"""
|
||||
params_keys = list(self.param_grid.keys())
|
||||
params_values = list(self.param_grid.values())
|
||||
|
||||
protocol = self.protocol
|
||||
|
||||
self.param_scores_ = {}
|
||||
self.best_score_ = None
|
||||
if self.refit and not isinstance(self.protocol, OnLabelledCollectionProtocol):
|
||||
raise RuntimeWarning(
|
||||
f'"refit" was requested, but the protocol does not implement '
|
||||
f'the {OnLabelledCollectionProtocol.__name__} interface'
|
||||
)
|
||||
|
||||
tinit = time()
|
||||
|
||||
hyper = [dict({k: val[i] for i, k in enumerate(params_keys)}) for val in itertools.product(*params_values)]
|
||||
self._sout(f'starting model selection with {self.n_jobs =}')
|
||||
#pass a seed to parallel so it is set in clild processes
|
||||
scores = qp.util.parallel(
|
||||
self._delayed_eval,
|
||||
((params, training) for params in hyper),
|
||||
seed=qp.environ.get('_R_SEED', None),
|
||||
n_jobs=self.n_jobs
|
||||
)
|
||||
self.error_collector = []
|
||||
|
||||
for params, score, model in scores:
|
||||
if score is not None:
|
||||
self._sout(f'starting model selection with n_jobs={self.n_jobs}')
|
||||
if isinstance(self.model, AggregativeQuantifier):
|
||||
results = self._compute_scores_aggregative(training)
|
||||
else:
|
||||
results = self._compute_scores_nonaggregative(training)
|
||||
|
||||
self.param_scores_ = {}
|
||||
self.best_score_ = None
|
||||
for model, params, score, status, took in results:
|
||||
if status.success():
|
||||
if self.best_score_ is None or score < self.best_score_:
|
||||
self.best_score_ = score
|
||||
self.best_params_ = params
|
||||
self.best_model_ = model
|
||||
self.param_scores_[str(params)] = score
|
||||
else:
|
||||
self.param_scores_[str(params)] = 'timeout'
|
||||
self.param_scores_[str(params)] = status.status
|
||||
self.error_collector.append(status)
|
||||
|
||||
tend = time()-tinit
|
||||
|
||||
if self.best_score_ is None:
|
||||
raise TimeoutError('no combination of hyperparameters seem to work')
|
||||
raise ValueError('no combination of hyperparameters seemed to work')
|
||||
|
||||
self._sout(f'optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) '
|
||||
f'[took {tend:.4f}s]')
|
||||
|
||||
no_errors = len(self.error_collector)
|
||||
if no_errors>0:
|
||||
self._sout(f'warning: {no_errors} errors found')
|
||||
for err in self.error_collector:
|
||||
self._sout(f'\t{str(err)}')
|
||||
|
||||
if self.refit:
|
||||
if isinstance(protocol, OnLabelledCollectionProtocol):
|
||||
if isinstance(self.protocol, OnLabelledCollectionProtocol):
|
||||
tinit = time()
|
||||
self._sout(f'refitting on the whole development set')
|
||||
self.best_model_.fit(training + protocol.get_labelled_collection())
|
||||
self.best_model_.fit(training + self.protocol.get_labelled_collection())
|
||||
tend = time() - tinit
|
||||
self.refit_time_ = tend
|
||||
else:
|
||||
raise RuntimeWarning(f'"refit" was requested, but the protocol does not '
|
||||
f'implement the {OnLabelledCollectionProtocol.__name__} interface')
|
||||
# already checked
|
||||
raise RuntimeWarning(f'the model cannot be refit on the whole dataset')
|
||||
|
||||
return self
|
||||
|
||||
def _delayed_eval(self, args):
|
||||
params, training = args
|
||||
|
||||
protocol = self.protocol
|
||||
error = self.error
|
||||
|
||||
if self.timeout > 0:
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError()
|
||||
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
|
||||
tinit = time()
|
||||
|
||||
if self.timeout > 0:
|
||||
signal.alarm(self.timeout)
|
||||
|
||||
try:
|
||||
model = deepcopy(self.model)
|
||||
# overrides default parameters with the parameters being explored at this iteration
|
||||
model.set_params(**params)
|
||||
model.fit(training)
|
||||
score = evaluation.evaluate(model, protocol=protocol, error_metric=error)
|
||||
|
||||
ttime = time()-tinit
|
||||
self._sout(f'hyperparams={params}\t got {error.__name__} score {score:.5f} [took {ttime:.4f}s]')
|
||||
|
||||
if self.timeout > 0:
|
||||
signal.alarm(0)
|
||||
except TimeoutError:
|
||||
self._sout(f'timeout ({self.timeout}s) reached for config {params}')
|
||||
score = None
|
||||
except ValueError as e:
|
||||
self._sout(f'the combination of hyperparameters {params} is invalid')
|
||||
raise e
|
||||
except Exception as e:
|
||||
self._sout(f'something went wrong for config {params}; skipping:')
|
||||
self._sout(f'\tException: {e}')
|
||||
score = None
|
||||
|
||||
return params, score, model
|
||||
|
||||
|
||||
def quantify(self, instances):
|
||||
"""Estimate class prevalence values using the best model found after calling the :meth:`fit` method.
|
||||
|
||||
|
@ -203,7 +295,42 @@ class GridSearchQ(BaseQuantifier):
|
|||
return self.best_model_
|
||||
raise ValueError('best_model called before fit')
|
||||
|
||||
def _error_handler(self, func, params):
|
||||
"""
|
||||
Endorses one job with two returned values: the status, and the time of execution
|
||||
|
||||
:param func: the function to be called
|
||||
:param params: parameters of the function
|
||||
:return: `tuple(out, status, time)` where `out` is the function output,
|
||||
`status` is an enum value from `Status`, and `time` is the time it
|
||||
took to complete the call
|
||||
"""
|
||||
|
||||
output = None
|
||||
|
||||
def _handle(status, exception):
|
||||
if self.raise_errors:
|
||||
raise exception
|
||||
else:
|
||||
return ConfigStatus(params, status, str(e))
|
||||
|
||||
try:
|
||||
with timeout(self.timeout):
|
||||
tinit = time()
|
||||
output = func(params)
|
||||
status = ConfigStatus(params, Status.SUCCESS)
|
||||
|
||||
except TimeoutError as e:
|
||||
status = _handle(Status.TIMEOUT, str(e))
|
||||
|
||||
except ValueError as e:
|
||||
status = _handle(Status.INVALID, str(e))
|
||||
|
||||
except Exception as e:
|
||||
status = _handle(Status.ERROR, str(e))
|
||||
|
||||
took = time() - tinit
|
||||
return output, status, took
|
||||
|
||||
|
||||
def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfolds=3, random_state=0):
|
||||
|
@ -229,3 +356,43 @@ def cross_val_predict(quantifier: BaseQuantifier, data: LabelledCollection, nfol
|
|||
return total_prev
|
||||
|
||||
|
||||
def expand_grid(param_grid: dict):
|
||||
"""
|
||||
Expands a param_grid dictionary as a list of configurations.
|
||||
Example:
|
||||
|
||||
>>> combinations = expand_grid({'A': [1, 10, 100], 'B': [True, False]})
|
||||
>>> print(combinations)
|
||||
>>> [{'A': 1, 'B': True}, {'A': 1, 'B': False}, {'A': 10, 'B': True}, {'A': 10, 'B': False}, {'A': 100, 'B': True}, {'A': 100, 'B': False}]
|
||||
|
||||
:param param_grid: dictionary with keys representing hyper-parameter names, and values representing the range
|
||||
to explore for that hyper-parameter
|
||||
:return: a list of configurations, i.e., combinations of hyper-parameter assignments in the grid.
|
||||
"""
|
||||
params_keys = list(param_grid.keys())
|
||||
params_values = list(param_grid.values())
|
||||
configs = [{k: combs[i] for i, k in enumerate(params_keys)} for combs in itertools.product(*params_values)]
|
||||
return configs
|
||||
|
||||
|
||||
def group_params(param_grid: dict):
|
||||
"""
|
||||
Partitions a param_grid dictionary as two lists of configurations, one for the classifier-specific
|
||||
hyper-parameters, and another for que quantifier-specific hyper-parameters
|
||||
|
||||
:param param_grid: dictionary with keys representing hyper-parameter names, and values representing the range
|
||||
to explore for that hyper-parameter
|
||||
:return: two expanded grids of configurations, one for the classifier, another for the quantifier
|
||||
"""
|
||||
classifier_params, quantifier_params = {}, {}
|
||||
for key, values in param_grid.items():
|
||||
if key.startswith('classifier__') or key == 'val_split':
|
||||
classifier_params[key] = values
|
||||
else:
|
||||
quantifier_params[key] = values
|
||||
|
||||
classifier_configs = expand_grid(classifier_params)
|
||||
quantifier_configs = expand_grid(quantifier_params)
|
||||
|
||||
return classifier_configs, quantifier_configs
|
||||
|
||||
|
|
|
@ -22,9 +22,9 @@ class HierarchyTestCase(unittest.TestCase):
|
|||
def test_probabilistic(self):
|
||||
lr = LogisticRegression()
|
||||
for m in [CC(lr), ACC(lr)]:
|
||||
self.assertEqual(isinstance(m, AggregativeProbabilisticQuantifier), False)
|
||||
self.assertEqual(isinstance(m, AggregativeSoftQuantifier), False)
|
||||
for m in [PCC(lr), PACC(lr)]:
|
||||
self.assertEqual(isinstance(m, AggregativeProbabilisticQuantifier), True)
|
||||
self.assertEqual(isinstance(m, AggregativeSoftQuantifier), True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -10,6 +10,8 @@ import quapy as qp
|
|||
|
||||
import numpy as np
|
||||
from joblib import Parallel, delayed
|
||||
from time import time
|
||||
import signal
|
||||
|
||||
|
||||
def _get_parallel_slices(n_tasks, n_jobs):
|
||||
|
@ -38,7 +40,7 @@ def map_parallel(func, args, n_jobs):
|
|||
return list(itertools.chain.from_iterable(results))
|
||||
|
||||
|
||||
def parallel(func, args, n_jobs, seed=None):
|
||||
def parallel(func, args, n_jobs, seed=None, asarray=True, backend='loky'):
|
||||
"""
|
||||
A wrapper of multiprocessing:
|
||||
|
||||
|
@ -58,9 +60,12 @@ def parallel(func, args, n_jobs, seed=None):
|
|||
stack.enter_context(qp.util.temp_seed(seed))
|
||||
return func(*args)
|
||||
|
||||
return Parallel(n_jobs=n_jobs)(
|
||||
out = Parallel(n_jobs=n_jobs, backend=backend)(
|
||||
delayed(func_dec)(qp.environ, None if seed is None else seed+i, args_i) for i, args_i in enumerate(args)
|
||||
)
|
||||
if asarray:
|
||||
out = np.asarray(out)
|
||||
return out
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
@ -254,3 +259,35 @@ class EarlyStop:
|
|||
if self.patience <= 0:
|
||||
self.STOP = True
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def timeout(seconds):
|
||||
"""
|
||||
Opens a context that will launch an exception if not closed after a given number of seconds
|
||||
|
||||
>>> def func(start_msg, end_msg):
|
||||
>>> print(start_msg)
|
||||
>>> sleep(2)
|
||||
>>> print(end_msg)
|
||||
>>>
|
||||
>>> with timeout(1):
|
||||
>>> func('begin function', 'end function')
|
||||
>>> Out[]
|
||||
>>> begin function
|
||||
>>> TimeoutError
|
||||
|
||||
|
||||
:param seconds: number of seconds, set to <=0 to ignore the timer
|
||||
"""
|
||||
if seconds > 0:
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError()
|
||||
|
||||
signal.signal(signal.SIGALRM, handler)
|
||||
signal.alarm(seconds)
|
||||
|
||||
yield
|
||||
|
||||
if seconds > 0:
|
||||
signal.alarm(0)
|
||||
|
||||
|
|
Loading…
Reference in New Issue