forked from moreo/QuaPy
changing gridsearchQ to ensure reproducibility
This commit is contained in:
parent
c91961cff5
commit
a4584b79db
|
@ -83,7 +83,8 @@ class GridSearchQ(BaseQuantifier):
|
|||
tinit = time()
|
||||
|
||||
hyper = [dict({k: values[i] for i, k in enumerate(params_keys)}) for values in itertools.product(*params_values)]
|
||||
scores = qp.util.parallel(self._delayed_eval, ((params, training) for params in hyper), n_jobs=self.n_jobs)
|
||||
#pass a seed to parallel so it is set in clild processes
|
||||
scores = qp.util.parallel(self._delayed_eval, ((params, training) for params in hyper), seed=qp.environ.get('_R_SEED', None), n_jobs=self.n_jobs)
|
||||
|
||||
for params, score, model in scores:
|
||||
if score is not None:
|
||||
|
|
|
@ -5,6 +5,7 @@ import os
|
|||
import pickle
|
||||
import urllib
|
||||
from pathlib import Path
|
||||
from contextlib import ExitStack
|
||||
import quapy as qp
|
||||
|
||||
import numpy as np
|
||||
|
@ -36,7 +37,7 @@ def map_parallel(func, args, n_jobs):
|
|||
return list(itertools.chain.from_iterable(results))
|
||||
|
||||
|
||||
def parallel(func, args, n_jobs):
|
||||
def parallel(func, args, n_jobs, seed = None):
|
||||
"""
|
||||
A wrapper of multiprocessing:
|
||||
|
||||
|
@ -44,14 +45,20 @@ def parallel(func, args, n_jobs):
|
|||
>>> delayed(func)(args_i) for args_i in args
|
||||
>>> )
|
||||
|
||||
that takes the `quapy.environ` variable as input silently
|
||||
that takes the `quapy.environ` variable as input silently.
|
||||
Seeds the child processes to ensure reproducibility when n_jobs>1
|
||||
"""
|
||||
def func_dec(environ, *args):
|
||||
def func_dec(environ, seed, *args):
|
||||
qp.environ = environ.copy()
|
||||
qp.environ['N_JOBS'] = 1
|
||||
#set a context with a temporal seed to ensure results are reproducibles in parallel
|
||||
with ExitStack() as stack:
|
||||
if seed is not None:
|
||||
stack.enter_context(qp.util.temp_seed(seed))
|
||||
return func(*args)
|
||||
|
||||
return Parallel(n_jobs=n_jobs)(
|
||||
delayed(func_dec)(qp.environ, args_i) for args_i in args
|
||||
delayed(func_dec)(qp.environ, None if seed is None else seed+i, args_i) for i, args_i in enumerate(args)
|
||||
)
|
||||
|
||||
|
||||
|
@ -66,6 +73,8 @@ def temp_seed(random_state):
|
|||
:param random_state: the seed to set within the "with" context
|
||||
"""
|
||||
state = np.random.get_state()
|
||||
#save the seed just in case is needed (for instance for setting the seed to child processes)
|
||||
qp.environ['_R_SEED'] = random_state
|
||||
np.random.seed(random_state)
|
||||
try:
|
||||
yield
|
||||
|
|
Loading…
Reference in New Issue