forked from moreo/QuaPy
running ensembles
This commit is contained in:
parent
03cf73aff6
commit
2001c6d852
|
@ -63,7 +63,8 @@ def quantification_ensembles():
|
||||||
'n_jobs': settings.ENSEMBLE_N_JOBS,
|
'n_jobs': settings.ENSEMBLE_N_JOBS,
|
||||||
'param_grid': lr_params,
|
'param_grid': lr_params,
|
||||||
'param_mod_sel': param_mod_sel,
|
'param_mod_sel': param_mod_sel,
|
||||||
'val_split': 0.4
|
'val_split': 0.4,
|
||||||
|
'min_pos': 10
|
||||||
}
|
}
|
||||||
|
|
||||||
# hyperparameters will be evaluated within each quantifier of the ensemble, and so the typical model selection
|
# hyperparameters will be evaluated within each quantifier of the ensemble, and so the typical model selection
|
||||||
|
@ -71,13 +72,13 @@ def quantification_ensembles():
|
||||||
hyper_none = None
|
hyper_none = None
|
||||||
yield 'epaccmaeptr', EPACC(newLR(), optim='mae', policy='ptr', **common), hyper_none
|
yield 'epaccmaeptr', EPACC(newLR(), optim='mae', policy='ptr', **common), hyper_none
|
||||||
yield 'epaccmaemae', EPACC(newLR(), optim='mae', policy='mae', **common), hyper_none
|
yield 'epaccmaemae', EPACC(newLR(), optim='mae', policy='mae', **common), hyper_none
|
||||||
yield 'esldmaeptr', EEMQ(newLR(), optim='mae', policy='ptr', **common), hyper_none
|
#yield 'esldmaeptr', EEMQ(newLR(), optim='mae', policy='ptr', **common), hyper_none
|
||||||
yield 'esldmaemae', EEMQ(newLR(), optim='mae', policy='mae', **common), hyper_none
|
#yield 'esldmaemae', EEMQ(newLR(), optim='mae', policy='mae', **common), hyper_none
|
||||||
|
|
||||||
yield 'epaccmraeptr', EPACC(newLR(), optim='mrae', policy='ptr', **common), hyper_none
|
yield 'epaccmraeptr', EPACC(newLR(), optim='mrae', policy='ptr', **common), hyper_none
|
||||||
yield 'epaccmraemrae', EPACC(newLR(), optim='mrae', policy='mrae', **common), hyper_none
|
yield 'epaccmraemrae', EPACC(newLR(), optim='mrae', policy='mrae', **common), hyper_none
|
||||||
yield 'esldmraeptr', EEMQ(newLR(), optim='mrae', policy='ptr', **common), hyper_none
|
#yield 'esldmraeptr', EEMQ(newLR(), optim='mrae', policy='ptr', **common), hyper_none
|
||||||
yield 'esldmraemrae', EEMQ(newLR(), optim='mrae', policy='mrae', **common), hyper_none
|
#yield 'esldmraemrae', EEMQ(newLR(), optim='mrae', policy='mrae', **common), hyper_none
|
||||||
|
|
||||||
|
|
||||||
def evaluate_experiment(true_prevalences, estim_prevalences):
|
def evaluate_experiment(true_prevalences, estim_prevalences):
|
||||||
|
@ -178,8 +179,8 @@ def run(experiment):
|
||||||
benchmark_eval.training.prevalence(), test_true_prevalence, test_estim_prevalence,
|
benchmark_eval.training.prevalence(), test_true_prevalence, test_estim_prevalence,
|
||||||
best_params)
|
best_params)
|
||||||
|
|
||||||
if isinstance(model, QuaNet):
|
#if isinstance(model, QuaNet):
|
||||||
model.clean_checkpoint_dir()
|
#model.clean_checkpoint_dir()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -195,24 +196,24 @@ if __name__ == '__main__':
|
||||||
print(f'Result folder: {args.results}')
|
print(f'Result folder: {args.results}')
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
|
||||||
optim_losses = ['mae'] # ['mae', 'mrae']
|
optim_losses = ['mae', 'mrae']
|
||||||
datasets = qp.datasets.TWITTER_SENTIMENT_DATASETS_TRAIN
|
datasets = qp.datasets.TWITTER_SENTIMENT_DATASETS_TRAIN
|
||||||
|
|
||||||
#models = quantification_models()
|
models = quantification_models()
|
||||||
#Parallel(n_jobs=settings.N_JOBS)(
|
Parallel(n_jobs=settings.N_JOBS)(
|
||||||
# delayed(run)(experiment) for experiment in itertools.product(optim_losses, datasets, models)
|
delayed(run)(experiment) for experiment in itertools.product(optim_losses, datasets, models)
|
||||||
#)
|
)
|
||||||
|
|
||||||
#models = quantification_cuda_models()
|
models = quantification_cuda_models()
|
||||||
#Parallel(n_jobs=settings.CUDA_N_JOBS)(
|
Parallel(n_jobs=settings.CUDA_N_JOBS)(
|
||||||
# delayed(run)(experiment) for experiment in itertools.product(optim_losses, datasets, models)
|
delayed(run)(experiment) for experiment in itertools.product(optim_losses, datasets, models)
|
||||||
#)
|
)
|
||||||
|
|
||||||
models = quantification_ensembles()
|
models = quantification_ensembles()
|
||||||
Parallel(n_jobs=1)(
|
Parallel(n_jobs=1)(
|
||||||
delayed(run)(experiment) for experiment in itertools.product(optim_losses, datasets, models)
|
delayed(run)(experiment) for experiment in itertools.product(optim_losses, datasets, models)
|
||||||
)
|
)
|
||||||
|
|
||||||
shutil.rmtree(args.checkpointdir, ignore_errors=True)
|
#shutil.rmtree(args.checkpointdir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
|
||||||
N_JOBS = -2 #multiprocessing.cpu_count()
|
N_JOBS = -2 #multiprocessing.cpu_count()
|
||||||
CUDA_N_JOBS = 1
|
CUDA_N_JOBS = 2
|
||||||
ENSEMBLE_N_JOBS = -2
|
ENSEMBLE_N_JOBS = -2
|
||||||
|
|
||||||
SAMPLE_SIZE = 100
|
SAMPLE_SIZE = 100
|
||||||
|
|
|
@ -72,7 +72,8 @@ class Ensemble(BaseQuantifier):
|
||||||
|
|
||||||
# randomly chooses the prevalences for each member of the ensemble (preventing classes with less than
|
# randomly chooses the prevalences for each member of the ensemble (preventing classes with less than
|
||||||
# min_pos positive examples)
|
# min_pos positive examples)
|
||||||
prevs = [_draw_simplex(ndim=data.n_classes, min_val=self.min_pos / len(data)) for _ in range(self.size)]
|
sample_size = len(data) if self.max_sample_size is None else min(self.max_sample_size, len(data))
|
||||||
|
prevs = [_draw_simplex(ndim=data.n_classes, min_val=self.min_pos / sample_size) for _ in range(self.size)]
|
||||||
|
|
||||||
posteriors = None
|
posteriors = None
|
||||||
if self.policy == 'ds':
|
if self.policy == 'ds':
|
||||||
|
@ -80,7 +81,7 @@ class Ensemble(BaseQuantifier):
|
||||||
posteriors, self.post_proba_fn = self.ds_policy_get_posteriors(data)
|
posteriors, self.post_proba_fn = self.ds_policy_get_posteriors(data)
|
||||||
|
|
||||||
is_static_policy = (self.policy in qp.error.QUANTIFICATION_ERROR_NAMES)
|
is_static_policy = (self.policy in qp.error.QUANTIFICATION_ERROR_NAMES)
|
||||||
sample_size = len(data) if self.max_sample_size is None else min(self.max_sample_size, len(data))
|
|
||||||
self.ensemble = Parallel(n_jobs=self.n_jobs)(
|
self.ensemble = Parallel(n_jobs=self.n_jobs)(
|
||||||
delayed(_delayed_new_instance)(
|
delayed(_delayed_new_instance)(
|
||||||
self.base_quantifier, data, val_split, prev, posteriors, keep_samples=is_static_policy,
|
self.base_quantifier, data, val_split, prev, posteriors, keep_samples=is_static_policy,
|
||||||
|
|
Loading…
Reference in New Issue