Merge pull request #12 from pglez82/protocols

changing app to use prevalence_linspace function with smooth limits
This commit is contained in:
Alejandro Moreo Fernandez 2022-06-24 14:44:50 +02:00 committed by GitHub
commit 1742b75504
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 2 deletions

View File

@ -223,6 +223,7 @@ def cross_generate_predictions(
# fit the learner on all data
learner.fit(*data.Xy)
y = data.y
classes = data.classes_
else:
learner, val_data = _training_helper(

View File

@ -132,15 +132,17 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
:param n_prevalences: the number of equidistant prevalence points to extract from the [0,1] interval for the
grid (default is 21)
:param repeats: number of copies for each valid prevalence vector (default is 10)
:param smooth_limits_epsilon: the quantity to add and subtract to the limits 0 and 1
:param random_state: allows replicating samples across runs (default None)
"""
def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, random_state=None, return_type='sample_prev'):
def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, smooth_limits_epsilon=0, random_state=None, return_type='sample_prev'):
super(APP, self).__init__(random_state)
self.data = data
self.sample_size = sample_size
self.n_prevalences = n_prevalences
self.repeats = repeats
self.smooth_limits_epsilon = smooth_limits_epsilon
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
def prevalence_grid(self):
@ -159,7 +161,7 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
in the grid multiplied by `repeat`
"""
dimensions = self.data.n_classes
s = np.linspace(0., 1., self.n_prevalences, endpoint=True)
s = F.prevalence_linspace(self.n_prevalences, repeats=1, smooth_limits_epsilon=self.smooth_limits_epsilon)
s = [s] * (dimensions - 1)
prevs = [p for p in itertools.product(*s, repeat=1) if (sum(p) <= 1.0)]
prevs = np.asarray(prevs).reshape(len(prevs), -1)