forked from moreo/QuaPy
adding sanity check to APP, in order to prevent the user unattendedly runs into a never-endting loop of samples being generated
This commit is contained in:
parent
bfaa5678d7
commit
140ab3bfc9
|
@ -214,18 +214,30 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
||||||
:param smooth_limits_epsilon: the quantity to add and subtract to the limits 0 and 1
|
:param smooth_limits_epsilon: the quantity to add and subtract to the limits 0 and 1
|
||||||
:param random_state: allows replicating samples across runs (default 0, meaning that the sequence of samples
|
:param random_state: allows replicating samples across runs (default 0, meaning that the sequence of samples
|
||||||
will be the same every time the protocol is called)
|
will be the same every time the protocol is called)
|
||||||
|
:param sanity_check: int, raises an exception warning the user that the number of examples to be generated exceed
|
||||||
|
this number; set to None for skipping this check
|
||||||
:param return_type: set to "sample_prev" (default) to get the pairs of (sample, prevalence) at each iteration, or
|
:param return_type: set to "sample_prev" (default) to get the pairs of (sample, prevalence) at each iteration, or
|
||||||
to "labelled_collection" to get instead instances of LabelledCollection
|
to "labelled_collection" to get instead instances of LabelledCollection
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data: LabelledCollection, sample_size=None, n_prevalences=21, repeats=10,
|
def __init__(self, data: LabelledCollection, sample_size=None, n_prevalences=21, repeats=10,
|
||||||
smooth_limits_epsilon=0, random_state=0, return_type='sample_prev'):
|
smooth_limits_epsilon=0, random_state=0, sanity_check=10000, return_type='sample_prev'):
|
||||||
super(APP, self).__init__(random_state)
|
super(APP, self).__init__(random_state)
|
||||||
self.data = data
|
self.data = data
|
||||||
self.sample_size = qp._get_sample_size(sample_size)
|
self.sample_size = qp._get_sample_size(sample_size)
|
||||||
self.n_prevalences = n_prevalences
|
self.n_prevalences = n_prevalences
|
||||||
self.repeats = repeats
|
self.repeats = repeats
|
||||||
self.smooth_limits_epsilon = smooth_limits_epsilon
|
self.smooth_limits_epsilon = smooth_limits_epsilon
|
||||||
|
if not ((isinstance(sanity_check, int) and sanity_check>0) or sanity_check is None):
|
||||||
|
raise ValueError('param "sanity_check" must either be None or a positive integer')
|
||||||
|
if isinstance(sanity_check, int):
|
||||||
|
n = F.num_prevalence_combinations(n_prevpoints=n_prevalences, n_classes=data.n_classes, n_repeats=repeats)
|
||||||
|
if n > sanity_check:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Abort: the number of samples that will be generated by {self.__class__.__name__} ({n}) "
|
||||||
|
f"exceeds the maximum number of allowed samples ({sanity_check = }). Set 'sanity_check' to "
|
||||||
|
f"None for bypassing this check, or to a higher number.")
|
||||||
|
|
||||||
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
|
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
|
||||||
|
|
||||||
def prevalence_grid(self):
|
def prevalence_grid(self):
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import quapy.functional
|
||||||
from quapy.data import LabelledCollection
|
from quapy.data import LabelledCollection
|
||||||
from quapy.protocol import APP, NPP, UPP, DomainMixer, AbstractStochasticSeededProtocol
|
from quapy.protocol import APP, NPP, UPP, DomainMixer, AbstractStochasticSeededProtocol
|
||||||
|
|
||||||
|
@ -19,6 +21,17 @@ def samples_to_str(protocol):
|
||||||
|
|
||||||
class TestProtocols(unittest.TestCase):
|
class TestProtocols(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_app_sanity_check(self):
|
||||||
|
data = mock_labelled_collection()
|
||||||
|
n_prevpoints = 101
|
||||||
|
repeats = 10
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
p = APP(data, sample_size=5, n_prevalences=n_prevpoints, repeats=repeats, random_state=42)
|
||||||
|
n_combinations = \
|
||||||
|
quapy.functional.num_prevalence_combinations(n_prevpoints, n_classes=data.n_classes, n_repeats=repeats)
|
||||||
|
p = APP(data, sample_size=5, n_prevalences=n_prevpoints, random_state=42, sanity_check=n_combinations)
|
||||||
|
p = APP(data, sample_size=5, n_prevalences=n_prevpoints, random_state=42, sanity_check=None)
|
||||||
|
|
||||||
def test_app_replicate(self):
|
def test_app_replicate(self):
|
||||||
data = mock_labelled_collection()
|
data = mock_labelled_collection()
|
||||||
p = APP(data, sample_size=5, n_prevalences=11, random_state=42)
|
p = APP(data, sample_size=5, n_prevalences=11, random_state=42)
|
||||||
|
|
Loading…
Reference in New Issue