diff --git a/quapy/plot.py b/quapy/plot.py index cdb9b1e..2e41413 100644 --- a/quapy/plot.py +++ b/quapy/plot.py @@ -4,6 +4,7 @@ from matplotlib.cm import get_cmap import numpy as np from matplotlib import cm from scipy.stats import ttest_ind_from_stats +from matplotlib.ticker import ScalarFormatter import quapy as qp @@ -212,6 +213,7 @@ def binary_bias_bins(method_names, true_prevs, estim_prevs, pos_class=1, title=N def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, error_name='ae', show_std=False, show_density=True, + show_legend=True, logscale=False, title=f'Quantification error as a function of distribution shift', vlines=None, @@ -234,6 +236,7 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, :param error_name: a string representing the name of an error function (as defined in `quapy.error`, default is "ae") :param show_std: whether or not to show standard deviations as color bands (default is False) :param show_density: whether or not to display the distribution of experiments for each bin (default is True) + :param show_density: whether or not to display the legend of the chart (default is True) :param logscale: whether or not to log-scale the y-error measure (default is False) :param title: title of the plot (default is "Quantification error as a function of distribution shift") :param vlines: array-like list of values (default is None). If indicated, highlights some regions of the space @@ -254,6 +257,9 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, # x_error function) and 'y' is the estim-test shift (computed as according to y_error) data = _join_data_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, x_error, y_error, method_order) + if method_order is None: + method_order = method_names + _set_colors(ax, n_methods=len(method_order)) bins = np.linspace(0, 1, n_bins+1) @@ -264,7 +270,11 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, tr_test_drifts = data[method]['x'] method_drifts = data[method]['y'] if logscale: - method_drifts=np.log(1+method_drifts) + ax.set_yscale("log") + ax.yaxis.set_major_formatter(ScalarFormatter()) + ax.yaxis.set_minor_formatter(ScalarFormatter()) + ax.yaxis.get_major_formatter().set_scientific(False) + ax.yaxis.get_minor_formatter().set_scientific(False) inds = np.digitize(tr_test_drifts, bins, right=True) @@ -295,9 +305,15 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25) if show_density: - ax.bar([ind * binwidth-binwidth/2 for ind in range(len(bins))], + ax2 = ax.twinx() + ax2.bar([ind * binwidth-binwidth/2 for ind in range(len(bins))], max_y*npoints/np.max(npoints), alpha=0.15, color='g', width=binwidth, label='density') - + #ax2.set_ylabel("bar data") + ax2.set_ylim(0,1) + ax2.spines['right'].set_color('g') + ax2.tick_params(axis='y', colors='g') + #ax2.yaxis.set_visible(False) + ax.set(xlabel=f'Distribution shift between training set and test sample', ylabel=f'{error_name.upper()} (true distribution, predicted distribution)', title=title) @@ -306,9 +322,13 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, if vlines: for vline in vlines: ax.axvline(vline, 0, 1, linestyle='--', color='k') - ax.set_xlim(0, max_x) - ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) + + ax.set_xlim(min_x, max_x) + + if show_legend: + fig.legend(loc='right') + _save_or_show(savepath)