1
0
Fork 0

changing gridsearchQ to ensure reproducibility

This commit is contained in:
Pablo González 2022-07-11 16:27:02 +02:00
parent c91961cff5
commit a4584b79db
2 changed files with 16 additions and 6 deletions

View File

@ -83,7 +83,8 @@ class GridSearchQ(BaseQuantifier):
tinit = time() tinit = time()
hyper = [dict({k: values[i] for i, k in enumerate(params_keys)}) for values in itertools.product(*params_values)] hyper = [dict({k: values[i] for i, k in enumerate(params_keys)}) for values in itertools.product(*params_values)]
scores = qp.util.parallel(self._delayed_eval, ((params, training) for params in hyper), n_jobs=self.n_jobs) #pass a seed to parallel so it is set in clild processes
scores = qp.util.parallel(self._delayed_eval, ((params, training) for params in hyper), seed=qp.environ.get('_R_SEED', None), n_jobs=self.n_jobs)
for params, score, model in scores: for params, score, model in scores:
if score is not None: if score is not None:

View File

@ -5,6 +5,7 @@ import os
import pickle import pickle
import urllib import urllib
from pathlib import Path from pathlib import Path
from contextlib import ExitStack
import quapy as qp import quapy as qp
import numpy as np import numpy as np
@ -36,7 +37,7 @@ def map_parallel(func, args, n_jobs):
return list(itertools.chain.from_iterable(results)) return list(itertools.chain.from_iterable(results))
def parallel(func, args, n_jobs): def parallel(func, args, n_jobs, seed = None):
""" """
A wrapper of multiprocessing: A wrapper of multiprocessing:
@ -44,14 +45,20 @@ def parallel(func, args, n_jobs):
>>> delayed(func)(args_i) for args_i in args >>> delayed(func)(args_i) for args_i in args
>>> ) >>> )
that takes the `quapy.environ` variable as input silently that takes the `quapy.environ` variable as input silently.
Seeds the child processes to ensure reproducibility when n_jobs>1
""" """
def func_dec(environ, *args): def func_dec(environ, seed, *args):
qp.environ = environ.copy() qp.environ = environ.copy()
qp.environ['N_JOBS'] = 1 qp.environ['N_JOBS'] = 1
return func(*args) #set a context with a temporal seed to ensure results are reproducibles in parallel
with ExitStack() as stack:
if seed is not None:
stack.enter_context(qp.util.temp_seed(seed))
return func(*args)
return Parallel(n_jobs=n_jobs)( return Parallel(n_jobs=n_jobs)(
delayed(func_dec)(qp.environ, args_i) for args_i in args delayed(func_dec)(qp.environ, None if seed is None else seed+i, args_i) for i, args_i in enumerate(args)
) )
@ -66,6 +73,8 @@ def temp_seed(random_state):
:param random_state: the seed to set within the "with" context :param random_state: the seed to set within the "with" context
""" """
state = np.random.get_state() state = np.random.get_state()
#save the seed just in case is needed (for instance for setting the seed to child processes)
qp.environ['_R_SEED'] = random_state
np.random.seed(random_state) np.random.seed(random_state)
try: try:
yield yield