1
0
Fork 0
QuaPy/quapy/plot.py

256 lines
10 KiB
Python
Raw Normal View History

2021-01-07 17:58:48 +01:00
from collections import defaultdict
2021-01-15 18:32:32 +01:00
2021-01-07 17:58:48 +01:00
import matplotlib.pyplot as plt
import numpy as np
2021-01-15 18:32:32 +01:00
from matplotlib import cm
2021-01-07 17:58:48 +01:00
2021-01-15 18:32:32 +01:00
import quapy as qp
from matplotlib.font_manager import FontProperties
2021-01-07 17:58:48 +01:00
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['figure.dpi'] = 200
plt.rcParams['font.size'] = 16
2021-01-07 17:58:48 +01:00
2021-06-15 10:10:19 +02:00
def _set_colors(ax, n_methods):
NUM_COLORS = n_methods
2021-06-16 09:38:13 +02:00
cm = plt.get_cmap('tab20')
ax.set_prop_cycle(color=[cm(1. * i / NUM_COLORS) for i in range(NUM_COLORS)])
2021-06-15 10:10:19 +02:00
2021-02-18 13:48:41 +01:00
def binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title=None, show_std=True, legend=True,
train_prev=None, savepath=None):
2021-01-07 17:58:48 +01:00
fig, ax = plt.subplots()
ax.set_aspect('equal')
ax.grid()
ax.plot([0, 1], [0, 1], '--k', label='ideal', zorder=1)
method_names, true_prevs, estim_prevs = _merge(method_names, true_prevs, estim_prevs)
2021-06-15 10:10:19 +02:00
_set_colors(ax, n_methods=len(method_names))
2021-01-07 17:58:48 +01:00
for method, true_prev, estim_prev in zip(method_names, true_prevs, estim_prevs):
true_prev = true_prev[:,pos_class]
estim_prev = estim_prev[:,pos_class]
x_ticks = np.unique(true_prev)
x_ticks.sort()
y_ave = np.asarray([estim_prev[true_prev == x].mean() for x in x_ticks])
y_std = np.asarray([estim_prev[true_prev == x].std() for x in x_ticks])
ax.errorbar(x_ticks, y_ave, fmt='-', marker='o', label=method, markersize=3, zorder=2)
if show_std:
ax.fill_between(x_ticks, y_ave - y_std, y_ave + y_std, alpha=0.25)
2021-01-07 17:58:48 +01:00
2021-02-18 13:48:41 +01:00
if train_prev is not None:
train_prev = train_prev[pos_class]
ax.scatter(train_prev, train_prev, c='c', label='tr-prev', linewidth=2, edgecolor='k', s=100, zorder=3)
2021-01-07 17:58:48 +01:00
ax.set(xlabel='true prevalence', ylabel='estimated prevalence', title=title)
ax.set_ylim(0, 1)
ax.set_xlim(0, 1)
if legend:
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
2021-01-07 17:58:48 +01:00
save_or_show(savepath)
def binary_bias_global(method_names, true_prevs, estim_prevs, pos_class=1, title=None, savepath=None):
method_names, true_prevs, estim_prevs = _merge(method_names, true_prevs, estim_prevs)
2021-01-07 17:58:48 +01:00
fig, ax = plt.subplots()
ax.grid()
data, labels = [], []
for method, true_prev, estim_prev in zip(method_names, true_prevs, estim_prevs):
true_prev = true_prev[:,pos_class]
estim_prev = estim_prev[:,pos_class]
data.append(estim_prev-true_prev)
labels.append(method)
ax.boxplot(data, labels=labels, patch_artist=False, showmeans=True)
plt.xticks(rotation=45)
2021-01-07 17:58:48 +01:00
ax.set(ylabel='error bias', title=title)
save_or_show(savepath)
2021-01-11 12:55:06 +01:00
def binary_bias_bins(method_names, true_prevs, estim_prevs, pos_class=1, title=None, nbins=5, colormap=cm.tab10,
vertical_xticks=False, legend=True, savepath=None):
2021-01-07 17:58:48 +01:00
from pylab import boxplot, plot, setp
fig, ax = plt.subplots()
ax.grid()
method_names, true_prevs, estim_prevs = _merge(method_names, true_prevs, estim_prevs)
2021-06-15 10:10:19 +02:00
_set_colors(ax, n_methods=len(method_names))
2021-01-11 12:55:06 +01:00
bins = np.linspace(0, 1, nbins+1)
binwidth = 1/nbins
2021-01-07 17:58:48 +01:00
data = {}
for method, true_prev, estim_prev in zip(method_names, true_prevs, estim_prevs):
2021-03-11 09:27:14 +01:00
true_prev = true_prev[:, pos_class]
estim_prev = estim_prev[:, pos_class]
2021-01-07 17:58:48 +01:00
data[method] = []
2021-03-11 09:27:14 +01:00
inds = np.digitize(true_prev, bins[1:], right=True)
2021-01-07 17:58:48 +01:00
for ind in range(len(bins)):
selected = inds==ind
data[method].append(estim_prev[selected] - true_prev[selected])
nmethods = len(method_names)
boxwidth = binwidth/(nmethods+4)
2021-03-11 09:27:14 +01:00
for i,bin in enumerate(bins):
2021-01-07 17:58:48 +01:00
boxdata = [data[method][i] for method in method_names]
positions = [bin+(i*boxwidth)+2*boxwidth for i,_ in enumerate(method_names)]
2021-03-11 09:27:14 +01:00
box = boxplot(boxdata, showmeans=False, positions=positions, widths=boxwidth, sym='+', patch_artist=True)
2021-01-07 17:58:48 +01:00
for boxid in range(len(method_names)):
c = colormap.colors[boxid%len(colormap.colors)]
2021-01-07 17:58:48 +01:00
setp(box['fliers'][boxid], color=c, marker='+', markersize=3., markeredgecolor=c)
setp(box['boxes'][boxid], color=c)
setp(box['medians'][boxid], color='k')
major_xticks_positions, minor_xticks_positions = [], []
major_xticks_labels, minor_xticks_labels = [], []
for i,b in enumerate(bins[:-1]):
major_xticks_positions.append(b)
minor_xticks_positions.append(b + binwidth / 2)
major_xticks_labels.append('')
2021-03-11 09:27:14 +01:00
minor_xticks_labels.append(f'[{bins[i]:.2f}-{bins[i + 1]:.2f}' + (')' if i < len(bins)-2 else ']'))
2021-01-07 17:58:48 +01:00
ax.set_xticks(major_xticks_positions)
ax.set_xticks(minor_xticks_positions, minor=True)
ax.set_xticklabels(major_xticks_labels)
ax.set_xticklabels(minor_xticks_labels, minor=True, rotation='vertical' if vertical_xticks else 'horizontal')
if vertical_xticks:
# Pad margins so that markers don't get clipped by the axes
plt.margins(0.2)
# Tweak spacing to prevent clipping of tick-labels
plt.subplots_adjust(bottom=0.15)
if legend:
# adds the legend to the list hs, initialized with the "ideal" quantifier (one that has 0 bias across all bins. i.e.
# a line from (0,0) to (1,0). The other elements are simply labelled dot-plots that are to be removed (setting
# set_visible to False for all but the first element) after the legend has been placed
hs=[ax.plot([0, 1], [0, 0], '-k', zorder=2)[0]]
for colorid in range(len(method_names)):
color=colormap.colors[colorid % len(colormap.colors)]
h, = plot([0, 0], '-s', markerfacecolor=color, color='k',mec=color, linewidth=1.)
hs.append(h)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.legend(hs, ['ideal']+method_names, loc='center left', bbox_to_anchor=(1, 0.5))
[h.set_visible(False) for h in hs[1:]]
2021-01-07 17:58:48 +01:00
# x-axis and y-axis labels and limits
ax.set(xlabel='prevalence', ylabel='error bias', title=title)
# ax.set_ylim(-1, 1)
ax.set_xlim(0, 1)
save_or_show(savepath)
def _merge(method_names, true_prevs, estim_prevs):
ndims = true_prevs[0].shape[1]
data = defaultdict(lambda: {'true': np.empty(shape=(0, ndims)), 'estim': np.empty(shape=(0, ndims))})
method_order=[]
for method, true_prev, estim_prev in zip(method_names, true_prevs, estim_prevs):
data[method]['true'] = np.concatenate([data[method]['true'], true_prev])
data[method]['estim'] = np.concatenate([data[method]['estim'], estim_prev])
if method not in method_order:
method_order.append(method)
true_prevs_ = [data[m]['true'] for m in method_order]
estim_prevs_ = [data[m]['estim'] for m in method_order]
return method_order, true_prevs_, estim_prevs_
2021-01-11 12:55:06 +01:00
def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, error_name='ae', show_std=True,
logscale=False,
title=f'Quantification error as a function of distribution shift',
savepath=None):
2021-01-07 17:58:48 +01:00
fig, ax = plt.subplots()
ax.grid()
x_error = qp.error.ae
y_error = getattr(qp.error, error_name)
# join all data, and keep the order in which the methods appeared for the first time
data = defaultdict(lambda:{'x':np.empty(shape=(0)), 'y':np.empty(shape=(0))})
method_order = []
2021-06-15 07:49:16 +02:00
2021-01-07 17:58:48 +01:00
for method, test_prevs_i, estim_prevs_i, tr_prev_i in zip(method_names, true_prevs, estim_prevs, tr_prevs):
tr_prev_i = np.repeat(tr_prev_i.reshape(1,-1), repeats=test_prevs_i.shape[0], axis=0)
tr_test_drifts = x_error(test_prevs_i, tr_prev_i)
data[method]['x'] = np.concatenate([data[method]['x'], tr_test_drifts])
method_drifts = y_error(test_prevs_i, estim_prevs_i)
data[method]['y'] = np.concatenate([data[method]['y'], method_drifts])
if method not in method_order:
method_order.append(method)
2021-06-15 10:10:19 +02:00
_set_colors(ax, n_methods=len(method_order))
2021-01-11 12:55:06 +01:00
bins = np.linspace(0, 1, n_bins+1)
2021-06-15 07:49:16 +02:00
inds_histogram_global = np.zeros(n_bins, dtype=np.float) # we use this to keep track of how many datapoits contribute to each bin
2021-01-11 12:55:06 +01:00
binwidth = 1 / n_bins
2021-01-07 17:58:48 +01:00
min_x, max_x = None, None
for method in method_order:
tr_test_drifts = data[method]['x']
method_drifts = data[method]['y']
if logscale:
method_drifts=np.log(1+method_drifts)
2021-01-07 17:58:48 +01:00
inds = np.digitize(tr_test_drifts, bins, right=True)
2021-06-15 07:49:16 +02:00
inds_histogram_global += np.histogram(tr_test_drifts, density=True, bins=bins)[0]
2021-01-07 17:58:48 +01:00
xs, ys, ystds = [], [], []
for ind in range(len(bins)):
selected = inds==ind
if selected.sum() > 0:
xs.append(ind*binwidth)
ys.append(np.mean(method_drifts[selected]))
ystds.append(np.std(method_drifts[selected]))
xs = np.asarray(xs)
ys = np.asarray(ys)
ystds = np.asarray(ystds)
min_x_method, max_x_method = xs.min(), xs.max()
min_x = min_x_method if min_x is None or min_x_method < min_x else min_x
max_x = max_x_method if max_x is None or max_x_method > max_x else max_x
ax.errorbar(xs, ys, fmt='-', marker='o', label=method, markersize=3, zorder=2)
if show_std:
ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25)
2021-06-15 07:49:16 +02:00
# xs = bins[:-1]
# ys = inds_histogram_global
# print(xs.shape, ys.shape)
# ax.errorbar(xs, ys, label='density')
2021-01-07 17:58:48 +01:00
ax.set(xlabel=f'Distribution shift between training set and test sample',
ylabel=f'{error_name.upper()} (true distribution, predicted distribution)',
title=title)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
ax.set_xlim(min_x, max_x)
save_or_show(savepath)
def save_or_show(savepath):
# if savepath is specified, then saves the plot in that path; otherwise the plot is shown
if savepath is not None:
qp.util.create_parent_dir(savepath)
# plt.tight_layout()
plt.savefig(savepath, bbox_inches='tight')
2021-01-07 17:58:48 +01:00
else:
plt.show()