QuaPy/quapy/plot.py

570 lines
27 KiB
Python
Raw Normal View History

2021-01-07 17:58:48 +01:00
from collections import defaultdict
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
2021-01-07 17:58:48 +01:00
import numpy as np
2021-01-15 18:32:32 +01:00
from matplotlib import cm
from scipy.stats import ttest_ind_from_stats
2023-01-18 14:53:46 +01:00
from matplotlib.ticker import ScalarFormatter
2023-01-18 16:05:40 +01:00
import math
2021-01-07 17:58:48 +01:00
2021-01-15 18:32:32 +01:00
import quapy as qp
2021-01-07 17:58:48 +01:00
2022-05-20 11:52:59 +02:00
plt.rcParams['figure.figsize'] = [10, 6]
2021-01-07 17:58:48 +01:00
plt.rcParams['figure.dpi'] = 200
2022-05-20 11:52:59 +02:00
plt.rcParams['font.size'] = 18
2021-01-07 17:58:48 +01:00
2021-02-18 13:48:41 +01:00
def binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title=None, show_std=True, legend=True,
train_prev=None, savepath=None, method_order=None):
"""
The diagonal plot displays the predicted prevalence values (along the y-axis) as a function of the true prevalence
values (along the x-axis). The optimal quantifier is described by the diagonal (0,0)-(1,1) of the plot (hence the
name). It is convenient for binary quantification problems, though it can be used for multiclass problems by
indicating which class is to be taken as the positive class. (For multiclass quantification problems, other plots
like the :meth:`error_by_drift` might be preferable though).
:param method_names: array-like with the method names for each experiment
:param true_prevs: array-like with the true prevalence values (each being a ndarray with n_classes components) for
each experiment
:param estim_prevs: array-like with the estimated prevalence values (each being a ndarray with n_classes components)
for each experiment
:param pos_class: index of the positive class
:param title: the title to be displayed in the plot
:param show_std: whether or not to show standard deviations (represented by color bands). This might be inconvenient
for cases in which many methods are compared, or when the standard deviations are high -- default True)
:param legend: whether or not to display the leyend (default True)
:param train_prev: if indicated (default is None), the training prevalence (for the positive class) is hightlighted
in the plot. This is convenient when all the experiments have been conducted in the same dataset.
:param savepath: path where to save the plot. If not indicated (as default), the plot is shown.
:param method_order: if indicated (default is None), imposes the order in which the methods are processed (i.e.,
listed in the legend and associated with matplotlib colors).
"""
2021-01-07 17:58:48 +01:00
fig, ax = plt.subplots()
ax.set_aspect('equal')
ax.grid()
ax.plot([0, 1], [0, 1], '--k', label='ideal', zorder=1)
method_names, true_prevs, estim_prevs = _merge(method_names, true_prevs, estim_prevs)
order = list(zip(method_names, true_prevs, estim_prevs))
if method_order is not None:
table = {method_name:[true_prev, estim_prev] for method_name, true_prev, estim_prev in order}
order = [(method_name, *table[method_name]) for method_name in method_order]
NUM_COLORS = len(method_names)
2023-02-14 17:00:50 +01:00
if NUM_COLORS>10:
cm = plt.get_cmap('tab20')
ax.set_prop_cycle(color=[cm(1. * i / NUM_COLORS) for i in range(NUM_COLORS)])
for method, true_prev, estim_prev in order:
2021-01-07 17:58:48 +01:00
true_prev = true_prev[:,pos_class]
estim_prev = estim_prev[:,pos_class]
x_ticks = np.unique(true_prev)
x_ticks.sort()
y_ave = np.asarray([estim_prev[true_prev == x].mean() for x in x_ticks])
y_std = np.asarray([estim_prev[true_prev == x].std() for x in x_ticks])
ax.errorbar(x_ticks, y_ave, fmt='-', marker='o', label=method, markersize=3, zorder=2)
if show_std:
ax.fill_between(x_ticks, y_ave - y_std, y_ave + y_std, alpha=0.25)
2021-01-07 17:58:48 +01:00
2021-02-18 13:48:41 +01:00
if train_prev is not None:
train_prev = train_prev[pos_class]
ax.scatter(train_prev, train_prev, c='c', label='tr-prev', linewidth=2, edgecolor='k', s=100, zorder=3)
2021-01-07 17:58:48 +01:00
ax.set(xlabel='true prevalence', ylabel='estimated prevalence', title=title)
ax.set_ylim(0, 1)
ax.set_xlim(0, 1)
if legend:
2022-05-20 11:52:59 +02:00
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
# box = ax.get_position()
# ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
2023-02-14 17:00:50 +01:00
# ax.legend(loc='lower center',
# bbox_to_anchor=(1, -0.5),
# ncol=(len(method_names)+1)//2)
2021-01-07 17:58:48 +01:00
2021-11-24 11:20:42 +01:00
_save_or_show(savepath)
2021-01-07 17:58:48 +01:00
def binary_bias_global(method_names, true_prevs, estim_prevs, pos_class=1, title=None, savepath=None):
"""
Box-plots displaying the global bias (i.e., signed error computed as the estimated value minus the true value)
for each quantification method with respect to a given positive class.
:param method_names: array-like with the method names for each experiment
:param true_prevs: array-like with the true prevalence values (each being a ndarray with n_classes components) for
each experiment
:param estim_prevs: array-like with the estimated prevalence values (each being a ndarray with n_classes components)
for each experiment
:param pos_class: index of the positive class
:param title: the title to be displayed in the plot
:param savepath: path where to save the plot. If not indicated (as default), the plot is shown.
"""
method_names, true_prevs, estim_prevs = _merge(method_names, true_prevs, estim_prevs)
2021-01-07 17:58:48 +01:00
fig, ax = plt.subplots()
ax.grid()
data, labels = [], []
for method, true_prev, estim_prev in zip(method_names, true_prevs, estim_prevs):
true_prev = true_prev[:,pos_class]
estim_prev = estim_prev[:,pos_class]
data.append(estim_prev-true_prev)
labels.append(method)
ax.boxplot(data, labels=labels, patch_artist=False, showmeans=True)
plt.xticks(rotation=45)
2021-01-07 17:58:48 +01:00
ax.set(ylabel='error bias', title=title)
2021-11-24 11:20:42 +01:00
_save_or_show(savepath)
2021-01-07 17:58:48 +01:00
2021-01-11 12:55:06 +01:00
def binary_bias_bins(method_names, true_prevs, estim_prevs, pos_class=1, title=None, nbins=5, colormap=cm.tab10,
vertical_xticks=False, legend=True, savepath=None):
"""
2021-11-24 11:20:42 +01:00
Box-plots displaying the local bias (i.e., signed error computed as the estimated value minus the true value)
for different bins of (true) prevalence of the positive classs, for each quantification method.
:param method_names: array-like with the method names for each experiment
:param true_prevs: array-like with the true prevalence values (each being a ndarray with n_classes components) for
each experiment
:param estim_prevs: array-like with the estimated prevalence values (each being a ndarray with n_classes components)
for each experiment
:param pos_class: index of the positive class
:param title: the title to be displayed in the plot
:param nbins: number of bins
:param colormap: the matplotlib colormap to use (default cm.tab10)
2021-11-24 11:20:42 +01:00
:param vertical_xticks: whether or not to add secondary grid (default is False)
:param legend: whether or not to display the legend (default is True)
:param savepath: path where to save the plot. If not indicated (as default), the plot is shown.
"""
2021-01-07 17:58:48 +01:00
from pylab import boxplot, plot, setp
fig, ax = plt.subplots()
ax.grid()
method_names, true_prevs, estim_prevs = _merge(method_names, true_prevs, estim_prevs)
2021-01-11 12:55:06 +01:00
bins = np.linspace(0, 1, nbins+1)
binwidth = 1/nbins
2021-01-07 17:58:48 +01:00
data = {}
for method, true_prev, estim_prev in zip(method_names, true_prevs, estim_prevs):
true_prev = true_prev[:,pos_class]
estim_prev = estim_prev[:,pos_class]
2021-01-07 17:58:48 +01:00
data[method] = []
inds = np.digitize(true_prev, bins, right=True)
2021-01-07 17:58:48 +01:00
for ind in range(len(bins)):
selected = inds==ind
data[method].append(estim_prev[selected] - true_prev[selected])
nmethods = len(method_names)
boxwidth = binwidth/(nmethods+4)
for i,bin in enumerate(bins[:-1]):
2021-01-07 17:58:48 +01:00
boxdata = [data[method][i] for method in method_names]
positions = [bin+(i*boxwidth)+2*boxwidth for i,_ in enumerate(method_names)]
box = boxplot(boxdata, showmeans=False, positions=positions, widths = boxwidth, sym='+', patch_artist=True)
2021-01-07 17:58:48 +01:00
for boxid in range(len(method_names)):
c = colormap.colors[boxid%len(colormap.colors)]
2021-01-07 17:58:48 +01:00
setp(box['fliers'][boxid], color=c, marker='+', markersize=3., markeredgecolor=c)
setp(box['boxes'][boxid], color=c)
setp(box['medians'][boxid], color='k')
major_xticks_positions, minor_xticks_positions = [], []
major_xticks_labels, minor_xticks_labels = [], []
for i,b in enumerate(bins[:-1]):
major_xticks_positions.append(b)
minor_xticks_positions.append(b + binwidth / 2)
major_xticks_labels.append('')
minor_xticks_labels.append(f'[{bins[i]:.2f}-{bins[i + 1]:.2f})')
2021-01-07 17:58:48 +01:00
ax.set_xticks(major_xticks_positions)
ax.set_xticks(minor_xticks_positions, minor=True)
ax.set_xticklabels(major_xticks_labels)
ax.set_xticklabels(minor_xticks_labels, minor=True, rotation='vertical' if vertical_xticks else 'horizontal')
if vertical_xticks:
# Pad margins so that markers don't get clipped by the axes
plt.margins(0.2)
# Tweak spacing to prevent clipping of tick-labels
plt.subplots_adjust(bottom=0.15)
if legend:
# adds the legend to the list hs, initialized with the "ideal" quantifier (one that has 0 bias across all bins. i.e.
# a line from (0,0) to (1,0). The other elements are simply labelled dot-plots that are to be removed (setting
# set_visible to False for all but the first element) after the legend has been placed
hs=[ax.plot([0, 1], [0, 0], '-k', zorder=2)[0]]
for colorid in range(len(method_names)):
color=colormap.colors[colorid % len(colormap.colors)]
h, = plot([0, 0], '-s', markerfacecolor=color, color='k',mec=color, linewidth=1.)
hs.append(h)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.legend(hs, ['ideal']+method_names, loc='center left', bbox_to_anchor=(1, 0.5))
[h.set_visible(False) for h in hs[1:]]
2021-01-07 17:58:48 +01:00
# x-axis and y-axis labels and limits
ax.set(xlabel='prevalence', ylabel='error bias', title=title)
ax.set_xlim(0, 1)
2021-11-24 11:20:42 +01:00
_save_or_show(savepath)
2021-01-07 17:58:48 +01:00
2021-11-24 11:20:42 +01:00
def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
n_bins=20, error_name='ae', show_std=False,
show_density=True,
2023-01-16 17:00:24 +01:00
show_legend=True,
logscale=False,
title=f'Quantification error as a function of distribution shift',
vlines=None,
2021-11-24 11:20:42 +01:00
method_order=None,
savepath=None):
"""
Plots the error (along the x-axis, as measured in terms of `error_name`) as a function of the train-test shift
(along the y-axis, as measured in terms of :meth:`quapy.error.ae`). This plot is useful especially for multiclass
problems, in which "diagonal plots" may be cumbersone, and in order to gain understanding about how methods
fare in different regions of the prior probability shift spectrum (e.g., in the low-shift regime vs. in the
high-shift regime).
:param method_names: array-like with the method names for each experiment
:param true_prevs: array-like with the true prevalence values (each being a ndarray with n_classes components) for
each experiment
:param estim_prevs: array-like with the estimated prevalence values (each being a ndarray with n_classes components)
for each experiment
:param tr_prevs: training prevalence of each experiment
:param n_bins: number of bins in which the y-axis is to be divided (default is 20)
:param error_name: a string representing the name of an error function (as defined in `quapy.error`, default is "ae")
:param show_std: whether or not to show standard deviations as color bands (default is False)
:param show_density: whether or not to display the distribution of experiments for each bin (default is True)
2023-01-16 17:00:24 +01:00
:param show_density: whether or not to display the legend of the chart (default is True)
2021-11-24 11:20:42 +01:00
:param logscale: whether or not to log-scale the y-error measure (default is False)
:param title: title of the plot (default is "Quantification error as a function of distribution shift")
:param vlines: array-like list of values (default is None). If indicated, highlights some regions of the space
using vertical dotted lines.
:param method_order: if indicated (default is None), imposes the order in which the methods are processed (i.e.,
listed in the legend and associated with matplotlib colors).
:param savepath: path where to save the plot. If not indicated (as default), the plot is shown.
"""
2021-01-07 17:58:48 +01:00
fig, ax = plt.subplots()
ax.grid()
x_error = qp.error.ae
y_error = getattr(qp.error, error_name)
# get all data as a dictionary {'m':{'x':ndarray, 'y':ndarray}} where 'm' is a method name (in the same
# order as in method_order (if specified), and where 'x' are the train-test shifts (computed as according to
# x_error function) and 'y' is the estim-test shift (computed as according to y_error)
2021-11-24 11:20:42 +01:00
data = _join_data_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, x_error, y_error, method_order)
2021-01-07 17:58:48 +01:00
2023-01-18 13:06:38 +01:00
if method_order is None:
method_order = method_names
2021-06-15 10:10:19 +02:00
_set_colors(ax, n_methods=len(method_order))
2021-01-11 12:55:06 +01:00
bins = np.linspace(0, 1, n_bins+1)
binwidth = 1 / n_bins
min_x, max_x, min_y, max_y = None, None, None, None
npoints = np.zeros(len(bins), dtype=float)
2021-01-07 17:58:48 +01:00
for method in method_order:
tr_test_drifts = data[method]['x']
method_drifts = data[method]['y']
if logscale:
2023-01-18 14:53:46 +01:00
ax.set_yscale("log")
ax.yaxis.set_major_formatter(ScalarFormatter())
ax.yaxis.get_major_formatter().set_scientific(False)
2023-01-18 16:05:40 +01:00
ax.minorticks_off()
2021-01-07 17:58:48 +01:00
inds = np.digitize(tr_test_drifts, bins, right=True)
2021-06-15 07:49:16 +02:00
2021-01-07 17:58:48 +01:00
xs, ys, ystds = [], [], []
for p,ind in enumerate(range(len(bins))):
2021-01-07 17:58:48 +01:00
selected = inds==ind
if selected.sum() > 0:
xs.append(ind*binwidth-binwidth/2)
2021-01-07 17:58:48 +01:00
ys.append(np.mean(method_drifts[selected]))
ystds.append(np.std(method_drifts[selected]))
npoints[p] += len(method_drifts[selected])
2021-01-07 17:58:48 +01:00
xs = np.asarray(xs)
ys = np.asarray(ys)
ystds = np.asarray(ystds)
min_x_method, max_x_method, min_y_method, max_y_method = xs.min(), xs.max(), ys.min(), ys.max()
2021-01-07 17:58:48 +01:00
min_x = min_x_method if min_x is None or min_x_method < min_x else min_x
max_x = max_x_method if max_x is None or max_x_method > max_x else max_x
max_y = max_y_method if max_y is None or max_y_method > max_y else max_y
min_y = min_y_method if min_y is None or min_y_method < min_y else min_y
max_y = max_y_method if max_y is None or max_y_method > max_y else max_y
ax.errorbar(xs, ys, fmt='-', marker='o', color='w', markersize=8, linewidth=4, zorder=1)
ax.errorbar(xs, ys, fmt='-', marker='o', label=method, markersize=6, linewidth=2, zorder=2)
2021-01-07 17:58:48 +01:00
if show_std:
ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25)
if show_density:
2023-01-18 14:53:46 +01:00
ax2 = ax.twinx()
2023-01-18 16:12:38 +01:00
densities = npoints/np.sum(npoints)
2023-01-18 14:53:46 +01:00
ax2.bar([ind * binwidth-binwidth/2 for ind in range(len(bins))],
2023-01-18 16:12:38 +01:00
densities, alpha=0.15, color='g', width=binwidth, label='density')
ax2.set_ylim(0,max(densities))
2023-01-18 14:53:46 +01:00
ax2.spines['right'].set_color('g')
ax2.tick_params(axis='y', colors='g')
2023-01-18 13:06:38 +01:00
2021-01-07 17:58:48 +01:00
ax.set(xlabel=f'Distribution shift between training set and test sample',
ylabel=f'{error_name.upper()} (true distribution, predicted distribution)',
title=title)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
if vlines:
for vline in vlines:
ax.axvline(vline, 0, 1, linestyle='--', color='k')
2023-01-16 17:00:24 +01:00
ax.set_xlim(min_x, max_x)
2023-01-18 16:05:40 +01:00
if logscale:
#nice scale for the logaritmic axis
ax.set_ylim(0,10 ** math.ceil(math.log10(max_y)))
2023-01-16 17:17:02 +01:00
2023-01-17 11:03:52 +01:00
if show_legend:
2023-01-18 16:12:38 +01:00
fig.legend(loc='lower center',
bbox_to_anchor=(1, 0.5),
ncol=(len(method_names)+1)//2)
2023-01-18 14:53:46 +01:00
2021-11-24 11:20:42 +01:00
_save_or_show(savepath)
2021-11-24 11:20:42 +01:00
def brokenbar_supremacy_by_drift(method_names, true_prevs, estim_prevs, tr_prevs,
n_bins=20, binning='isomerous',
x_error='ae', y_error='ae', ttest_alpha=0.005, tail_density_threshold=0.005,
method_order=None,
savepath=None):
2021-11-24 11:20:42 +01:00
"""
Displays (only) the top performing methods for different regions of the train-test shift in form of a broken
bar chart, in which each method has bars only for those regions in which either one of the following conditions
hold: (i) it is the best method (in average) for the bin, or (ii) it is not statistically significantly different
(in average) as according to a two-sided t-test on independent samples at confidence `ttest_alpha`.
The binning can be made "isometric" (same size), or "isomerous" (same number of experiments -- default). A second
plot is displayed on top, that displays the distribution of experiments for each bin (when binning="isometric") or
the percentiles points of the distribution (when binning="isomerous").
:param method_names: array-like with the method names for each experiment
:param true_prevs: array-like with the true prevalence values (each being a ndarray with n_classes components) for
each experiment
:param estim_prevs: array-like with the estimated prevalence values (each being a ndarray with n_classes components)
for each experiment
:param tr_prevs: training prevalence of each experiment
:param n_bins: number of bins in which the y-axis is to be divided (default is 20)
:param binning: type of binning, either "isomerous" (default) or "isometric"
:param x_error: a string representing the name of an error function (as defined in `quapy.error`) to be used for
measuring the amount of train-test shift (default is "ae")
:param y_error: a string representing the name of an error function (as defined in `quapy.error`) to be used for
measuring the amount of error in the prevalence estimations (default is "ae")
:param ttest_alpha: the confidence interval above which a p-value (two-sided t-test on independent samples) is
to be considered as an indicator that the two means are not statistically significantly different. Default is
0.005, meaning that a `p-value > 0.005` indicates the two methods involved are to be considered similar
:param tail_density_threshold: sets a threshold on the density of experiments (over the total number of experiments)
below which a bin in the tail (i.e., the right-most ones) will be discarded. This is in order to avoid some
bins to be shown for train-test outliers.
:param method_order: if indicated (default is None), imposes the order in which the methods are processed (i.e.,
listed in the legend and associated with matplotlib colors).
:param savepath: path where to save the plot. If not indicated (as default), the plot is shown.
:return:
"""
assert binning in ['isomerous', 'isometric'], 'unknown binning type; valid types are "isomerous" and "isometric"'
x_error = getattr(qp.error, x_error)
y_error = getattr(qp.error, y_error)
# get all data as a dictionary {'m':{'x':ndarray, 'y':ndarray}} where 'm' is a method name (in the same
# order as in method_order (if specified), and where 'x' are the train-test shifts (computed as according to
# x_error function) and 'y' is the estim-test shift (computed as according to y_error)
2021-11-24 11:20:42 +01:00
data = _join_data_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, x_error, y_error, method_order)
if method_order is None:
method_order = method_names
if binning == 'isomerous':
# take bins containing the same amount of examples
tr_test_drifts = np.concatenate([data[m]['x'] for m in method_order])
bins = np.quantile(tr_test_drifts, q=np.linspace(0, 1, n_bins+1)).flatten()
else:
# take equidistant bins
bins = np.linspace(0, 1, n_bins+1)
bins[0] = -0.001
bins[-1] += 0.001
# we use this to keep track of how many datapoits contribute to each bin
inds_histogram_global = np.zeros(n_bins, dtype=float)
n_methods = len(method_order)
buckets = np.zeros(shape=(n_methods, n_bins, 3))
for i, method in enumerate(method_order):
tr_test_drifts = data[method]['x']
method_drifts = data[method]['y']
inds = np.digitize(tr_test_drifts, bins, right=False)
inds_histogram_global += np.histogram(tr_test_drifts, density=False, bins=bins)[0]
for j in range(len(bins)):
selected = inds == j
if selected.sum() > 0:
buckets[i, j-1, 0] = np.mean(method_drifts[selected])
buckets[i, j-1, 1] = np.std(method_drifts[selected])
buckets[i, j-1, 2] = selected.sum()
# cancel last buckets with low density
histogram = inds_histogram_global / inds_histogram_global.sum()
for tail in reversed(range(len(histogram))):
if histogram[tail] < tail_density_threshold:
buckets[:,tail,2] = 0
else:
break
salient_methods = set()
best_methods = []
for bucket in range(buckets.shape[1]):
nc = buckets[:, bucket, 2].sum()
if nc == 0:
best_methods.append([])
continue
order = np.argsort(buckets[:, bucket, 0])
rank1 = order[0]
best_bucket_methods = [method_order[rank1]]
best_mean, best_std, best_nc = buckets[rank1, bucket, :]
for method_index in order[1:]:
method_mean, method_std, method_nc = buckets[method_index, bucket, :]
_, pval = ttest_ind_from_stats(best_mean, best_std, best_nc, method_mean, method_std, method_nc)
if pval > ttest_alpha:
best_bucket_methods.append(method_order[method_index])
best_methods.append(best_bucket_methods)
salient_methods.update(best_bucket_methods)
print(best_bucket_methods)
if binning=='isomerous':
fig, axes = plt.subplots(2, 1, gridspec_kw={'height_ratios': [0.2, 1]}, figsize=(20, len(salient_methods)))
else:
fig, axes = plt.subplots(2, 1, gridspec_kw={'height_ratios': [1, 1]}, figsize=(20, len(salient_methods)))
ax = axes[1]
high_from = 0
yticks, yticks_method_names = [], []
color = get_cmap('Accent').colors
vlines = []
bar_high = 1
for method in [m for m in method_order if m in salient_methods]:
broken_paths = []
path_start, path_end = None, None
for i, best_bucket_methods in enumerate(best_methods):
if method in best_bucket_methods:
if path_start is None:
path_start = bins[i]
path_end = bins[i+1]-path_start
else:
path_end += bins[i+1]-bins[i]
else:
if path_start is not None:
broken_paths.append(tuple((path_start, path_end)))
path_start, path_end = None, None
if path_start is not None:
broken_paths.append(tuple((path_start, path_end)))
ax.broken_barh(broken_paths, (high_from, bar_high), facecolors=color[len(yticks_method_names)])
yticks.append(high_from+bar_high/2)
high_from += bar_high
yticks_method_names.append(method)
for path_start, path_end in broken_paths:
vlines.extend([path_start, path_start+path_end])
vlines = np.unique(vlines)
vlines = sorted(vlines)
for v in vlines[1:-1]:
ax.axvline(x=v, color='k', linestyle='--')
ax.set_ylim(0, high_from)
ax.set_xlim(vlines[0], vlines[-1])
ax.set_xlabel('Distribution shift between training set and sample')
ax.set_yticks(yticks)
ax.set_yticklabels(yticks_method_names)
# upper plot (explaining distribution)
ax = axes[0]
if binning == 'isometric':
# show the density for each region
bins[0]=0
y_pos = [b+(bins[i+1]-b)/2 for i,b in enumerate(bins[:-1]) if histogram[i]>0]
bar_width = [bins[i+1]-bins[i] for i in range(len(bins[:-1])) if histogram[i]>0]
ax.bar(y_pos, [n for n in histogram if n>0], bar_width, align='center', alpha=0.5, color='silver')
ax.set_ylabel('shift\ndistribution', rotation=0, ha='right', va='center')
ax.set_xlim(vlines[0], vlines[-1])
ax.get_xaxis().set_visible(False)
plt.subplots_adjust(wspace=0, hspace=0.1)
else:
# show the percentiles of the distribution
cumsum = np.cumsum(histogram)
for i in range(len(bins[:-1])):
start, width = bins[i], bins[i+1]-bins[i]
ax.broken_barh([tuple((start, width))], (0, 1), facecolors='whitesmoke' if i%2==0 else 'silver')
if i < len(bins)-2:
ax.text(bins[i+1], 0.5, '$P_{'+f'{int(np.round(cumsum[i]*100))}'+'}$', ha='center')
ax.set_ylim(0, 1)
ax.set_xlim(vlines[0], vlines[-1])
ax.get_yaxis().set_visible(False)
ax.get_xaxis().set_visible(False)
plt.subplots_adjust(wspace=0, hspace=0)
2021-01-07 17:58:48 +01:00
2021-11-24 11:20:42 +01:00
_save_or_show(savepath)
def _merge(method_names, true_prevs, estim_prevs):
ndims = true_prevs[0].shape[1]
data = defaultdict(lambda: {'true': np.empty(shape=(0, ndims)), 'estim': np.empty(shape=(0, ndims))})
method_order=[]
for method, true_prev, estim_prev in zip(method_names, true_prevs, estim_prevs):
data[method]['true'] = np.concatenate([data[method]['true'], true_prev])
data[method]['estim'] = np.concatenate([data[method]['estim'], estim_prev])
if method not in method_order:
method_order.append(method)
true_prevs_ = [data[m]['true'] for m in method_order]
estim_prevs_ = [data[m]['estim'] for m in method_order]
return method_order, true_prevs_, estim_prevs_
def _set_colors(ax, n_methods):
NUM_COLORS = n_methods
cm = plt.get_cmap('tab20')
ax.set_prop_cycle(color=[cm(1. * i / NUM_COLORS) for i in range(NUM_COLORS)])
2021-01-07 17:58:48 +01:00
2021-11-24 11:20:42 +01:00
def _save_or_show(savepath):
2021-01-07 17:58:48 +01:00
# if savepath is specified, then saves the plot in that path; otherwise the plot is shown
if savepath is not None:
qp.util.create_parent_dir(savepath)
# plt.tight_layout()
plt.savefig(savepath, bbox_inches='tight')
2021-01-07 17:58:48 +01:00
else:
plt.show()
2021-11-24 11:20:42 +01:00
def _join_data_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, x_error, y_error, method_order):
data = defaultdict(lambda: {'x': np.empty(shape=(0)), 'y': np.empty(shape=(0))})
if method_order is None:
method_order = []
for method, test_prevs_i, estim_prevs_i, tr_prev_i in zip(method_names, true_prevs, estim_prevs, tr_prevs):
tr_prev_i = np.repeat(tr_prev_i.reshape(1, -1), repeats=test_prevs_i.shape[0], axis=0)
tr_test_drifts = x_error(test_prevs_i, tr_prev_i)
data[method]['x'] = np.concatenate([data[method]['x'], tr_test_drifts])
method_drifts = y_error(test_prevs_i, estim_prevs_i)
data[method]['y'] = np.concatenate([data[method]['y'], method_drifts])
if method not in method_order:
method_order.append(method)
return data