forked from moreo/QuaPy
Compare commits
3 Commits
Author | SHA1 | Date |
---|---|---|
|
7fd32d5c5f | |
|
bdbe933a41 | |
|
95b21c8bc2 |
|
@ -12,8 +12,8 @@ from os.path import join
|
||||||
qp.environ['SAMPLE_SIZE'] = settings.SAMPLE_SIZE
|
qp.environ['SAMPLE_SIZE'] = settings.SAMPLE_SIZE
|
||||||
plotext='png'
|
plotext='png'
|
||||||
|
|
||||||
resultdir = './results'
|
resultdir = './results_npp'
|
||||||
plotdir = './plots'
|
plotdir = './plots_npp'
|
||||||
os.makedirs(plotdir, exist_ok=True)
|
os.makedirs(plotdir, exist_ok=True)
|
||||||
|
|
||||||
def gather_results(methods, error_name):
|
def gather_results(methods, error_name):
|
||||||
|
@ -33,6 +33,12 @@ def plot_error_by_drift(methods, error_name, logscale=False, path=None):
|
||||||
if path is not None:
|
if path is not None:
|
||||||
path = join(path, f'error_by_drift_{error_name}.{plotext}')
|
path = join(path, f'error_by_drift_{error_name}.{plotext}')
|
||||||
method_names, true_prevs, estim_prevs, tr_prevs = gather_results(methods, error_name)
|
method_names, true_prevs, estim_prevs, tr_prevs = gather_results(methods, error_name)
|
||||||
|
|
||||||
|
method_order = ['SVM(AE)' if error_name=='ae' else 'SVM(RAE)',
|
||||||
|
'PCC', 'SVM(KLD)', 'SVM(Q)', 'SVM(NKLD)', 'CC', 'HDy',
|
||||||
|
'E(PACC)$_\\mathrm{Ptr}$',
|
||||||
|
'E(PACC)$_\\mathrm{AE}$' if error_name=='ae' else 'E(PACC)$_\\mathrm{RAE}$',
|
||||||
|
'QuaNet', 'PACC', 'ACC', 'SLD']
|
||||||
qp.plot.error_by_drift(
|
qp.plot.error_by_drift(
|
||||||
method_names,
|
method_names,
|
||||||
true_prevs,
|
true_prevs,
|
||||||
|
@ -43,7 +49,9 @@ def plot_error_by_drift(methods, error_name, logscale=False, path=None):
|
||||||
show_std=False,
|
show_std=False,
|
||||||
logscale=logscale,
|
logscale=logscale,
|
||||||
title=f'Quantification error as a function of distribution shift',
|
title=f'Quantification error as a function of distribution shift',
|
||||||
savepath=path
|
savepath=path,
|
||||||
|
vlines=[0.02, 0.1055],
|
||||||
|
method_order=method_order
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,9 +60,15 @@ def diagonal_plot(methods, error_name, path=None):
|
||||||
if path is not None:
|
if path is not None:
|
||||||
path = join(path, f'diag_{error_name}')
|
path = join(path, f'diag_{error_name}')
|
||||||
method_names, true_prevs, estim_prevs, tr_prevs = gather_results(methods, error_name)
|
method_names, true_prevs, estim_prevs, tr_prevs = gather_results(methods, error_name)
|
||||||
qp.plot.binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=0, title='Negative', legend=False, show_std=False, savepath=f'{path}_neg.{plotext}')
|
method_order = ['SVM(AE)' if error_name == 'ae' else 'SVM(RAE)',
|
||||||
qp.plot.binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title='Neutral', legend=False, show_std=False, savepath=f'{path}_neu.{plotext}')
|
'PCC', 'SVM(KLD)', 'SVM(Q)', 'SVM(NKLD)', 'CC', 'HDy',
|
||||||
qp.plot.binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=2, title='Positive', legend=True, show_std=False, savepath=f'{path}_pos.{plotext}')
|
'E(PACC)$_\\mathrm{Ptr}$',
|
||||||
|
'E(PACC)$_\\mathrm{AE}$' if error_name == 'ae' else 'E(PACC)$_\\mathrm{RAE}$',
|
||||||
|
'QuaNet', 'PACC', 'ACC', 'SLD']
|
||||||
|
qp.plot.binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=0, title='Negative', legend=False, show_std=False, savepath=f'{path}_neg.{plotext}', method_order=method_order)
|
||||||
|
qp.plot.binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title='Neutral', legend=False, show_std=False, savepath=f'{path}_neu.{plotext}', method_order=method_order)
|
||||||
|
qp.plot.binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=2, title='Positive', legend=False, show_std=False, savepath=f'{path}_pos.{plotext}', method_order=method_order)
|
||||||
|
qp.plot.binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=2, title='Positive', legend=True, show_std=False, savepath=f'{path}_pos.legend.{plotext}', method_order=method_order)
|
||||||
|
|
||||||
|
|
||||||
def binary_bias_global(methods, error_name, path=None):
|
def binary_bias_global(methods, error_name, path=None):
|
||||||
|
@ -84,12 +98,12 @@ new_methods_rae = ['svmmrae' , 'epaccmraeptr', 'epaccmraemrae', 'hdy', 'quanet']
|
||||||
plot_error_by_drift(gao_seb_methods+new_methods_ae, error_name='ae', path=plotdir)
|
plot_error_by_drift(gao_seb_methods+new_methods_ae, error_name='ae', path=plotdir)
|
||||||
plot_error_by_drift(gao_seb_methods+new_methods_rae, error_name='rae', logscale=True, path=plotdir)
|
plot_error_by_drift(gao_seb_methods+new_methods_rae, error_name='rae', logscale=True, path=plotdir)
|
||||||
|
|
||||||
diagonal_plot(gao_seb_methods+new_methods_ae, error_name='ae', path=plotdir)
|
# diagonal_plot(gao_seb_methods+new_methods_ae, error_name='ae', path=plotdir)
|
||||||
diagonal_plot(gao_seb_methods+new_methods_rae, error_name='rae', path=plotdir)
|
# diagonal_plot(gao_seb_methods+new_methods_rae, error_name='rae', path=plotdir)
|
||||||
|
|
||||||
binary_bias_global(gao_seb_methods+new_methods_ae, error_name='ae', path=plotdir)
|
|
||||||
binary_bias_global(gao_seb_methods+new_methods_rae, error_name='rae', path=plotdir)
|
|
||||||
|
|
||||||
|
# binary_bias_global(gao_seb_methods+new_methods_ae, error_name='ae', path=plotdir)
|
||||||
|
# binary_bias_global(gao_seb_methods+new_methods_rae, error_name='rae', path=plotdir)
|
||||||
|
#
|
||||||
#binary_bias_bins(gao_seb_methods+new_methods_ae, error_name='ae', path=plotdir)
|
#binary_bias_bins(gao_seb_methods+new_methods_ae, error_name='ae', path=plotdir)
|
||||||
#binary_bias_bins(gao_seb_methods+new_methods_rae, error_name='rae', path=plotdir)
|
#binary_bias_bins(gao_seb_methods+new_methods_rae, error_name='rae', path=plotdir)
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ from experiments import result_path
|
||||||
from tabular import Table
|
from tabular import Table
|
||||||
|
|
||||||
tables_path = './tables'
|
tables_path = './tables'
|
||||||
|
results_path = './results'
|
||||||
MAXTONE = 50 # sets the intensity of the maximum color reached by the worst (red) and best (green) results
|
MAXTONE = 50 # sets the intensity of the maximum color reached by the worst (red) and best (green) results
|
||||||
|
|
||||||
makedirs(tables_path, exist_ok=True)
|
makedirs(tables_path, exist_ok=True)
|
||||||
|
@ -23,8 +24,8 @@ def save_table(path, table):
|
||||||
foo.write(table)
|
foo.write(table)
|
||||||
|
|
||||||
|
|
||||||
def experiment_errors(path, dataset, method, loss):
|
def experiment_errors(path, dataset, method, optloss, loss):
|
||||||
path = result_path(path, dataset, method, 'm'+loss if not loss.startswith('m') else loss)
|
path = result_path(path, dataset, method, 'm'+optloss if not loss.startswith('m') else optloss)
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
true_prevs, estim_prevs, _, _, _, _ = pickle.load(open(path, 'rb'))
|
true_prevs, estim_prevs, _, _, _, _ = pickle.load(open(path, 'rb'))
|
||||||
err_fn = getattr(qp.error, loss)
|
err_fn = getattr(qp.error, loss)
|
||||||
|
@ -35,13 +36,10 @@ def experiment_errors(path, dataset, method, loss):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='Generate tables for Tweeter Sentiment Quantification')
|
|
||||||
parser.add_argument('results', metavar='RESULT_PATH', type=str,
|
|
||||||
help='path to the directory where to store the results')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
datasets = qp.datasets.TWITTER_SENTIMENT_DATASETS_TEST
|
datasets = qp.datasets.TWITTER_SENTIMENT_DATASETS_TEST
|
||||||
evaluation_measures = [qp.error.ae, qp.error.rae]
|
evaluation_measures = [qp.error.ae, qp.error.rae]
|
||||||
|
secundary_eval_measures = [qp.error.kld.__name__, qp.error.nkld.__name__, qp.error.se.__name__]
|
||||||
gao_seb_methods = ['cc', 'acc', 'pcc', 'pacc', 'sld', 'svmq', 'svmkld', 'svmnkld']
|
gao_seb_methods = ['cc', 'acc', 'pcc', 'pacc', 'sld', 'svmq', 'svmkld', 'svmnkld']
|
||||||
new_methods = ['hdy', 'quanet']
|
new_methods = ['hdy', 'quanet']
|
||||||
|
|
||||||
|
@ -52,32 +50,29 @@ if __name__ == '__main__':
|
||||||
# Tables evaluation scores for AE and RAE (two tables)
|
# Tables evaluation scores for AE and RAE (two tables)
|
||||||
# ----------------------------------------------------
|
# ----------------------------------------------------
|
||||||
|
|
||||||
eval_name = eval_func.__name__
|
main_eval_name = eval_func.__name__
|
||||||
added_methods = ['svmm' + eval_name, f'epaccm{eval_name}ptr', f'epaccm{eval_name}m{eval_name}'] + new_methods
|
added_methods = ['svmm' + main_eval_name, f'epaccm{main_eval_name}ptr', f'epaccm{main_eval_name}m{main_eval_name}'] + new_methods
|
||||||
methods = gao_seb_methods + added_methods
|
methods = gao_seb_methods + added_methods
|
||||||
nold_methods = len(gao_seb_methods)
|
nold_methods = len(gao_seb_methods)
|
||||||
nnew_methods = len(added_methods)
|
nnew_methods = len(added_methods)
|
||||||
|
|
||||||
|
for eval_name in [main_eval_name] + secundary_eval_measures:
|
||||||
|
|
||||||
# fill data table
|
# fill data table
|
||||||
table = Table(benchmarks=datasets, methods=methods)
|
table = Table(benchmarks=datasets, methods=methods)
|
||||||
for dataset in datasets:
|
for dataset in datasets:
|
||||||
for method in methods:
|
for method in methods:
|
||||||
table.add(dataset, method, experiment_errors(args.results, dataset, method, eval_name))
|
table.add(dataset, method, experiment_errors(results_path, dataset, method, main_eval_name, eval_name))
|
||||||
|
|
||||||
# write the latex table
|
# write the latex table
|
||||||
# tabular = """
|
|
||||||
# \\begin{tabularx}{\\textwidth}{|c||""" + ('Y|'*nold_methods)+ '|' + ('Y|'*nnew_methods) + """} \hline
|
|
||||||
# & \multicolumn{"""+str(nold_methods)+"""}{c||}{Methods tested in~\cite{Gao:2016uq}} &
|
|
||||||
# \multicolumn{"""+str(nnew_methods)+"""}{c|}{} \\\\ \hline
|
|
||||||
# """
|
|
||||||
tabular = """
|
tabular = """
|
||||||
\\resizebox{\\textwidth}{!}{%
|
\\resizebox{\\textwidth}{!}{%
|
||||||
\\begin{tabular}{|c||""" + ('c|' * nold_methods) + '|' + ('c|' * nnew_methods) + """} \hline
|
\\begin{tabular}{|c||""" + ('c|' * nold_methods) + '|' + ('c|' * nnew_methods) + """} \hline
|
||||||
& \multicolumn{""" + str(nold_methods) + """}{c||}{Methods tested in~\cite{Gao:2016uq}} &
|
& \multicolumn{""" + str(nold_methods) + """}{c||}{Methods tested in~\cite{Gao:2016uq}} &
|
||||||
\multicolumn{""" + str(nnew_methods) + """}{c|}{} \\\\ \hline
|
\multicolumn{""" + str(nnew_methods) + """}{c|}{Newly added methods} \\\\ \hline
|
||||||
"""
|
"""
|
||||||
rowreplace={dataset: nicename(dataset) for dataset in datasets}
|
rowreplace={dataset: nicename(dataset) for dataset in datasets}
|
||||||
colreplace={method: nicename(method, eval_name, side=True) for method in methods}
|
colreplace={method: nicename(method, main_eval_name, side=True) for method in methods}
|
||||||
|
|
||||||
tabular += table.latexTabular(benchmark_replace=rowreplace, method_replace=colreplace)
|
tabular += table.latexTabular(benchmark_replace=rowreplace, method_replace=colreplace)
|
||||||
tabular += """
|
tabular += """
|
||||||
|
@ -85,7 +80,9 @@ if __name__ == '__main__':
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
save_table(f'./tables/tab_results_{eval_name}.new.tex', tabular)
|
save_table(f'./tables/tab_results_{main_eval_name}_{eval_name}.tex', tabular)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
# Tables ranks for AE and RAE (two tables)
|
# Tables ranks for AE and RAE (two tables)
|
||||||
# ----------------------------------------------------
|
# ----------------------------------------------------
|
||||||
|
@ -140,6 +137,6 @@ if __name__ == '__main__':
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
save_table(f'./tables/tab_rank_{eval_name}.new.tex', tabular)
|
save_table(f'./tables/tab_rank_{main_eval_name}.{eval_name}.tex', tabular)
|
||||||
|
|
||||||
print("[Done]")
|
print("[Done]")
|
||||||
|
|
|
@ -283,7 +283,7 @@ class Table:
|
||||||
return t
|
return t
|
||||||
|
|
||||||
def dropMethods(self, methods):
|
def dropMethods(self, methods):
|
||||||
drop_index = [self.method_index[m] for m in methods]
|
drop_index = [self.method_index[m] for m in methods if m in self.method_index]
|
||||||
new_methods = np.delete(self.methods, drop_index)
|
new_methods = np.delete(self.methods, drop_index)
|
||||||
new_index = {col:j for j, col in enumerate(new_methods)}
|
new_index = {col:j for j, col in enumerate(new_methods)}
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,9 @@ nice = {
|
||||||
'mrae':'RAE',
|
'mrae':'RAE',
|
||||||
'ae':'AE',
|
'ae':'AE',
|
||||||
'rae':'RAE',
|
'rae':'RAE',
|
||||||
|
'kld':'KLD',
|
||||||
|
'nkld':'NKLD',
|
||||||
|
'se':'SE',
|
||||||
'svmkld': 'SVM(KLD)',
|
'svmkld': 'SVM(KLD)',
|
||||||
'svmnkld': 'SVM(NKLD)',
|
'svmnkld': 'SVM(NKLD)',
|
||||||
'svmq': 'SVM(Q)',
|
'svmq': 'SVM(Q)',
|
||||||
|
|
|
@ -13,7 +13,7 @@ plt.rcParams['font.size'] = 16
|
||||||
|
|
||||||
|
|
||||||
def binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title=None, show_std=True, legend=True,
|
def binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title=None, show_std=True, legend=True,
|
||||||
train_prev=None, savepath=None):
|
train_prev=None, savepath=None, method_order=None):
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots()
|
||||||
ax.set_aspect('equal')
|
ax.set_aspect('equal')
|
||||||
ax.grid()
|
ax.grid()
|
||||||
|
@ -21,7 +21,15 @@ def binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title=No
|
||||||
|
|
||||||
method_names, true_prevs, estim_prevs = _merge(method_names, true_prevs, estim_prevs)
|
method_names, true_prevs, estim_prevs = _merge(method_names, true_prevs, estim_prevs)
|
||||||
|
|
||||||
for method, true_prev, estim_prev in zip(method_names, true_prevs, estim_prevs):
|
order = list(zip(method_names, true_prevs, estim_prevs))
|
||||||
|
if method_order is not None:
|
||||||
|
table = {method_name:[true_prev, estim_prev] for method_name, true_prev, estim_prev in order}
|
||||||
|
order = [(method_name, *table[method_name]) for method_name in method_order]
|
||||||
|
|
||||||
|
cm = plt.get_cmap('tab20')
|
||||||
|
NUM_COLORS = len(method_names)
|
||||||
|
ax.set_prop_cycle(color=[cm(1. * i / NUM_COLORS) for i in range(NUM_COLORS)])
|
||||||
|
for method, true_prev, estim_prev in order:
|
||||||
true_prev = true_prev[:,pos_class]
|
true_prev = true_prev[:,pos_class]
|
||||||
estim_prev = estim_prev[:,pos_class]
|
estim_prev = estim_prev[:,pos_class]
|
||||||
|
|
||||||
|
@ -44,8 +52,12 @@ def binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title=No
|
||||||
|
|
||||||
if legend:
|
if legend:
|
||||||
box = ax.get_position()
|
box = ax.get_position()
|
||||||
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
|
# ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
|
||||||
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
|
# ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
|
||||||
|
# ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
|
||||||
|
ax.legend(loc='lower center',
|
||||||
|
bbox_to_anchor=(1, -0.5),
|
||||||
|
ncol=(len(method_names)+1)//2)
|
||||||
|
|
||||||
save_or_show(savepath)
|
save_or_show(savepath)
|
||||||
|
|
||||||
|
@ -158,10 +170,19 @@ def _merge(method_names, true_prevs, estim_prevs):
|
||||||
return method_order, true_prevs_, estim_prevs_
|
return method_order, true_prevs_, estim_prevs_
|
||||||
|
|
||||||
|
|
||||||
def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, error_name='ae', show_std=True,
|
def _set_colors(ax, n_methods):
|
||||||
|
NUM_COLORS = n_methods
|
||||||
|
cm = plt.get_cmap('tab20')
|
||||||
|
ax.set_prop_cycle(color=[cm(1. * i / NUM_COLORS) for i in range(NUM_COLORS)])
|
||||||
|
|
||||||
|
|
||||||
|
def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, error_name='ae', show_std=False,
|
||||||
|
show_density=True,
|
||||||
logscale=False,
|
logscale=False,
|
||||||
title=f'Quantification error as a function of distribution shift',
|
title=f'Quantification error as a function of distribution shift',
|
||||||
savepath=None):
|
savepath=None,
|
||||||
|
vlines=None,
|
||||||
|
method_order=None):
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
fig, ax = plt.subplots()
|
||||||
ax.grid()
|
ax.grid()
|
||||||
|
@ -171,7 +192,10 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, e
|
||||||
|
|
||||||
# join all data, and keep the order in which the methods appeared for the first time
|
# join all data, and keep the order in which the methods appeared for the first time
|
||||||
data = defaultdict(lambda:{'x':np.empty(shape=(0)), 'y':np.empty(shape=(0))})
|
data = defaultdict(lambda:{'x':np.empty(shape=(0)), 'y':np.empty(shape=(0))})
|
||||||
|
|
||||||
|
if method_order is None:
|
||||||
method_order = []
|
method_order = []
|
||||||
|
|
||||||
for method, test_prevs_i, estim_prevs_i, tr_prev_i in zip(method_names, true_prevs, estim_prevs, tr_prevs):
|
for method, test_prevs_i, estim_prevs_i, tr_prev_i in zip(method_names, true_prevs, estim_prevs, tr_prevs):
|
||||||
tr_prev_i = np.repeat(tr_prev_i.reshape(1,-1), repeats=test_prevs_i.shape[0], axis=0)
|
tr_prev_i = np.repeat(tr_prev_i.reshape(1,-1), repeats=test_prevs_i.shape[0], axis=0)
|
||||||
|
|
||||||
|
@ -184,9 +208,12 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, e
|
||||||
if method not in method_order:
|
if method not in method_order:
|
||||||
method_order.append(method)
|
method_order.append(method)
|
||||||
|
|
||||||
|
_set_colors(ax, n_methods=len(method_order))
|
||||||
|
|
||||||
bins = np.linspace(0, 1, n_bins+1)
|
bins = np.linspace(0, 1, n_bins+1)
|
||||||
binwidth = 1 / n_bins
|
binwidth = 1 / n_bins
|
||||||
min_x, max_x = None, None
|
min_x, max_x, min_y, max_y = None, None, None, None
|
||||||
|
npoints = np.zeros(len(bins), dtype=float)
|
||||||
for method in method_order:
|
for method in method_order:
|
||||||
tr_test_drifts = data[method]['x']
|
tr_test_drifts = data[method]['x']
|
||||||
method_drifts = data[method]['y']
|
method_drifts = data[method]['y']
|
||||||
|
@ -194,33 +221,49 @@ def error_by_drift(method_names, true_prevs, estim_prevs, tr_prevs, n_bins=20, e
|
||||||
method_drifts=np.log(1+method_drifts)
|
method_drifts=np.log(1+method_drifts)
|
||||||
|
|
||||||
inds = np.digitize(tr_test_drifts, bins, right=True)
|
inds = np.digitize(tr_test_drifts, bins, right=True)
|
||||||
|
|
||||||
xs, ys, ystds = [], [], []
|
xs, ys, ystds = [], [], []
|
||||||
for ind in range(len(bins)):
|
for p,ind in enumerate(range(len(bins))):
|
||||||
selected = inds==ind
|
selected = inds==ind
|
||||||
if selected.sum() > 0:
|
if selected.sum() > 0:
|
||||||
xs.append(ind*binwidth)
|
xs.append(ind*binwidth-binwidth/2)
|
||||||
ys.append(np.mean(method_drifts[selected]))
|
ys.append(np.mean(method_drifts[selected]))
|
||||||
ystds.append(np.std(method_drifts[selected]))
|
ystds.append(np.std(method_drifts[selected]))
|
||||||
|
npoints[p] += len(method_drifts[selected])
|
||||||
|
|
||||||
xs = np.asarray(xs)
|
xs = np.asarray(xs)
|
||||||
ys = np.asarray(ys)
|
ys = np.asarray(ys)
|
||||||
ystds = np.asarray(ystds)
|
ystds = np.asarray(ystds)
|
||||||
|
|
||||||
min_x_method, max_x_method = xs.min(), xs.max()
|
min_x_method, max_x_method, min_y_method, max_y_method = xs.min(), xs.max(), ys.min(), ys.max()
|
||||||
min_x = min_x_method if min_x is None or min_x_method < min_x else min_x
|
min_x = min_x_method if min_x is None or min_x_method < min_x else min_x
|
||||||
max_x = max_x_method if max_x is None or max_x_method > max_x else max_x
|
max_x = max_x_method if max_x is None or max_x_method > max_x else max_x
|
||||||
|
max_y = max_y_method if max_y is None or max_y_method > max_y else max_y
|
||||||
|
min_y = min_y_method if min_y is None or min_y_method < min_y else min_y
|
||||||
|
max_y = max_y_method if max_y is None or max_y_method > max_y else max_y
|
||||||
|
|
||||||
|
ax.errorbar(xs, ys, fmt='-', marker='o', color='w', markersize=8, linewidth=4, zorder=1)
|
||||||
|
ax.errorbar(xs, ys, fmt='-', marker='o', label=method, markersize=6, linewidth=2, zorder=2)
|
||||||
|
|
||||||
ax.errorbar(xs, ys, fmt='-', marker='o', label=method, markersize=3, zorder=2)
|
|
||||||
if show_std:
|
if show_std:
|
||||||
ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25)
|
ax.fill_between(xs, ys-ystds, ys+ystds, alpha=0.25)
|
||||||
|
|
||||||
|
if show_density:
|
||||||
|
ax.bar([ind * binwidth-binwidth/2 for ind in range(len(bins))],
|
||||||
|
max_y*npoints/np.max(npoints), alpha=0.15, color='g', width=binwidth, label='density')
|
||||||
|
|
||||||
ax.set(xlabel=f'Distribution shift between training set and test sample',
|
ax.set(xlabel=f'Distribution shift between training set and test sample',
|
||||||
ylabel=f'{error_name.upper()} (true distribution, predicted distribution)',
|
ylabel=f'{error_name.upper()} (true distribution, predicted distribution)',
|
||||||
title=title)
|
title=title)
|
||||||
box = ax.get_position()
|
box = ax.get_position()
|
||||||
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
|
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
|
||||||
|
if vlines:
|
||||||
|
for vline in vlines:
|
||||||
|
ax.axvline(vline, 0, 1, linestyle='--', color='k')
|
||||||
|
# ax.axvline(0.02, 0, 1, linestyle='--', color='k')
|
||||||
|
# ax.axvline(0.1055, 0, 1, linestyle='--', color='k')
|
||||||
|
ax.set_xlim(0, max_x)
|
||||||
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
|
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
|
||||||
ax.set_xlim(min_x, max_x)
|
|
||||||
|
|
||||||
save_or_show(savepath)
|
save_or_show(savepath)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue