From 02dd2846ff4db54a6e6eedb35b15ffc98dad38bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pablo=20Gonz=C3=A1lez?= Date: Fri, 24 Jun 2022 14:05:47 +0200 Subject: [PATCH 1/2] changing app to use prevalence_linspace function with smooth limits --- quapy/protocol.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/quapy/protocol.py b/quapy/protocol.py index 69b99ad..7652eeb 100644 --- a/quapy/protocol.py +++ b/quapy/protocol.py @@ -132,15 +132,17 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol): :param n_prevalences: the number of equidistant prevalence points to extract from the [0,1] interval for the grid (default is 21) :param repeats: number of copies for each valid prevalence vector (default is 10) + :param smooth_limits_epsilon: the quantity to add and subtract to the limits 0 and 1 :param random_state: allows replicating samples across runs (default None) """ - def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, random_state=None, return_type='sample_prev'): + def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, smooth_limits_epsilon=0, random_state=None, return_type='sample_prev'): super(APP, self).__init__(random_state) self.data = data self.sample_size = sample_size self.n_prevalences = n_prevalences self.repeats = repeats + self.smooth_limits_epsilon = smooth_limits_epsilon self.collator = OnLabelledCollectionProtocol.get_collator(return_type) def prevalence_grid(self): @@ -159,7 +161,7 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol): in the grid multiplied by `repeat` """ dimensions = self.data.n_classes - s = np.linspace(0., 1., self.n_prevalences, endpoint=True) + s = F.prevalence_linspace(self.n_prevalences, repeats=1, smooth_limits_epsilon=self.smooth_limits_epsilon) s = [s] * (dimensions - 1) prevs = [p for p in itertools.product(*s, repeat=1) if (sum(p) <= 1.0)] prevs = np.asarray(prevs).reshape(len(prevs), -1) From 750814ef2a4a74e60f5ecd857d784211813d6caf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pablo=20Gonz=C3=A1lez?= Date: Fri, 24 Jun 2022 14:20:08 +0200 Subject: [PATCH 2/2] fixing bug in ACC when using cross validation --- quapy/method/aggregative.py | 1 + 1 file changed, 1 insertion(+) diff --git a/quapy/method/aggregative.py b/quapy/method/aggregative.py index c2f4717..759a853 100644 --- a/quapy/method/aggregative.py +++ b/quapy/method/aggregative.py @@ -223,6 +223,7 @@ def cross_generate_predictions( # fit the learner on all data learner.fit(*data.Xy) + y = data.y classes = data.classes_ else: learner, val_data = _training_helper(