force all samples be with replacement in base.LabelledCollection, irrespective of the sample size requested

This commit is contained in:
Alejandro Moreo Fernandez 2024-02-28 08:46:54 +01:00
parent d50a86daf4
commit 75af15ae4a
1 changed files with 6 additions and 10 deletions

View File

@ -108,8 +108,7 @@ class LabelledCollection:
"""
Returns an index to be used to extract a random sample of desired size and desired prevalence values. If the
prevalence values are not specified, then returns the index of a uniform sampling.
For each class, the sampling is drawn with replacement if the requested prevalence is larger than
the actual prevalence of the class, or without replacement otherwise.
For each class, the sampling is drawn with replacement.
:param size: integer, the requested size
:param prevs: the prevalence for each class; the prevalence value for the last class can be lead empty since
@ -153,7 +152,7 @@ class LabelledCollection:
for class_, n_requested in n_requests.items():
n_candidates = len(self.index[class_])
index_sample = self.index[class_][
np.random.choice(n_candidates, size=n_requested, replace=(n_requested > n_candidates))
np.random.choice(n_candidates, size=n_requested, replace=True)
] if n_requested > 0 else []
indexes_sample.append(index_sample)
@ -168,8 +167,7 @@ class LabelledCollection:
def uniform_sampling_index(self, size, random_state=None):
"""
Returns an index to be used to extract a uniform sample of desired size. The sampling is drawn
with replacement if the requested size is greater than the number of instances, or without replacement
otherwise.
with replacement.
:param size: integer, the size of the uniform sample
:param random_state: if specified, guarantees reproducibility of the split.
@ -179,13 +177,12 @@ class LabelledCollection:
ng = RandomState(seed=random_state)
else:
ng = np.random
return ng.choice(len(self), size, replace=size > len(self))
return ng.choice(len(self), size, replace=True)
def sampling(self, size, *prevs, shuffle=True, random_state=None):
"""
Return a random sample (an instance of :class:`LabelledCollection`) of desired size and desired prevalence
values. For each class, the sampling is drawn without replacement if the requested prevalence is larger than
the actual prevalence of the class, or with replacement otherwise.
values. For each class, the sampling is drawn with replacement.
:param size: integer, the requested size
:param prevs: the prevalence for each class; the prevalence value for the last class can be lead empty since
@ -202,8 +199,7 @@ class LabelledCollection:
def uniform_sampling(self, size, random_state=None):
"""
Returns a uniform sample (an instance of :class:`LabelledCollection`) of desired size. The sampling is drawn
with replacement if the requested size is greater than the number of instances, or without replacement
otherwise.
with replacement.
:param size: integer, the requested size
:param random_state: if specified, guarantees reproducibility of the split.