collator functions in protocols for preparing the outputs

This commit is contained in:
Alejandro Moreo Fernandez 2022-06-03 18:02:52 +02:00
parent bfe4b8b51a
commit 82a01478ec
4 changed files with 27 additions and 19 deletions

View File

@ -63,10 +63,9 @@ class LabelledCollection:
"""
return self.instances.shape[0]
@property
def prevalence(self):
"""
Returns the prevalence, or relative frequency, of the classes of interest.
Returns the prevalence, or relative frequency, of the classes in the codeframe.
:return: a np.ndarray of shape `(n_classes)` with the relative frequencies of each class, in the same order
as listed by `self.classes_`
@ -75,7 +74,7 @@ class LabelledCollection:
def counts(self):
"""
Returns the number of instances for each of the classes of interest.
Returns the number of instances for each of the classes in the codeframe.
:return: a np.ndarray of shape `(n_classes)` with the number of instances of each class, in the same order
as listed by `self.classes_`
@ -252,7 +251,8 @@ class LabelledCollection:
@property
def Xp(self):
"""
Gets the instances and the true prevalence. This is useful when implementing evaluation protocols
Gets the instances and the true prevalence. This is useful when implementing evaluation protocols from
a `LabelledCollection` object.
:return: a tuple `(instances, prevalence)` from this collection
"""
@ -420,6 +420,16 @@ class Dataset:
"""
return len(self.vocabulary)
@property
def train_test(self):
"""
Alias to `self.training` and `self.test`
:return: the training and test collections
:return: the training and test collections
"""
return self.training, self.test
def stats(self, show):
"""
Returns (and eventually prints) a dictionary with some stats of this dataset. E.g.,:

View File

@ -135,13 +135,13 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
:param random_seed: allows replicating samples across runs (default None)
"""
def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, random_seed=None):
def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, random_seed=None, return_type='sample_prev'):
super(APP, self).__init__(random_seed)
self.data = data
self.sample_size = sample_size
self.n_prevalences = n_prevalences
self.repeats = repeats
self.set_collator(collator_fn=lambda x: (x.instances, x.prevalence()))
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
def prevalence_grid(self):
"""
@ -192,13 +192,13 @@ class NPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
:param random_seed: allows replicating samples across runs (default None)
"""
def __init__(self, data:LabelledCollection, sample_size, repeats=100, random_seed=None):
def __init__(self, data:LabelledCollection, sample_size, repeats=100, random_seed=None, return_type='sample_prev'):
super(NPP, self).__init__(random_seed)
self.data = data
self.sample_size = sample_size
self.repeats = repeats
self.random_seed = random_seed
self.set_collator(collator_fn=lambda x: (x.instances, x.prevalence()))
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
def samples_parameters(self):
indexes = []
@ -229,13 +229,13 @@ class USimplexPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol)
:param random_seed: allows replicating samples across runs (default None)
"""
def __init__(self, data: LabelledCollection, sample_size, repeats=100, random_seed=None):
def __init__(self, data: LabelledCollection, sample_size, repeats=100, random_seed=None, return_type='sample_prev'):
super(USimplexPP, self).__init__(random_seed)
self.data = data
self.sample_size = sample_size
self.repeats = repeats
self.random_seed = random_seed
self.set_collator(collator_fn=lambda x: (x.instances, x.prevalence()))
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
def samples_parameters(self):
indexes = []
@ -339,7 +339,7 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol):
indexesA, indexesB = indexes
sampleA = self.A.sampling_from_index(indexesA)
sampleB = self.B.sampling_from_index(indexesB)
return sampleA+sampleB
return (sampleA+sampleB).Xp
def total(self):
return self.repeats * len(self.mixture_points)

View File

@ -46,9 +46,7 @@ def test_fetch_UCIDataset(dataset_name):
@pytest.mark.parametrize('dataset_name', LEQUA2022_TASKS)
def test_fetch_lequa2022(dataset_name):
fetch_lequa2022(dataset_name)
# dataset = fetch_lequa2022(dataset_name)
# print(f'Dataset {dataset_name}')
# print('Training set stats')
# dataset.training.stats()
# print('Test set stats')
train, gen_val, gen_test = fetch_lequa2022(dataset_name)
print(train.stats())
print('Val:', gen_val.total())
print('Test:', gen_test.total())

View File

@ -12,8 +12,8 @@ def mock_labelled_collection(prefix=''):
def samples_to_str(protocol):
samples_str = ""
for sample in protocol():
samples_str += f'{sample.instances}\t{sample.labels}\t{sample.prevalence()}\n'
for instances, prev in protocol():
samples_str += f'{instances}\t{prev}\n'
return samples_str