Merge pull request #15 from pglez82/protocols

Protocols
This commit is contained in:
Alejandro Moreo Fernandez 2023-01-18 15:19:26 +01:00 committed by GitHub
commit 850f0e25db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 25 additions and 5 deletions

View File

@ -4,6 +4,7 @@ from matplotlib.cm import get_cmap
import numpy as np import numpy as np
from matplotlib import cm from matplotlib import cm
from scipy.stats import ttest_ind_from_stats from scipy.stats import ttest_ind_from_stats
from matplotlib.ticker import ScalarFormatter
import quapy as qp import quapy as qp
@ -212,6 +213,7 @@ def binary_bias_bins(method_names, true_prevs, estim_prevs, pos_class=1, title=N
def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
n_bins=20, error_name='ae', show_std=False, n_bins=20, error_name='ae', show_std=False,
show_density=True, show_density=True,
show_legend=True,
logscale=False, logscale=False,
title=f'Quantification error as a function of distribution shift', title=f'Quantification error as a function of distribution shift',
vlines=None, vlines=None,
@ -234,6 +236,7 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
:param error_name: a string representing the name of an error function (as defined in `quapy.error`, default is "ae") :param error_name: a string representing the name of an error function (as defined in `quapy.error`, default is "ae")
:param show_std: whether or not to show standard deviations as color bands (default is False) :param show_std: whether or not to show standard deviations as color bands (default is False)
:param show_density: whether or not to display the distribution of experiments for each bin (default is True) :param show_density: whether or not to display the distribution of experiments for each bin (default is True)
:param show_density: whether or not to display the legend of the chart (default is True)
:param logscale: whether or not to log-scale the y-error measure (default is False) :param logscale: whether or not to log-scale the y-error measure (default is False)
:param title: title of the plot (default is "Quantification error as a function of distribution shift") :param title: title of the plot (default is "Quantification error as a function of distribution shift")
:param vlines: array-like list of values (default is None). If indicated, highlights some regions of the space :param vlines: array-like list of values (default is None). If indicated, highlights some regions of the space
@ -254,6 +257,9 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
# x_error function) and 'y' is the estim-test shift (computed as according to y_error) # x_error function) and 'y' is the estim-test shift (computed as according to y_error)
data = _join_data_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, x_error, y_error, method_order) data = _join_data_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, x_error, y_error, method_order)
if method_order is None:
method_order = method_names
_set_colors(ax, n_methods=len(method_order)) _set_colors(ax, n_methods=len(method_order))
bins = np.linspace(0, 1, n_bins+1) bins = np.linspace(0, 1, n_bins+1)
@ -264,7 +270,11 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
tr_test_drifts = data[method]['x'] tr_test_drifts = data[method]['x']
method_drifts = data[method]['y'] method_drifts = data[method]['y']
if logscale: if logscale:
method_drifts=np.log(1+method_drifts) ax.set_yscale("log")
ax.yaxis.set_major_formatter(ScalarFormatter())
ax.yaxis.set_minor_formatter(ScalarFormatter())
ax.yaxis.get_major_formatter().set_scientific(False)
ax.yaxis.get_minor_formatter().set_scientific(False)
inds = np.digitize(tr_test_drifts, bins, right=True) inds = np.digitize(tr_test_drifts, bins, right=True)
@ -295,8 +305,14 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25) ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25)
if show_density: if show_density:
ax.bar([ind * binwidth-binwidth/2 for ind in range(len(bins))], ax2 = ax.twinx()
ax2.bar([ind * binwidth-binwidth/2 for ind in range(len(bins))],
max_y*npoints/np.max(npoints), alpha=0.15, color='g', width=binwidth, label='density') max_y*npoints/np.max(npoints), alpha=0.15, color='g', width=binwidth, label='density')
#ax2.set_ylabel("bar data")
ax2.set_ylim(0,1)
ax2.spines['right'].set_color('g')
ax2.tick_params(axis='y', colors='g')
#ax2.yaxis.set_visible(False)
ax.set(xlabel=f'Distribution shift between training set and test sample', ax.set(xlabel=f'Distribution shift between training set and test sample',
ylabel=f'{error_name.upper()} (true distribution, predicted distribution)', ylabel=f'{error_name.upper()} (true distribution, predicted distribution)',
@ -306,8 +322,12 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
if vlines: if vlines:
for vline in vlines: for vline in vlines:
ax.axvline(vline, 0, 1, linestyle='--', color='k') ax.axvline(vline, 0, 1, linestyle='--', color='k')
ax.set_xlim(0, max_x)
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
ax.set_xlim(min_x, max_x)
if show_legend:
fig.legend(loc='right')
_save_or_show(savepath) _save_or_show(savepath)