adding density to error-by-drift plot

This commit is contained in:
Alejandro Moreo Fernandez 2021-11-22 17:34:36 +01:00
parent bdbe933a41
commit 7fd32d5c5f
2 changed files with 14 additions and 6 deletions

View File

@ -12,8 +12,8 @@ from os.path import join
qp.environ['SAMPLE_SIZE'] = settings.SAMPLE_SIZE
plotext='png'
resultdir = './results'
plotdir = './plots'
resultdir = './results_npp'
plotdir = './plots_npp'
os.makedirs(plotdir, exist_ok=True)
def gather_results(methods, error_name):
@ -50,6 +50,7 @@ def plot_error_by_drift(methods, error_name, logscale=False, path=None):
logscale=logscale,
title=f'Quantification error as a function of distribution shift',
savepath=path,
vlines=[0.02, 0.1055],
method_order=method_order
)

View File

@ -176,10 +176,12 @@ def _set_colors(ax, n_methods):
ax.set_prop_cycle(color=[cm(1. * i / NUM_COLORS) for i in range(NUM_COLORS)])
def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, error_name='ae', show_std=True,
def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, error_name='ae', show_std=False,
show_density=True,
logscale=False,
title=f'Quantification error as a function of distribution shift',
savepath=None,
vlines=None,
method_order=None):
fig, ax = plt.subplots()
@ -246,15 +248,20 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, e
if show_std:
ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25)
ax.bar([ind * binwidth-binwidth/2 for ind in range(len(bins))], max_y*npoints/np.max(npoints), alpha=0.15, color='g', width=binwidth, label='density')
if show_density:
ax.bar([ind * binwidth-binwidth/2 for ind in range(len(bins))],
max_y*npoints/np.max(npoints), alpha=0.15, color='g', width=binwidth, label='density')
ax.set(xlabel=f'Distribution shift between training set and test sample',
ylabel=f'{error_name.upper()} (true distribution, predicted distribution)',
title=title)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.axvline(0.02, 0, 1, linestyle='--', color='k')
ax.axvline(0.1055, 0, 1, linestyle='--', color='k')
if vlines:
for vline in vlines:
ax.axvline(vline, 0, 1, linestyle='--', color='k')
# ax.axvline(0.02, 0, 1, linestyle='--', color='k')
# ax.axvline(0.1055, 0, 1, linestyle='--', color='k')
ax.set_xlim(0, max_x)
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))