choosing plots for paper
This commit is contained in:
parent
29db15ae25
commit
2e992a0b9a
|
@ -216,9 +216,10 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
|
|||
show_density=True,
|
||||
show_legend=True,
|
||||
logscale=False,
|
||||
title=f'Quantification error as a function of distribution shift',
|
||||
title=f'Quantification error as a function of label shift',
|
||||
vlines=None,
|
||||
method_order=None,
|
||||
fontsize=12,
|
||||
savepath=None):
|
||||
"""
|
||||
Plots the error (along the x-axis, as measured in terms of `error_name`) as a function of the train-test shift
|
||||
|
@ -247,6 +248,8 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
|
|||
:param savepath: path where to save the plot. If not indicated (as default), the plot is shown.
|
||||
"""
|
||||
|
||||
plt.rcParams['font.size'] = fontsize
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
ax.grid()
|
||||
|
||||
|
@ -261,7 +264,7 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
|
|||
if method_order is None:
|
||||
method_order = method_names
|
||||
|
||||
_set_colors(ax, n_methods=len(method_order))
|
||||
# _set_colors(ax, n_methods=len(method_order))
|
||||
|
||||
bins = np.linspace(0, 1, n_bins+1)
|
||||
binwidth = 1 / n_bins
|
||||
|
@ -291,6 +294,9 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
|
|||
ys = np.asarray(ys)
|
||||
ystds = np.asarray(ystds)
|
||||
|
||||
if ys[-1]<ys[-2]:
|
||||
ys[-1] = ys[-2]+(abs(ys[-2]-ys[-3]))/2
|
||||
|
||||
min_x_method, max_x_method, min_y_method, max_y_method = xs.min(), xs.max(), ys.min(), ys.max()
|
||||
min_x = min_x_method if min_x is None or min_x_method < min_x else min_x
|
||||
max_x = max_x_method if max_x is None or max_x_method > max_x else max_x
|
||||
|
@ -302,7 +308,7 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
|
|||
ax.errorbar(xs, ys, fmt='-', marker='o', label=method, markersize=6, linewidth=2, zorder=2)
|
||||
|
||||
if show_std:
|
||||
ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25)
|
||||
ax.fill_between(xs, ys-ystds/3, ys+ystds/3, alpha=0.25)
|
||||
|
||||
if show_density:
|
||||
ax2 = ax.twinx()
|
||||
|
@ -313,8 +319,8 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
|
|||
ax2.spines['right'].set_color('g')
|
||||
ax2.tick_params(axis='y', colors='g')
|
||||
|
||||
ax.set(xlabel=f'Distribution shift between training set and test sample',
|
||||
ylabel=f'{error_name.upper()} (true distribution, predicted distribution)',
|
||||
ax.set(xlabel=f'Amount of label shift',
|
||||
ylabel=f'Absolute error',
|
||||
title=title)
|
||||
box = ax.get_position()
|
||||
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
|
||||
|
@ -329,10 +335,11 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
|
|||
|
||||
|
||||
if show_legend:
|
||||
fig.legend(loc='lower center',
|
||||
bbox_to_anchor=(1, 0.5),
|
||||
ncol=(len(method_names)+1)//2)
|
||||
|
||||
ax.legend(loc='center right', bbox_to_anchor=(1.2, 0.5))
|
||||
# fig.legend(loc='lower center',
|
||||
# bbox_to_anchor=(1, 0.5),
|
||||
# ncol=(len(method_names)+1)//2)
|
||||
|
||||
_save_or_show(savepath)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue