forked from moreo/QuaPy
changing gridsearchQ to ensure reproducibility
This commit is contained in:
parent
c91961cff5
commit
a4584b79db
|
@ -83,7 +83,8 @@ class GridSearchQ(BaseQuantifier):
|
||||||
tinit = time()
|
tinit = time()
|
||||||
|
|
||||||
hyper = [dict({k: values[i] for i, k in enumerate(params_keys)}) for values in itertools.product(*params_values)]
|
hyper = [dict({k: values[i] for i, k in enumerate(params_keys)}) for values in itertools.product(*params_values)]
|
||||||
scores = qp.util.parallel(self._delayed_eval, ((params, training) for params in hyper), n_jobs=self.n_jobs)
|
#pass a seed to parallel so it is set in clild processes
|
||||||
|
scores = qp.util.parallel(self._delayed_eval, ((params, training) for params in hyper), seed=qp.environ.get('_R_SEED', None), n_jobs=self.n_jobs)
|
||||||
|
|
||||||
for params, score, model in scores:
|
for params, score, model in scores:
|
||||||
if score is not None:
|
if score is not None:
|
||||||
|
|
|
@ -5,6 +5,7 @@ import os
|
||||||
import pickle
|
import pickle
|
||||||
import urllib
|
import urllib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from contextlib import ExitStack
|
||||||
import quapy as qp
|
import quapy as qp
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -36,7 +37,7 @@ def map_parallel(func, args, n_jobs):
|
||||||
return list(itertools.chain.from_iterable(results))
|
return list(itertools.chain.from_iterable(results))
|
||||||
|
|
||||||
|
|
||||||
def parallel(func, args, n_jobs):
|
def parallel(func, args, n_jobs, seed = None):
|
||||||
"""
|
"""
|
||||||
A wrapper of multiprocessing:
|
A wrapper of multiprocessing:
|
||||||
|
|
||||||
|
@ -44,14 +45,20 @@ def parallel(func, args, n_jobs):
|
||||||
>>> delayed(func)(args_i) for args_i in args
|
>>> delayed(func)(args_i) for args_i in args
|
||||||
>>> )
|
>>> )
|
||||||
|
|
||||||
that takes the `quapy.environ` variable as input silently
|
that takes the `quapy.environ` variable as input silently.
|
||||||
|
Seeds the child processes to ensure reproducibility when n_jobs>1
|
||||||
"""
|
"""
|
||||||
def func_dec(environ, *args):
|
def func_dec(environ, seed, *args):
|
||||||
qp.environ = environ.copy()
|
qp.environ = environ.copy()
|
||||||
qp.environ['N_JOBS'] = 1
|
qp.environ['N_JOBS'] = 1
|
||||||
return func(*args)
|
#set a context with a temporal seed to ensure results are reproducibles in parallel
|
||||||
|
with ExitStack() as stack:
|
||||||
|
if seed is not None:
|
||||||
|
stack.enter_context(qp.util.temp_seed(seed))
|
||||||
|
return func(*args)
|
||||||
|
|
||||||
return Parallel(n_jobs=n_jobs)(
|
return Parallel(n_jobs=n_jobs)(
|
||||||
delayed(func_dec)(qp.environ, args_i) for args_i in args
|
delayed(func_dec)(qp.environ, None if seed is None else seed+i, args_i) for i, args_i in enumerate(args)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,6 +73,8 @@ def temp_seed(random_state):
|
||||||
:param random_state: the seed to set within the "with" context
|
:param random_state: the seed to set within the "with" context
|
||||||
"""
|
"""
|
||||||
state = np.random.get_state()
|
state = np.random.get_state()
|
||||||
|
#save the seed just in case is needed (for instance for setting the seed to child processes)
|
||||||
|
qp.environ['_R_SEED'] = random_state
|
||||||
np.random.seed(random_state)
|
np.random.seed(random_state)
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
|
|
Loading…
Reference in New Issue