1
0
Fork 0

adding sanity check to APP, in order to prevent the user unattendedly runs into a never-endting loop of samples being generated

This commit is contained in:
Alejandro Moreo Fernandez 2023-02-22 11:57:22 +01:00
parent bfaa5678d7
commit 140ab3bfc9
2 changed files with 27 additions and 2 deletions

View File

@ -214,18 +214,30 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
:param smooth_limits_epsilon: the quantity to add and subtract to the limits 0 and 1
:param random_state: allows replicating samples across runs (default 0, meaning that the sequence of samples
will be the same every time the protocol is called)
:param sanity_check: int, raises an exception warning the user that the number of examples to be generated exceed
this number; set to None for skipping this check
:param return_type: set to "sample_prev" (default) to get the pairs of (sample, prevalence) at each iteration, or
to "labelled_collection" to get instead instances of LabelledCollection
"""
def __init__(self, data:LabelledCollection, sample_size=None, n_prevalences=21, repeats=10,
smooth_limits_epsilon=0, random_state=0, return_type='sample_prev'):
def __init__(self, data: LabelledCollection, sample_size=None, n_prevalences=21, repeats=10,
smooth_limits_epsilon=0, random_state=0, sanity_check=10000, return_type='sample_prev'):
super(APP, self).__init__(random_state)
self.data = data
self.sample_size = qp._get_sample_size(sample_size)
self.n_prevalences = n_prevalences
self.repeats = repeats
self.smooth_limits_epsilon = smooth_limits_epsilon
if not ((isinstance(sanity_check, int) and sanity_check>0) or sanity_check is None):
raise ValueError('param "sanity_check" must either be None or a positive integer')
if isinstance(sanity_check, int):
n = F.num_prevalence_combinations(n_prevpoints=n_prevalences, n_classes=data.n_classes, n_repeats=repeats)
if n > sanity_check:
raise RuntimeError(
f"Abort: the number of samples that will be generated by {self.__class__.__name__} ({n}) "
f"exceeds the maximum number of allowed samples ({sanity_check = }). Set 'sanity_check' to "
f"None for bypassing this check, or to a higher number.")
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
def prevalence_grid(self):

View File

@ -1,5 +1,7 @@
import unittest
import numpy as np
import quapy.functional
from quapy.data import LabelledCollection
from quapy.protocol import APP, NPP, UPP, DomainMixer, AbstractStochasticSeededProtocol
@ -19,6 +21,17 @@ def samples_to_str(protocol):
class TestProtocols(unittest.TestCase):
def test_app_sanity_check(self):
data = mock_labelled_collection()
n_prevpoints = 101
repeats = 10
with self.assertRaises(RuntimeError):
p = APP(data, sample_size=5, n_prevalences=n_prevpoints, repeats=repeats, random_state=42)
n_combinations = \
quapy.functional.num_prevalence_combinations(n_prevpoints, n_classes=data.n_classes, n_repeats=repeats)
p = APP(data, sample_size=5, n_prevalences=n_prevpoints, random_state=42, sanity_check=n_combinations)
p = APP(data, sample_size=5, n_prevalences=n_prevpoints, random_state=42, sanity_check=None)
def test_app_replicate(self):
data = mock_labelled_collection()
p = APP(data, sample_size=5, n_prevalences=11, random_state=42)