forked from moreo/QuaPy
trying to figure out how to refactor protocols meaninguflly
This commit is contained in:
parent
cfdf2e35bd
commit
ba18d00334
|
@ -10,7 +10,7 @@ from . import model_selection
|
|||
from . import classification
|
||||
from quapy.method.base import isprobabilistic, isaggregative
|
||||
|
||||
__version__ = '0.1.6'
|
||||
__version__ = '0.1.7'
|
||||
|
||||
environ = {
|
||||
'SAMPLE_SIZE': None,
|
||||
|
|
|
@ -3,7 +3,7 @@ from scipy.sparse import issparse
|
|||
from scipy.sparse import vstack
|
||||
from sklearn.model_selection import train_test_split, RepeatedStratifiedKFold
|
||||
|
||||
from quapy.functional import artificial_prevalence_sampling, strprev
|
||||
from quapy.functional import strprev
|
||||
|
||||
|
||||
class LabelledCollection:
|
||||
|
@ -120,21 +120,24 @@ class LabelledCollection:
|
|||
assert len(prevs) == self.n_classes, 'unexpected number of prevalences'
|
||||
assert sum(prevs) == 1, f'prevalences ({prevs}) wrong range (sum={sum(prevs)})'
|
||||
|
||||
taken = 0
|
||||
indexes_sample = []
|
||||
for i, class_ in enumerate(self.classes_):
|
||||
if i == self.n_classes - 1:
|
||||
n_requested = size - taken
|
||||
else:
|
||||
n_requested = int(size * prevs[i])
|
||||
# Decide how many instances should be taken for each class in order to satisfy the requested prevalence
|
||||
# accurately, and the number of instances in the sample (exactly). If int(size * prevs[i]) (which is
|
||||
# <= size * prevs[i]) examples are drawn from class i, there could be a remainder number of instances to take
|
||||
# to satisfy the size constrain. The remainder is distributed along the classes with probability = prevs.
|
||||
# (This aims at avoiding the remainder to be placed in a class for which the prevalence requested is 0.)
|
||||
n_requests = {class_: int(size * prevs[i]) for i, class_ in enumerate(self.classes_)}
|
||||
remainder = size - sum(n_requests.values())
|
||||
for rand_class in np.random.choice(self.classes_, size=remainder, p=prevs):
|
||||
n_requests[rand_class] += 1
|
||||
|
||||
indexes_sample = []
|
||||
for class_, n_requested in n_requests.items():
|
||||
n_candidates = len(self.index[class_])
|
||||
index_sample = self.index[class_][
|
||||
np.random.choice(n_candidates, size=n_requested, replace=(n_requested > n_candidates))
|
||||
] if n_requested > 0 else []
|
||||
|
||||
indexes_sample.append(index_sample)
|
||||
taken += n_requested
|
||||
|
||||
indexes_sample = np.concatenate(indexes_sample).astype(int)
|
||||
|
||||
|
@ -152,7 +155,7 @@ class LabelledCollection:
|
|||
:param size: integer, the size of the uniform sample
|
||||
:return: a np.ndarray of shape `(size)` with the indexes
|
||||
"""
|
||||
return np.random.choice(len(self), size, replace=False)
|
||||
return np.random.choice(len(self), size, replace=size > len(self))
|
||||
|
||||
def sampling(self, size, *prevs, shuffle=True):
|
||||
"""
|
||||
|
@ -212,68 +215,6 @@ class LabelledCollection:
|
|||
random_state=random_state)
|
||||
return LabelledCollection(tr_docs, tr_labels), LabelledCollection(te_docs, te_labels)
|
||||
|
||||
def artificial_sampling_generator(self, sample_size, n_prevalences=101, repeats=1):
|
||||
"""
|
||||
A generator of samples that implements the artificial prevalence protocol (APP).
|
||||
The APP consists of exploring a grid of prevalence values containing `n_prevalences` points (e.g.,
|
||||
[0, 0.05, 0.1, 0.15, ..., 1], if `n_prevalences=21`), and generating all valid combinations of
|
||||
prevalence values for all classes (e.g., for 3 classes, samples with [0, 0, 1], [0, 0.05, 0.95], ...,
|
||||
[1, 0, 0] prevalence values of size `sample_size` will be yielded). The number of samples for each valid
|
||||
combination of prevalence values is indicated by `repeats`.
|
||||
|
||||
:param sample_size: the number of instances in each sample
|
||||
:param n_prevalences: the number of prevalence points to be taken from the [0,1] interval (including the
|
||||
limits {0,1}). E.g., if `n_prevalences=11`, then the prevalence points to take are [0, 0.1, 0.2, ..., 1]
|
||||
:param repeats: the number of samples to generate for each valid combination of prevalence values (default 1)
|
||||
:return: yield samples generated at artificially controlled prevalence values
|
||||
"""
|
||||
dimensions = self.n_classes
|
||||
for prevs in artificial_prevalence_sampling(dimensions, n_prevalences, repeats):
|
||||
yield self.sampling(sample_size, *prevs)
|
||||
|
||||
def artificial_sampling_index_generator(self, sample_size, n_prevalences=101, repeats=1):
|
||||
"""
|
||||
A generator of sample indexes implementing the artificial prevalence protocol (APP).
|
||||
The APP consists of exploring
|
||||
a grid of prevalence values (e.g., [0, 0.05, 0.1, 0.15, ..., 1]), and generating all valid combinations of
|
||||
prevalence values for all classes (e.g., for 3 classes, samples with [0, 0, 1], [0, 0.05, 0.95], ...,
|
||||
[1, 0, 0] prevalence values of size `sample_size` will be yielded). The number of sample indexes for each valid
|
||||
combination of prevalence values is indicated by `repeats`
|
||||
|
||||
:param sample_size: the number of instances in each sample (i.e., length of each index)
|
||||
:param n_prevalences: the number of prevalence points to be taken from the [0,1] interval (including the
|
||||
limits {0,1}). E.g., if `n_prevalences=11`, then the prevalence points to take are [0, 0.1, 0.2, ..., 1]
|
||||
:param repeats: the number of samples to generate for each valid combination of prevalence values (default 1)
|
||||
:return: yield the indexes that generate the samples according to APP
|
||||
"""
|
||||
dimensions = self.n_classes
|
||||
for prevs in artificial_prevalence_sampling(dimensions, n_prevalences, repeats):
|
||||
yield self.sampling_index(sample_size, *prevs)
|
||||
|
||||
def natural_sampling_generator(self, sample_size, repeats=100):
|
||||
"""
|
||||
A generator of samples that implements the natural prevalence protocol (NPP). The NPP consists of drawing
|
||||
samples uniformly at random, therefore approximately preserving the natural prevalence of the collection.
|
||||
|
||||
:param sample_size: integer, the number of instances in each sample
|
||||
:param repeats: the number of samples to generate
|
||||
:return: yield instances of :class:`LabelledCollection`
|
||||
"""
|
||||
for _ in range(repeats):
|
||||
yield self.uniform_sampling(sample_size)
|
||||
|
||||
def natural_sampling_index_generator(self, sample_size, repeats=100):
|
||||
"""
|
||||
A generator of sample indexes according to the natural prevalence protocol (NPP). The NPP consists of drawing
|
||||
samples uniformly at random, therefore approximately preserving the natural prevalence of the collection.
|
||||
|
||||
:param sample_size: integer, the number of instances in each sample (i.e., the length of each index)
|
||||
:param repeats: the number of indexes to generate
|
||||
:return: yield `repeats` instances of np.ndarray with shape `(sample_size,)`
|
||||
"""
|
||||
for _ in range(repeats):
|
||||
yield self.uniform_sampling_index(sample_size)
|
||||
|
||||
def __add__(self, other):
|
||||
"""
|
||||
Returns a new :class:`LabelledCollection` as the union of this collection with another collection
|
||||
|
|
|
@ -4,36 +4,6 @@ import scipy
|
|||
import numpy as np
|
||||
|
||||
|
||||
def artificial_prevalence_sampling(dimensions, n_prevalences=21, repeat=1, return_constrained_dim=False):
|
||||
"""
|
||||
Generates vectors of prevalence values artificially drawn from an exhaustive grid of prevalence values. The
|
||||
number of prevalence values explored for each dimension depends on `n_prevalences`, so that, if, for example,
|
||||
`n_prevalences=11` then the prevalence values of the grid are taken from [0, 0.1, 0.2, ..., 0.9, 1]. Only
|
||||
valid prevalence distributions are returned, i.e., vectors of prevalence values that sum up to 1. For each
|
||||
valid vector of prevalence values, `repeat` copies are returned. The vector of prevalence values can be
|
||||
implicit (by setting `return_constrained_dim=False`), meaning that the last dimension (which is constrained
|
||||
to 1 - sum of the rest) is not returned (note that, quite obviously, in this case the vector does not sum up to 1).
|
||||
|
||||
:param dimensions: the number of classes
|
||||
:param n_prevalences: the number of equidistant prevalence points to extract from the [0,1] interval for the grid
|
||||
(default is 21)
|
||||
:param repeat: number of copies for each valid prevalence vector (default is 1)
|
||||
:param return_constrained_dim: set to True to return all dimensions, or to False (default) for ommitting the
|
||||
constrained dimension
|
||||
:return: a `np.ndarray` of shape `(n, dimensions)` if `return_constrained_dim=True` or of shape `(n, dimensions-1)`
|
||||
if `return_constrained_dim=False`, where `n` is the number of valid combinations found in the grid multiplied
|
||||
by `repeat`
|
||||
"""
|
||||
s = np.linspace(0., 1., n_prevalences, endpoint=True)
|
||||
s = [s] * (dimensions - 1)
|
||||
prevs = [p for p in itertools.product(*s, repeat=1) if sum(p)<=1]
|
||||
if return_constrained_dim:
|
||||
prevs = [p+(1-sum(p),) for p in prevs]
|
||||
prevs = np.asarray(prevs).reshape(len(prevs), -1)
|
||||
if repeat>1:
|
||||
prevs = np.repeat(prevs, repeat, axis=0)
|
||||
return prevs
|
||||
|
||||
|
||||
def prevalence_linspace(n_prevalences=21, repeats=1, smooth_limits_epsilon=0.01):
|
||||
"""
|
||||
|
|
|
@ -21,7 +21,7 @@ class GridSearchQ(BaseQuantifier):
|
|||
:param model: the quantifier to optimize
|
||||
:type model: BaseQuantifier
|
||||
:param param_grid: a dictionary with keys the parameter names and values the list of values to explore
|
||||
:param sample_size: the size of the samples to extract from the validation set (ignored if protocl='gen')
|
||||
:param sample_size: the size of the samples to extract from the validation set (ignored if protocol='gen')
|
||||
:param protocol: either 'app' for the artificial prevalence protocol, 'npp' for the natural prevalence
|
||||
protocol, or 'gen' for using a custom sampling generator function
|
||||
:param n_prevpoints: if specified, indicates the number of equally distant points to extract from the interval
|
||||
|
|
Loading…
Reference in New Issue