From 775417c8eb5f984fee713d87fd622b6929f84a7b Mon Sep 17 00:00:00 2001 From: Alex Moreo Date: Thu, 18 Feb 2021 13:48:41 +0100 Subject: [PATCH] bugfix in PACC --- quapy/method/aggregative.py | 4 ++-- quapy/plot.py | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/quapy/method/aggregative.py b/quapy/method/aggregative.py index 1cd98be..392f866 100644 --- a/quapy/method/aggregative.py +++ b/quapy/method/aggregative.py @@ -207,7 +207,7 @@ class ACC(AggregativeQuantifier): class_count = data.counts() # fit the learner on all data - self.learner.fit(*data.Xy) + self.learner, _ = training_helper(self.learner, data, fit_learner, val_split=None) else: self.learner, val_data = training_helper(self.learner, data, fit_learner, val_split=val_split) @@ -294,7 +294,7 @@ class PACC(AggregativeProbabilisticQuantifier): y_ = np.vstack(y_) # fit the learner on all data - self.learner.fit(*data.Xy) + self.learner, _ = training_helper(self.learner, data, fit_learner, ensure_probabilistic=True, val_split=None) else: self.learner, val_data = training_helper( diff --git a/quapy/plot.py b/quapy/plot.py index 270fb80..0f5a0aa 100644 --- a/quapy/plot.py +++ b/quapy/plot.py @@ -12,7 +12,8 @@ plt.rcParams['figure.dpi'] = 200 plt.rcParams['font.size'] = 16 -def binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title=None, show_std=True, legend=True, savepath=None): +def binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title=None, show_std=True, legend=True, + train_prev=None, savepath=None): fig, ax = plt.subplots() ax.set_aspect('equal') ax.grid() @@ -33,6 +34,10 @@ def binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title=No if show_std: ax.fill_between(x_ticks, y_ave - y_std, y_ave + y_std, alpha=0.25) + if train_prev is not None: + train_prev = train_prev[pos_class] + ax.scatter(train_prev, train_prev, c='c', label='tr-prev', linewidth=2, edgecolor='k', s=100, zorder=3) + ax.set(xlabel='true prevalence', ylabel='estimated prevalence', title=title) ax.set_ylim(0, 1) ax.set_xlim(0, 1)