parallelization changed to joblib

This commit is contained in:
Lorenzo Volpi 2023-12-02 02:08:16 +01:00
parent 78c210b15f
commit 8705f2b3c0
1 changed files with 46 additions and 41 deletions

View File

@ -5,6 +5,7 @@ from traceback import print_exception as traceback
import pandas as pd import pandas as pd
import quapy as qp import quapy as qp
from joblib import Parallel, delayed
from quacc.dataset import Dataset from quacc.dataset import Dataset
from quacc.environment import env from quacc.environment import env
@ -21,49 +22,53 @@ def evaluate_comparison(dataset: Dataset, estimators=None) -> DatasetReport:
log = Logger.logger() log = Logger.logger()
# with multiprocessing.Pool(1) as pool: # with multiprocessing.Pool(1) as pool:
__pool_size = round(os.cpu_count() * 0.8) __pool_size = round(os.cpu_count() * 0.8)
with multiprocessing.Pool(__pool_size) as pool: # with multiprocessing.Pool(__pool_size) as pool:
dr = DatasetReport(dataset.name) dr = DatasetReport(dataset.name)
log.info(f"dataset {dataset.name} [pool size: {__pool_size}]") log.info(f"dataset {dataset.name} [pool size: {__pool_size}]")
for d in dataset(): for d in dataset():
log.info( log.info(
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} started" f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} started"
)
tasks = [
WorkerArgs(
_estimate=estim,
train=d.train,
validation=d.validation,
test=d.test,
_env=env,
q=Logger.queue(),
) )
tasks = [ for estim in CE.func[estimators]
WorkerArgs( ]
_estimate=estim, try:
train=d.train, tstart = time.time()
validation=d.validation, results = Parallel(n_jobs=1)(delayed(estimate_worker)(t) for t in tasks)
test=d.test, results = [r for r in results if r is not None]
_env=env, # # r for r in pool.imap(estimate_worker, tasks) if r is not None
q=Logger.queue(), # r
) # for r in map(estimate_worker, tasks)
for estim in CE.func[estimators] # if r is not None
] # ]
try:
tstart = time.time()
results = [
r for r in pool.imap(estimate_worker, tasks) if r is not None
]
g_time = time.time() - tstart g_time = time.time() - tstart
log.info( log.info(
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} finished " f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} finished "
f"[took {g_time:.4f}s]" f"[took {g_time:.4f}s]"
) )
cr = CompReport( cr = CompReport(
results, results,
name=dataset.name, name=dataset.name,
train_prev=d.train_prev, train_prev=d.train_prev,
valid_prev=d.validation_prev, valid_prev=d.validation_prev,
g_time=g_time, g_time=g_time,
) )
dr += cr dr += cr
except Exception as e: except Exception as e:
log.warning( log.warning(
f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} failed. " f"Dataset sample {d.train_prev[1]:.2f} of dataset {dataset.name} failed. "
f"Exception: {e}" f"Exception: {e}"
) )
traceback(e) traceback(e)
return dr return dr