1
0
Fork 0

changing app to use prevalence_linspace function with smooth limits

This commit is contained in:
Pablo González 2022-06-24 14:05:47 +02:00
parent cf7d37c793
commit 02dd2846ff
1 changed files with 4 additions and 2 deletions

View File

@ -132,15 +132,17 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
:param n_prevalences: the number of equidistant prevalence points to extract from the [0,1] interval for the
grid (default is 21)
:param repeats: number of copies for each valid prevalence vector (default is 10)
:param smooth_limits_epsilon: the quantity to add and subtract to the limits 0 and 1
:param random_state: allows replicating samples across runs (default None)
"""
def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, random_state=None, return_type='sample_prev'):
def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, smooth_limits_epsilon=0, random_state=None, return_type='sample_prev'):
super(APP, self).__init__(random_state)
self.data = data
self.sample_size = sample_size
self.n_prevalences = n_prevalences
self.repeats = repeats
self.smooth_limits_epsilon = smooth_limits_epsilon
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
def prevalence_grid(self):
@ -159,7 +161,7 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
in the grid multiplied by `repeat`
"""
dimensions = self.data.n_classes
s = np.linspace(0., 1., self.n_prevalences, endpoint=True)
s = F.prevalence_linspace(self.n_prevalences, repeats=1, smooth_limits_epsilon=self.smooth_limits_epsilon)
s = [s] * (dimensions - 1)
prevs = [p for p in itertools.product(*s, repeat=1) if (sum(p) <= 1.0)]
prevs = np.asarray(prevs).reshape(len(prevs), -1)