From 2d12ce12b94f1f325062bf74f8f5a78e59f05643 Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Mon, 18 Dec 2023 17:15:53 +0100 Subject: [PATCH] bugfix in APP --- examples/model_selection.py | 4 ++-- quapy/protocol.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/model_selection.py b/examples/model_selection.py index 485acd8..4e52784 100644 --- a/examples/model_selection.py +++ b/examples/model_selection.py @@ -46,8 +46,8 @@ with qp.util.temp_seed(0): tinit = time() - model = OLD_GridSearchQ( - # model = qp.model_selection.GridSearchQ( + # model = OLD_GridSearchQ( + model = qp.model_selection.GridSearchQ( model=model, param_grid=param_grid, protocol=protocol, diff --git a/quapy/protocol.py b/quapy/protocol.py index 7d7d1df..36362a9 100644 --- a/quapy/protocol.py +++ b/quapy/protocol.py @@ -257,8 +257,9 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol): """ dimensions = self.data.n_classes s = F.prevalence_linspace(self.n_prevalences, repeats=1, smooth_limits_epsilon=self.smooth_limits_epsilon) + eps = (s[1]-s[0])/2 # handling floating rounding s = [s] * (dimensions - 1) - prevs = [p for p in itertools.product(*s, repeat=1) if (sum(p) <= 1.0)] + prevs = [p for p in itertools.product(*s, repeat=1) if (sum(p) < (1.+eps))] prevs = np.asarray(prevs).reshape(len(prevs), -1) if self.repeats > 1: prevs = np.repeat(prevs, self.repeats, axis=0)