parallelization added

This commit is contained in:
Lorenzo Volpi 2023-12-02 02:11:30 +01:00
parent 8705f2b3c0
commit 27a384c1a1
1 changed files with 15 additions and 8 deletions

View File

@ -15,7 +15,7 @@ import quacc.error
from quacc.data import ExtendedCollection, ExtendedData from quacc.data import ExtendedCollection, ExtendedData
from quacc.environment import env from quacc.environment import env
from quacc.evaluation import evaluate from quacc.evaluation import evaluate
from quacc.logger import SubLogger from quacc.logger import Logger, SubLogger
from quacc.method.base import ( from quacc.method.base import (
BaseAccuracyEstimator, BaseAccuracyEstimator,
BinaryQuantifierAccuracyEstimator, BinaryQuantifierAccuracyEstimator,
@ -32,7 +32,7 @@ class GridSearchAE(BaseAccuracyEstimator):
error: Union[Callable, str] = qc.error.maccd, error: Union[Callable, str] = qc.error.maccd,
refit=True, refit=True,
# timeout=-1, # timeout=-1,
# n_jobs=None, n_jobs=None,
verbose=False, verbose=False,
): ):
self.model = model self.model = model
@ -40,7 +40,7 @@ class GridSearchAE(BaseAccuracyEstimator):
self.protocol = protocol self.protocol = protocol
self.refit = refit self.refit = refit
# self.timeout = timeout # self.timeout = timeout
# self.n_jobs = qp._get_njobs(n_jobs) self.n_jobs = qc._get_njobs(n_jobs)
self.verbose = verbose self.verbose = verbose
self.__check_error(error) self.__check_error(error)
assert isinstance(protocol, AbstractProtocol), "unknown protocol" assert isinstance(protocol, AbstractProtocol), "unknown protocol"
@ -92,10 +92,16 @@ class GridSearchAE(BaseAccuracyEstimator):
dict(zip(params_keys, val)) for val in itertools.product(*params_values) dict(zip(params_keys, val)) for val in itertools.product(*params_values)
] ]
# self._sout(f"starting model selection with {self.n_jobs =}") self._sout(f"starting model selection with {self.n_jobs =}")
self._sout("starting model selection") # self._sout("starting model selection")
scores = [self.__params_eval(params, training) for params in hyper] # scores = [self.__params_eval((params, training)) for params in hyper]
scores = qc.utils.parallel(
self._params_eval,
((params, training) for params in hyper),
seed=env._R_SEED,
n_jobs=self.n_jobs,
)
for params, score, model in scores: for params, score, model in scores:
if score is not None: if score is not None:
@ -118,7 +124,7 @@ class GridSearchAE(BaseAccuracyEstimator):
level=1, level=1,
) )
log = SubLogger.logger() log = Logger.logger()
log.debug( log.debug(
f"[{self.model.__class__.__name__}] " f"[{self.model.__class__.__name__}] "
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) " f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) "
@ -137,7 +143,8 @@ class GridSearchAE(BaseAccuracyEstimator):
return self return self
def __params_eval(self, params, training): def _params_eval(self, args):
params, training = args
protocol = self.protocol protocol = self.protocol
error = self.error error = self.error