parallelization added
This commit is contained in:
parent
8705f2b3c0
commit
27a384c1a1
|
|
@ -15,7 +15,7 @@ import quacc.error
|
||||||
from quacc.data import ExtendedCollection, ExtendedData
|
from quacc.data import ExtendedCollection, ExtendedData
|
||||||
from quacc.environment import env
|
from quacc.environment import env
|
||||||
from quacc.evaluation import evaluate
|
from quacc.evaluation import evaluate
|
||||||
from quacc.logger import SubLogger
|
from quacc.logger import Logger, SubLogger
|
||||||
from quacc.method.base import (
|
from quacc.method.base import (
|
||||||
BaseAccuracyEstimator,
|
BaseAccuracyEstimator,
|
||||||
BinaryQuantifierAccuracyEstimator,
|
BinaryQuantifierAccuracyEstimator,
|
||||||
|
|
@ -32,7 +32,7 @@ class GridSearchAE(BaseAccuracyEstimator):
|
||||||
error: Union[Callable, str] = qc.error.maccd,
|
error: Union[Callable, str] = qc.error.maccd,
|
||||||
refit=True,
|
refit=True,
|
||||||
# timeout=-1,
|
# timeout=-1,
|
||||||
# n_jobs=None,
|
n_jobs=None,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
@ -40,7 +40,7 @@ class GridSearchAE(BaseAccuracyEstimator):
|
||||||
self.protocol = protocol
|
self.protocol = protocol
|
||||||
self.refit = refit
|
self.refit = refit
|
||||||
# self.timeout = timeout
|
# self.timeout = timeout
|
||||||
# self.n_jobs = qp._get_njobs(n_jobs)
|
self.n_jobs = qc._get_njobs(n_jobs)
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.__check_error(error)
|
self.__check_error(error)
|
||||||
assert isinstance(protocol, AbstractProtocol), "unknown protocol"
|
assert isinstance(protocol, AbstractProtocol), "unknown protocol"
|
||||||
|
|
@ -92,10 +92,16 @@ class GridSearchAE(BaseAccuracyEstimator):
|
||||||
dict(zip(params_keys, val)) for val in itertools.product(*params_values)
|
dict(zip(params_keys, val)) for val in itertools.product(*params_values)
|
||||||
]
|
]
|
||||||
|
|
||||||
# self._sout(f"starting model selection with {self.n_jobs =}")
|
self._sout(f"starting model selection with {self.n_jobs =}")
|
||||||
self._sout("starting model selection")
|
# self._sout("starting model selection")
|
||||||
|
|
||||||
scores = [self.__params_eval(params, training) for params in hyper]
|
# scores = [self.__params_eval((params, training)) for params in hyper]
|
||||||
|
scores = qc.utils.parallel(
|
||||||
|
self._params_eval,
|
||||||
|
((params, training) for params in hyper),
|
||||||
|
seed=env._R_SEED,
|
||||||
|
n_jobs=self.n_jobs,
|
||||||
|
)
|
||||||
|
|
||||||
for params, score, model in scores:
|
for params, score, model in scores:
|
||||||
if score is not None:
|
if score is not None:
|
||||||
|
|
@ -118,7 +124,7 @@ class GridSearchAE(BaseAccuracyEstimator):
|
||||||
level=1,
|
level=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
log = SubLogger.logger()
|
log = Logger.logger()
|
||||||
log.debug(
|
log.debug(
|
||||||
f"[{self.model.__class__.__name__}] "
|
f"[{self.model.__class__.__name__}] "
|
||||||
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) "
|
f"optimization finished: best params {self.best_params_} (score={self.best_score_:.5f}) "
|
||||||
|
|
@ -137,7 +143,8 @@ class GridSearchAE(BaseAccuracyEstimator):
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __params_eval(self, params, training):
|
def _params_eval(self, args):
|
||||||
|
params, training = args
|
||||||
protocol = self.protocol
|
protocol = self.protocol
|
||||||
error = self.error
|
error = self.error
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue