1
0
Fork 0

adding broken bar plot

This commit is contained in:
Alejandro Moreo Fernandez 2021-11-22 17:21:28 +01:00
parent 1ae45f8b9f
commit 027e18f1e7
3 changed files with 16 additions and 14 deletions

View File

@ -61,6 +61,8 @@ nice = {
'Average': 'Average',
'EMdiag':'EM$_{diag}$', 'EMfull':'EM$_{full}$', 'EMtied':'EM$_{tied}$', 'EMspherical':'EM$_{sph}$',
'VEMdiag':'VEM$_{diag}$', 'VEMfull':'VEM$_{full}$', 'VEMtied':'VEM$_{tied}$', 'VEMspherical':'VEM$_{sph}$',
'epaccmaemae1k': 'E(PACC)$_\mathrm{AE}$',
'quanet': 'QuaNet'
}

View File

@ -18,7 +18,7 @@ os.makedirs(plotdir, exist_ok=True)
N_RUNS = N_FOLDS * N_REPEATS
def gather_results(methods, error_name):
def gather_results(methods, error_name, resultdir):
method_names, true_prevs, estim_prevs, tr_prevs = [], [], [], []
for method in methods:
for run in range(N_RUNS):
@ -35,7 +35,7 @@ def plot_error_by_drift(methods, error_name, logscale=False, path=None):
print('plotting error by drift')
if path is not None:
path = join(path, f'error_by_drift_{error_name}.{plotext}')
method_names, true_prevs, estim_prevs, tr_prevs = gather_results(methods, error_name)
method_names, true_prevs, estim_prevs, tr_prevs = gather_results(methods, error_name, resultdir)
qp.plot.error_by_drift(
method_names,
true_prevs,
@ -54,7 +54,7 @@ def diagonal_plot(methods, error_name, path=None):
print('plotting diagonal plots')
if path is not None:
path = join(path, f'diag_{error_name}')
method_names, true_prevs, estim_prevs, tr_prevs = gather_results(methods, error_name)
method_names, true_prevs, estim_prevs, tr_prevs = gather_results(methods, error_name, resultdir)
qp.plot.binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title='Positive', legend=True, show_std=True, savepath=f'{path}_pos.{plotext}')
@ -62,7 +62,7 @@ def binary_bias_global(methods, error_name, path=None):
print('plotting bias global')
if path is not None:
path = join(path, f'globalbias_{error_name}')
method_names, true_prevs, estim_prevs, tr_prevs = gather_results(methods, error_name)
method_names, true_prevs, estim_prevs, tr_prevs = gather_results(methods, error_name, resultdir)
qp.plot.binary_bias_global(method_names, true_prevs, estim_prevs, pos_class=1, title='Positive', savepath=f'{path}_pos.{plotext}')
@ -70,15 +70,15 @@ def binary_bias_bins(methods, error_name, path=None):
print('plotting bias local')
if path is not None:
path = join(path, f'localbias_{error_name}')
method_names, true_prevs, estim_prevs, tr_prevs = gather_results(methods, error_name)
method_names, true_prevs, estim_prevs, tr_prevs = gather_results(methods, error_name, resultdir)
qp.plot.binary_bias_bins(method_names, true_prevs, estim_prevs, pos_class=1, title='Positive', legend=True, savepath=f'{path}_pos.{plotext}')
if __name__ == '__main__':
plot_error_by_drift(METHODS, error_name='ae', path=plotdir)
plot_error_by_drift(METHODS, error_name='ae', path=plotdir)
diagonal_plot(METHODS, error_name='ae', path=plotdir)
diagonal_plot(METHODS, error_name='ae', path=plotdir)
binary_bias_global(METHODS, error_name='ae', path=plotdir)
binary_bias_global(METHODS, error_name='ae', path=plotdir)
binary_bias_bins(METHODS, error_name='ae', path=plotdir)
binary_bias_bins(METHODS, error_name='ae', path=plotdir)

View File

@ -228,10 +228,10 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, e
if show_std:
ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25)
# xs = bins[:-1]
# ys = inds_histogram_global
# print(xs.shape, ys.shape)
# ax.errorbar(xs, ys, label='density')
xs = bins[:-1]
ys = inds_histogram_global
print(xs.shape, ys.shape)
ax.errorbar(xs, ys, label='density')
ax.set(xlabel=f'Distribution shift between training set and test sample',
ylabel=f'{error_name.upper()} (true distribution, predicted distribution)',