Merge branch 'master' of github.com:HLT-ISTI/QuaPy

This commit is contained in:
Alejandro Moreo Fernandez 2023-02-28 10:25:52 +01:00
commit d0706005d7
3 changed files with 29 additions and 4 deletions

View File

@ -9,9 +9,9 @@ import math
import quapy as qp import quapy as qp
plt.rcParams['figure.figsize'] = [12, 8] plt.rcParams['figure.figsize'] = [10, 6]
plt.rcParams['figure.dpi'] = 200 plt.rcParams['figure.dpi'] = 200
plt.rcParams['font.size'] = 16 plt.rcParams['font.size'] = 18
def binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title=None, show_std=True, legend=True, def binary_diagonal(method_names, true_prevs, estim_prevs, pos_class=1, title=None, show_std=True, legend=True,

View File

@ -214,18 +214,30 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
:param smooth_limits_epsilon: the quantity to add and subtract to the limits 0 and 1 :param smooth_limits_epsilon: the quantity to add and subtract to the limits 0 and 1
:param random_state: allows replicating samples across runs (default 0, meaning that the sequence of samples :param random_state: allows replicating samples across runs (default 0, meaning that the sequence of samples
will be the same every time the protocol is called) will be the same every time the protocol is called)
:param sanity_check: int, raises an exception warning the user that the number of examples to be generated exceed
this number; set to None for skipping this check
:param return_type: set to "sample_prev" (default) to get the pairs of (sample, prevalence) at each iteration, or :param return_type: set to "sample_prev" (default) to get the pairs of (sample, prevalence) at each iteration, or
to "labelled_collection" to get instead instances of LabelledCollection to "labelled_collection" to get instead instances of LabelledCollection
""" """
def __init__(self, data:LabelledCollection, sample_size=None, n_prevalences=21, repeats=10, def __init__(self, data: LabelledCollection, sample_size=None, n_prevalences=21, repeats=10,
smooth_limits_epsilon=0, random_state=0, return_type='sample_prev'): smooth_limits_epsilon=0, random_state=0, sanity_check=10000, return_type='sample_prev'):
super(APP, self).__init__(random_state) super(APP, self).__init__(random_state)
self.data = data self.data = data
self.sample_size = qp._get_sample_size(sample_size) self.sample_size = qp._get_sample_size(sample_size)
self.n_prevalences = n_prevalences self.n_prevalences = n_prevalences
self.repeats = repeats self.repeats = repeats
self.smooth_limits_epsilon = smooth_limits_epsilon self.smooth_limits_epsilon = smooth_limits_epsilon
if not ((isinstance(sanity_check, int) and sanity_check>0) or sanity_check is None):
raise ValueError('param "sanity_check" must either be None or a positive integer')
if isinstance(sanity_check, int):
n = F.num_prevalence_combinations(n_prevpoints=n_prevalences, n_classes=data.n_classes, n_repeats=repeats)
if n > sanity_check:
raise RuntimeError(
f"Abort: the number of samples that will be generated by {self.__class__.__name__} ({n}) "
f"exceeds the maximum number of allowed samples ({sanity_check = }). Set 'sanity_check' to "
f"None for bypassing this check, or to a higher number.")
self.collator = OnLabelledCollectionProtocol.get_collator(return_type) self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
def prevalence_grid(self): def prevalence_grid(self):

View File

@ -1,5 +1,7 @@
import unittest import unittest
import numpy as np import numpy as np
import quapy.functional
from quapy.data import LabelledCollection from quapy.data import LabelledCollection
from quapy.protocol import APP, NPP, UPP, DomainMixer, AbstractStochasticSeededProtocol from quapy.protocol import APP, NPP, UPP, DomainMixer, AbstractStochasticSeededProtocol
@ -19,6 +21,17 @@ def samples_to_str(protocol):
class TestProtocols(unittest.TestCase): class TestProtocols(unittest.TestCase):
def test_app_sanity_check(self):
data = mock_labelled_collection()
n_prevpoints = 101
repeats = 10
with self.assertRaises(RuntimeError):
p = APP(data, sample_size=5, n_prevalences=n_prevpoints, repeats=repeats, random_state=42)
n_combinations = \
quapy.functional.num_prevalence_combinations(n_prevpoints, n_classes=data.n_classes, n_repeats=repeats)
p = APP(data, sample_size=5, n_prevalences=n_prevpoints, random_state=42, sanity_check=n_combinations)
p = APP(data, sample_size=5, n_prevalences=n_prevpoints, random_state=42, sanity_check=None)
def test_app_replicate(self): def test_app_replicate(self):
data = mock_labelled_collection() data = mock_labelled_collection()
p = APP(data, sample_size=5, n_prevalences=11, random_state=42) p = APP(data, sample_size=5, n_prevalences=11, random_state=42)