forked from moreo/QuaPy
adding sanity check to APP, in order to prevent the user unattendedly runs into a never-endting loop of samples being generated
This commit is contained in:
parent
bfaa5678d7
commit
140ab3bfc9
|
@ -214,18 +214,30 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
|||
:param smooth_limits_epsilon: the quantity to add and subtract to the limits 0 and 1
|
||||
:param random_state: allows replicating samples across runs (default 0, meaning that the sequence of samples
|
||||
will be the same every time the protocol is called)
|
||||
:param sanity_check: int, raises an exception warning the user that the number of examples to be generated exceed
|
||||
this number; set to None for skipping this check
|
||||
:param return_type: set to "sample_prev" (default) to get the pairs of (sample, prevalence) at each iteration, or
|
||||
to "labelled_collection" to get instead instances of LabelledCollection
|
||||
"""
|
||||
|
||||
def __init__(self, data: LabelledCollection, sample_size=None, n_prevalences=21, repeats=10,
|
||||
smooth_limits_epsilon=0, random_state=0, return_type='sample_prev'):
|
||||
smooth_limits_epsilon=0, random_state=0, sanity_check=10000, return_type='sample_prev'):
|
||||
super(APP, self).__init__(random_state)
|
||||
self.data = data
|
||||
self.sample_size = qp._get_sample_size(sample_size)
|
||||
self.n_prevalences = n_prevalences
|
||||
self.repeats = repeats
|
||||
self.smooth_limits_epsilon = smooth_limits_epsilon
|
||||
if not ((isinstance(sanity_check, int) and sanity_check>0) or sanity_check is None):
|
||||
raise ValueError('param "sanity_check" must either be None or a positive integer')
|
||||
if isinstance(sanity_check, int):
|
||||
n = F.num_prevalence_combinations(n_prevpoints=n_prevalences, n_classes=data.n_classes, n_repeats=repeats)
|
||||
if n > sanity_check:
|
||||
raise RuntimeError(
|
||||
f"Abort: the number of samples that will be generated by {self.__class__.__name__} ({n}) "
|
||||
f"exceeds the maximum number of allowed samples ({sanity_check = }). Set 'sanity_check' to "
|
||||
f"None for bypassing this check, or to a higher number.")
|
||||
|
||||
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
|
||||
|
||||
def prevalence_grid(self):
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
|
||||
import quapy.functional
|
||||
from quapy.data import LabelledCollection
|
||||
from quapy.protocol import APP, NPP, UPP, DomainMixer, AbstractStochasticSeededProtocol
|
||||
|
||||
|
@ -19,6 +21,17 @@ def samples_to_str(protocol):
|
|||
|
||||
class TestProtocols(unittest.TestCase):
|
||||
|
||||
def test_app_sanity_check(self):
|
||||
data = mock_labelled_collection()
|
||||
n_prevpoints = 101
|
||||
repeats = 10
|
||||
with self.assertRaises(RuntimeError):
|
||||
p = APP(data, sample_size=5, n_prevalences=n_prevpoints, repeats=repeats, random_state=42)
|
||||
n_combinations = \
|
||||
quapy.functional.num_prevalence_combinations(n_prevpoints, n_classes=data.n_classes, n_repeats=repeats)
|
||||
p = APP(data, sample_size=5, n_prevalences=n_prevpoints, random_state=42, sanity_check=n_combinations)
|
||||
p = APP(data, sample_size=5, n_prevalences=n_prevpoints, random_state=42, sanity_check=None)
|
||||
|
||||
def test_app_replicate(self):
|
||||
data = mock_labelled_collection()
|
||||
p = APP(data, sample_size=5, n_prevalences=11, random_state=42)
|
||||
|
|
Loading…
Reference in New Issue