1
0
Fork 0

adding sanity check to APP, in order to prevent the user unattendedly runs into a never-endting loop of samples being generated

This commit is contained in:
Alejandro Moreo Fernandez 2023-02-22 11:57:22 +01:00
parent bfaa5678d7
commit 140ab3bfc9
2 changed files with 27 additions and 2 deletions

View File

@ -214,18 +214,30 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
:param smooth_limits_epsilon: the quantity to add and subtract to the limits 0 and 1 :param smooth_limits_epsilon: the quantity to add and subtract to the limits 0 and 1
:param random_state: allows replicating samples across runs (default 0, meaning that the sequence of samples :param random_state: allows replicating samples across runs (default 0, meaning that the sequence of samples
will be the same every time the protocol is called) will be the same every time the protocol is called)
:param sanity_check: int, raises an exception warning the user that the number of examples to be generated exceed
this number; set to None for skipping this check
:param return_type: set to "sample_prev" (default) to get the pairs of (sample, prevalence) at each iteration, or :param return_type: set to "sample_prev" (default) to get the pairs of (sample, prevalence) at each iteration, or
to "labelled_collection" to get instead instances of LabelledCollection to "labelled_collection" to get instead instances of LabelledCollection
""" """
def __init__(self, data: LabelledCollection, sample_size=None, n_prevalences=21, repeats=10, def __init__(self, data: LabelledCollection, sample_size=None, n_prevalences=21, repeats=10,
smooth_limits_epsilon=0, random_state=0, return_type='sample_prev'): smooth_limits_epsilon=0, random_state=0, sanity_check=10000, return_type='sample_prev'):
super(APP, self).__init__(random_state) super(APP, self).__init__(random_state)
self.data = data self.data = data
self.sample_size = qp._get_sample_size(sample_size) self.sample_size = qp._get_sample_size(sample_size)
self.n_prevalences = n_prevalences self.n_prevalences = n_prevalences
self.repeats = repeats self.repeats = repeats
self.smooth_limits_epsilon = smooth_limits_epsilon self.smooth_limits_epsilon = smooth_limits_epsilon
if not ((isinstance(sanity_check, int) and sanity_check>0) or sanity_check is None):
raise ValueError('param "sanity_check" must either be None or a positive integer')
if isinstance(sanity_check, int):
n = F.num_prevalence_combinations(n_prevpoints=n_prevalences, n_classes=data.n_classes, n_repeats=repeats)
if n > sanity_check:
raise RuntimeError(
f"Abort: the number of samples that will be generated by {self.__class__.__name__} ({n}) "
f"exceeds the maximum number of allowed samples ({sanity_check = }). Set 'sanity_check' to "
f"None for bypassing this check, or to a higher number.")
self.collator = OnLabelledCollectionProtocol.get_collator(return_type) self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
def prevalence_grid(self): def prevalence_grid(self):

View File

@ -1,5 +1,7 @@
import unittest import unittest
import numpy as np import numpy as np
import quapy.functional
from quapy.data import LabelledCollection from quapy.data import LabelledCollection
from quapy.protocol import APP, NPP, UPP, DomainMixer, AbstractStochasticSeededProtocol from quapy.protocol import APP, NPP, UPP, DomainMixer, AbstractStochasticSeededProtocol
@ -19,6 +21,17 @@ def samples_to_str(protocol):
class TestProtocols(unittest.TestCase): class TestProtocols(unittest.TestCase):
def test_app_sanity_check(self):
data = mock_labelled_collection()
n_prevpoints = 101
repeats = 10
with self.assertRaises(RuntimeError):
p = APP(data, sample_size=5, n_prevalences=n_prevpoints, repeats=repeats, random_state=42)
n_combinations = \
quapy.functional.num_prevalence_combinations(n_prevpoints, n_classes=data.n_classes, n_repeats=repeats)
p = APP(data, sample_size=5, n_prevalences=n_prevpoints, random_state=42, sanity_check=n_combinations)
p = APP(data, sample_size=5, n_prevalences=n_prevpoints, random_state=42, sanity_check=None)
def test_app_replicate(self): def test_app_replicate(self):
data = mock_labelled_collection() data = mock_labelled_collection()
p = APP(data, sample_size=5, n_prevalences=11, random_state=42) p = APP(data, sample_size=5, n_prevalences=11, random_state=42)