From 60151dbebbc6be6264a895e32f32bdc975e90f78 Mon Sep 17 00:00:00 2001 From: Alex Moreo Date: Sun, 3 Mar 2024 14:52:37 +0100 Subject: [PATCH] added DOC and ATC --- ClassifierAccuracy/util/__init__.py | 0 ClassifierAccuracy/util/plotting.py | 46 +++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 ClassifierAccuracy/util/__init__.py create mode 100644 ClassifierAccuracy/util/plotting.py diff --git a/ClassifierAccuracy/util/__init__.py b/ClassifierAccuracy/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ClassifierAccuracy/util/plotting.py b/ClassifierAccuracy/util/plotting.py new file mode 100644 index 0000000..18ee82e --- /dev/null +++ b/ClassifierAccuracy/util/plotting.py @@ -0,0 +1,46 @@ +from pathlib import Path + +import matplotlib.pyplot as plt +from os import makedirs +from os.path import join + +from ClassifierAccuracy.util.commons import get_method_names, open_results + + +def plot_diagonal(basedir, cls_name, measure_name, dataset_name='*'): + methods = get_method_names() + results = open_results(basedir, cls_name, measure_name, dataset_name=dataset_name, method_name=methods) + methods, xs, ys = [], [], [] + for method_name in results.keys(): + methods.append(method_name) + xs.append(results[method_name]['true_acc']) + ys.append(results[method_name]['estim_acc']) + plotsubdir = 'all' if dataset_name=='*' else dataset_name + save_path = join('plots', basedir, plotsubdir, 'diagonal.png') + _plot_diagonal(methods, xs, ys, save_path, measure_name) + + +def _plot_diagonal(methods_names, true_xs, estim_ys, save_path, measure_name, title=None): + + makedirs(Path(save_path).parent, exist_ok=True) + + # Create scatter plot + plt.figure(figsize=(10, 10)) + plt.xlim(0, 1) + plt.ylim(0, 1) + plt.plot([0, 1], [0, 1], color='black', linestyle='--') + + for (method_name, xs, ys) in zip(methods_names, true_xs, estim_ys): + plt.scatter(xs, ys, label=f'{method_name}', alpha=0.6) + + plt.legend() + + # Add labels and title + if title is not None: + plt.title(title) + plt.xlabel(f'True {measure_name}') + plt.ylabel(f'Estimated {measure_name}') + + # Display the plot + plt.savefig(save_path) + plt.cla()