1
0
Fork 0

collator functions in protocols for preparing the outputs

This commit is contained in:
Alejandro Moreo Fernandez 2022-06-03 18:02:52 +02:00
parent bfe4b8b51a
commit 82a01478ec
4 changed files with 27 additions and 19 deletions

View File

@ -63,10 +63,9 @@ class LabelledCollection:
""" """
return self.instances.shape[0] return self.instances.shape[0]
@property
def prevalence(self): def prevalence(self):
""" """
Returns the prevalence, or relative frequency, of the classes of interest. Returns the prevalence, or relative frequency, of the classes in the codeframe.
:return: a np.ndarray of shape `(n_classes)` with the relative frequencies of each class, in the same order :return: a np.ndarray of shape `(n_classes)` with the relative frequencies of each class, in the same order
as listed by `self.classes_` as listed by `self.classes_`
@ -75,7 +74,7 @@ class LabelledCollection:
def counts(self): def counts(self):
""" """
Returns the number of instances for each of the classes of interest. Returns the number of instances for each of the classes in the codeframe.
:return: a np.ndarray of shape `(n_classes)` with the number of instances of each class, in the same order :return: a np.ndarray of shape `(n_classes)` with the number of instances of each class, in the same order
as listed by `self.classes_` as listed by `self.classes_`
@ -252,7 +251,8 @@ class LabelledCollection:
@property @property
def Xp(self): def Xp(self):
""" """
Gets the instances and the true prevalence. This is useful when implementing evaluation protocols Gets the instances and the true prevalence. This is useful when implementing evaluation protocols from
a `LabelledCollection` object.
:return: a tuple `(instances, prevalence)` from this collection :return: a tuple `(instances, prevalence)` from this collection
""" """
@ -420,6 +420,16 @@ class Dataset:
""" """
return len(self.vocabulary) return len(self.vocabulary)
@property
def train_test(self):
"""
Alias to `self.training` and `self.test`
:return: the training and test collections
:return: the training and test collections
"""
return self.training, self.test
def stats(self, show): def stats(self, show):
""" """
Returns (and eventually prints) a dictionary with some stats of this dataset. E.g.,: Returns (and eventually prints) a dictionary with some stats of this dataset. E.g.,:

View File

@ -135,13 +135,13 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
:param random_seed: allows replicating samples across runs (default None) :param random_seed: allows replicating samples across runs (default None)
""" """
def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, random_seed=None): def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, random_seed=None, return_type='sample_prev'):
super(APP, self).__init__(random_seed) super(APP, self).__init__(random_seed)
self.data = data self.data = data
self.sample_size = sample_size self.sample_size = sample_size
self.n_prevalences = n_prevalences self.n_prevalences = n_prevalences
self.repeats = repeats self.repeats = repeats
self.set_collator(collator_fn=lambda x: (x.instances, x.prevalence())) self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
def prevalence_grid(self): def prevalence_grid(self):
""" """
@ -192,13 +192,13 @@ class NPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
:param random_seed: allows replicating samples across runs (default None) :param random_seed: allows replicating samples across runs (default None)
""" """
def __init__(self, data:LabelledCollection, sample_size, repeats=100, random_seed=None): def __init__(self, data:LabelledCollection, sample_size, repeats=100, random_seed=None, return_type='sample_prev'):
super(NPP, self).__init__(random_seed) super(NPP, self).__init__(random_seed)
self.data = data self.data = data
self.sample_size = sample_size self.sample_size = sample_size
self.repeats = repeats self.repeats = repeats
self.random_seed = random_seed self.random_seed = random_seed
self.set_collator(collator_fn=lambda x: (x.instances, x.prevalence())) self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
def samples_parameters(self): def samples_parameters(self):
indexes = [] indexes = []
@ -229,13 +229,13 @@ class USimplexPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol)
:param random_seed: allows replicating samples across runs (default None) :param random_seed: allows replicating samples across runs (default None)
""" """
def __init__(self, data: LabelledCollection, sample_size, repeats=100, random_seed=None): def __init__(self, data: LabelledCollection, sample_size, repeats=100, random_seed=None, return_type='sample_prev'):
super(USimplexPP, self).__init__(random_seed) super(USimplexPP, self).__init__(random_seed)
self.data = data self.data = data
self.sample_size = sample_size self.sample_size = sample_size
self.repeats = repeats self.repeats = repeats
self.random_seed = random_seed self.random_seed = random_seed
self.set_collator(collator_fn=lambda x: (x.instances, x.prevalence())) self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
def samples_parameters(self): def samples_parameters(self):
indexes = [] indexes = []
@ -339,7 +339,7 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol):
indexesA, indexesB = indexes indexesA, indexesB = indexes
sampleA = self.A.sampling_from_index(indexesA) sampleA = self.A.sampling_from_index(indexesA)
sampleB = self.B.sampling_from_index(indexesB) sampleB = self.B.sampling_from_index(indexesB)
return sampleA+sampleB return (sampleA+sampleB).Xp
def total(self): def total(self):
return self.repeats * len(self.mixture_points) return self.repeats * len(self.mixture_points)

View File

@ -46,9 +46,7 @@ def test_fetch_UCIDataset(dataset_name):
@pytest.mark.parametrize('dataset_name', LEQUA2022_TASKS) @pytest.mark.parametrize('dataset_name', LEQUA2022_TASKS)
def test_fetch_lequa2022(dataset_name): def test_fetch_lequa2022(dataset_name):
fetch_lequa2022(dataset_name) train, gen_val, gen_test = fetch_lequa2022(dataset_name)
# dataset = fetch_lequa2022(dataset_name) print(train.stats())
# print(f'Dataset {dataset_name}') print('Val:', gen_val.total())
# print('Training set stats') print('Test:', gen_test.total())
# dataset.training.stats()
# print('Test set stats')

View File

@ -12,8 +12,8 @@ def mock_labelled_collection(prefix=''):
def samples_to_str(protocol): def samples_to_str(protocol):
samples_str = "" samples_str = ""
for sample in protocol(): for instances, prev in protocol():
samples_str += f'{sample.instances}\t{sample.labels}\t{sample.prevalence()}\n' samples_str += f'{instances}\t{prev}\n'
return samples_str return samples_str