From 2e992a0b9acb066ab6f5318296c7053415a7539b Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Fri, 10 Nov 2023 14:22:43 +0100 Subject: [PATCH] choosing plots for paper --- quapy/plot.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/quapy/plot.py b/quapy/plot.py index cdc3bd5..606a07a 100644 --- a/quapy/plot.py +++ b/quapy/plot.py @@ -216,9 +216,10 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, show_density=True, show_legend=True, logscale=False, - title=f'Quantification error as a function of distribution shift', + title=f'Quantification error as a function of label shift', vlines=None, method_order=None, + fontsize=12, savepath=None): """ Plots the error (along the x-axis, as measured in terms of `error_name`) as a function of the train-test shift @@ -247,6 +248,8 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, :param savepath: path where to save the plot. If not indicated (as default), the plot is shown. """ + plt.rcParams['font.size'] = fontsize + fig, ax = plt.subplots() ax.grid() @@ -261,7 +264,7 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, if method_order is None: method_order = method_names - _set_colors(ax, n_methods=len(method_order)) + # _set_colors(ax, n_methods=len(method_order)) bins = np.linspace(0, 1, n_bins+1) binwidth = 1 / n_bins @@ -291,6 +294,9 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, ys = np.asarray(ys) ystds = np.asarray(ystds) + if ys[-1] max_x else max_x @@ -302,7 +308,7 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, ax.errorbar(xs, ys, fmt='-', marker='o', label=method, markersize=6, linewidth=2, zorder=2) if show_std: - ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25) + ax.fill_between(xs, ys-ystds/3, ys+ystds/3, alpha=0.25) if show_density: ax2 = ax.twinx() @@ -313,8 +319,8 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, ax2.spines['right'].set_color('g') ax2.tick_params(axis='y', colors='g') - ax.set(xlabel=f'Distribution shift between training set and test sample', - ylabel=f'{error_name.upper()} (true distribution, predicted distribution)', + ax.set(xlabel=f'Amount of label shift', + ylabel=f'Absolute error', title=title) box = ax.get_position() ax.set_position([box.x0, box.y0, box.width * 0.8, box.height]) @@ -329,10 +335,11 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, if show_legend: - fig.legend(loc='lower center', - bbox_to_anchor=(1, 0.5), - ncol=(len(method_names)+1)//2) - + ax.legend(loc='center right', bbox_to_anchor=(1.2, 0.5)) + # fig.legend(loc='lower center', + # bbox_to_anchor=(1, 0.5), + # ncol=(len(method_names)+1)//2) + _save_or_show(savepath)