forked from moreo/QuaPy
changing app to use prevalence_linspace function with smooth limits
This commit is contained in:
parent
cf7d37c793
commit
02dd2846ff
|
@ -132,15 +132,17 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
||||||
:param n_prevalences: the number of equidistant prevalence points to extract from the [0,1] interval for the
|
:param n_prevalences: the number of equidistant prevalence points to extract from the [0,1] interval for the
|
||||||
grid (default is 21)
|
grid (default is 21)
|
||||||
:param repeats: number of copies for each valid prevalence vector (default is 10)
|
:param repeats: number of copies for each valid prevalence vector (default is 10)
|
||||||
|
:param smooth_limits_epsilon: the quantity to add and subtract to the limits 0 and 1
|
||||||
:param random_state: allows replicating samples across runs (default None)
|
:param random_state: allows replicating samples across runs (default None)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, random_state=None, return_type='sample_prev'):
|
def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, smooth_limits_epsilon=0, random_state=None, return_type='sample_prev'):
|
||||||
super(APP, self).__init__(random_state)
|
super(APP, self).__init__(random_state)
|
||||||
self.data = data
|
self.data = data
|
||||||
self.sample_size = sample_size
|
self.sample_size = sample_size
|
||||||
self.n_prevalences = n_prevalences
|
self.n_prevalences = n_prevalences
|
||||||
self.repeats = repeats
|
self.repeats = repeats
|
||||||
|
self.smooth_limits_epsilon = smooth_limits_epsilon
|
||||||
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
|
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
|
||||||
|
|
||||||
def prevalence_grid(self):
|
def prevalence_grid(self):
|
||||||
|
@ -159,7 +161,7 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
||||||
in the grid multiplied by `repeat`
|
in the grid multiplied by `repeat`
|
||||||
"""
|
"""
|
||||||
dimensions = self.data.n_classes
|
dimensions = self.data.n_classes
|
||||||
s = np.linspace(0., 1., self.n_prevalences, endpoint=True)
|
s = F.prevalence_linspace(self.n_prevalences, repeats=1, smooth_limits_epsilon=self.smooth_limits_epsilon)
|
||||||
s = [s] * (dimensions - 1)
|
s = [s] * (dimensions - 1)
|
||||||
prevs = [p for p in itertools.product(*s, repeat=1) if (sum(p) <= 1.0)]
|
prevs = [p for p in itertools.product(*s, repeat=1) if (sum(p) <= 1.0)]
|
||||||
prevs = np.asarray(prevs).reshape(len(prevs), -1)
|
prevs = np.asarray(prevs).reshape(len(prevs), -1)
|
||||||
|
|
Loading…
Reference in New Issue