forked from moreo/QuaPy
collator functions in protocols for preparing the outputs
This commit is contained in:
parent
bfe4b8b51a
commit
82a01478ec
|
@ -63,10 +63,9 @@ class LabelledCollection:
|
||||||
"""
|
"""
|
||||||
return self.instances.shape[0]
|
return self.instances.shape[0]
|
||||||
|
|
||||||
@property
|
|
||||||
def prevalence(self):
|
def prevalence(self):
|
||||||
"""
|
"""
|
||||||
Returns the prevalence, or relative frequency, of the classes of interest.
|
Returns the prevalence, or relative frequency, of the classes in the codeframe.
|
||||||
|
|
||||||
:return: a np.ndarray of shape `(n_classes)` with the relative frequencies of each class, in the same order
|
:return: a np.ndarray of shape `(n_classes)` with the relative frequencies of each class, in the same order
|
||||||
as listed by `self.classes_`
|
as listed by `self.classes_`
|
||||||
|
@ -75,7 +74,7 @@ class LabelledCollection:
|
||||||
|
|
||||||
def counts(self):
|
def counts(self):
|
||||||
"""
|
"""
|
||||||
Returns the number of instances for each of the classes of interest.
|
Returns the number of instances for each of the classes in the codeframe.
|
||||||
|
|
||||||
:return: a np.ndarray of shape `(n_classes)` with the number of instances of each class, in the same order
|
:return: a np.ndarray of shape `(n_classes)` with the number of instances of each class, in the same order
|
||||||
as listed by `self.classes_`
|
as listed by `self.classes_`
|
||||||
|
@ -252,7 +251,8 @@ class LabelledCollection:
|
||||||
@property
|
@property
|
||||||
def Xp(self):
|
def Xp(self):
|
||||||
"""
|
"""
|
||||||
Gets the instances and the true prevalence. This is useful when implementing evaluation protocols
|
Gets the instances and the true prevalence. This is useful when implementing evaluation protocols from
|
||||||
|
a `LabelledCollection` object.
|
||||||
|
|
||||||
:return: a tuple `(instances, prevalence)` from this collection
|
:return: a tuple `(instances, prevalence)` from this collection
|
||||||
"""
|
"""
|
||||||
|
@ -420,6 +420,16 @@ class Dataset:
|
||||||
"""
|
"""
|
||||||
return len(self.vocabulary)
|
return len(self.vocabulary)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def train_test(self):
|
||||||
|
"""
|
||||||
|
Alias to `self.training` and `self.test`
|
||||||
|
|
||||||
|
:return: the training and test collections
|
||||||
|
:return: the training and test collections
|
||||||
|
"""
|
||||||
|
return self.training, self.test
|
||||||
|
|
||||||
def stats(self, show):
|
def stats(self, show):
|
||||||
"""
|
"""
|
||||||
Returns (and eventually prints) a dictionary with some stats of this dataset. E.g.,:
|
Returns (and eventually prints) a dictionary with some stats of this dataset. E.g.,:
|
||||||
|
|
|
@ -135,13 +135,13 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
||||||
:param random_seed: allows replicating samples across runs (default None)
|
:param random_seed: allows replicating samples across runs (default None)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, random_seed=None):
|
def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, random_seed=None, return_type='sample_prev'):
|
||||||
super(APP, self).__init__(random_seed)
|
super(APP, self).__init__(random_seed)
|
||||||
self.data = data
|
self.data = data
|
||||||
self.sample_size = sample_size
|
self.sample_size = sample_size
|
||||||
self.n_prevalences = n_prevalences
|
self.n_prevalences = n_prevalences
|
||||||
self.repeats = repeats
|
self.repeats = repeats
|
||||||
self.set_collator(collator_fn=lambda x: (x.instances, x.prevalence()))
|
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
|
||||||
|
|
||||||
def prevalence_grid(self):
|
def prevalence_grid(self):
|
||||||
"""
|
"""
|
||||||
|
@ -192,13 +192,13 @@ class NPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
||||||
:param random_seed: allows replicating samples across runs (default None)
|
:param random_seed: allows replicating samples across runs (default None)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data:LabelledCollection, sample_size, repeats=100, random_seed=None):
|
def __init__(self, data:LabelledCollection, sample_size, repeats=100, random_seed=None, return_type='sample_prev'):
|
||||||
super(NPP, self).__init__(random_seed)
|
super(NPP, self).__init__(random_seed)
|
||||||
self.data = data
|
self.data = data
|
||||||
self.sample_size = sample_size
|
self.sample_size = sample_size
|
||||||
self.repeats = repeats
|
self.repeats = repeats
|
||||||
self.random_seed = random_seed
|
self.random_seed = random_seed
|
||||||
self.set_collator(collator_fn=lambda x: (x.instances, x.prevalence()))
|
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
|
||||||
|
|
||||||
def samples_parameters(self):
|
def samples_parameters(self):
|
||||||
indexes = []
|
indexes = []
|
||||||
|
@ -229,13 +229,13 @@ class USimplexPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol)
|
||||||
:param random_seed: allows replicating samples across runs (default None)
|
:param random_seed: allows replicating samples across runs (default None)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data: LabelledCollection, sample_size, repeats=100, random_seed=None):
|
def __init__(self, data: LabelledCollection, sample_size, repeats=100, random_seed=None, return_type='sample_prev'):
|
||||||
super(USimplexPP, self).__init__(random_seed)
|
super(USimplexPP, self).__init__(random_seed)
|
||||||
self.data = data
|
self.data = data
|
||||||
self.sample_size = sample_size
|
self.sample_size = sample_size
|
||||||
self.repeats = repeats
|
self.repeats = repeats
|
||||||
self.random_seed = random_seed
|
self.random_seed = random_seed
|
||||||
self.set_collator(collator_fn=lambda x: (x.instances, x.prevalence()))
|
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
|
||||||
|
|
||||||
def samples_parameters(self):
|
def samples_parameters(self):
|
||||||
indexes = []
|
indexes = []
|
||||||
|
@ -339,7 +339,7 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol):
|
||||||
indexesA, indexesB = indexes
|
indexesA, indexesB = indexes
|
||||||
sampleA = self.A.sampling_from_index(indexesA)
|
sampleA = self.A.sampling_from_index(indexesA)
|
||||||
sampleB = self.B.sampling_from_index(indexesB)
|
sampleB = self.B.sampling_from_index(indexesB)
|
||||||
return sampleA+sampleB
|
return (sampleA+sampleB).Xp
|
||||||
|
|
||||||
def total(self):
|
def total(self):
|
||||||
return self.repeats * len(self.mixture_points)
|
return self.repeats * len(self.mixture_points)
|
||||||
|
|
|
@ -46,9 +46,7 @@ def test_fetch_UCIDataset(dataset_name):
|
||||||
|
|
||||||
@pytest.mark.parametrize('dataset_name', LEQUA2022_TASKS)
|
@pytest.mark.parametrize('dataset_name', LEQUA2022_TASKS)
|
||||||
def test_fetch_lequa2022(dataset_name):
|
def test_fetch_lequa2022(dataset_name):
|
||||||
fetch_lequa2022(dataset_name)
|
train, gen_val, gen_test = fetch_lequa2022(dataset_name)
|
||||||
# dataset = fetch_lequa2022(dataset_name)
|
print(train.stats())
|
||||||
# print(f'Dataset {dataset_name}')
|
print('Val:', gen_val.total())
|
||||||
# print('Training set stats')
|
print('Test:', gen_test.total())
|
||||||
# dataset.training.stats()
|
|
||||||
# print('Test set stats')
|
|
||||||
|
|
|
@ -12,8 +12,8 @@ def mock_labelled_collection(prefix=''):
|
||||||
|
|
||||||
def samples_to_str(protocol):
|
def samples_to_str(protocol):
|
||||||
samples_str = ""
|
samples_str = ""
|
||||||
for sample in protocol():
|
for instances, prev in protocol():
|
||||||
samples_str += f'{sample.instances}\t{sample.labels}\t{sample.prevalence()}\n'
|
samples_str += f'{instances}\t{prev}\n'
|
||||||
return samples_str
|
return samples_str
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue