From 82a01478ec80eeb1fead5c320bb5e329a7ee9441 Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Fri, 3 Jun 2022 18:02:52 +0200 Subject: [PATCH] collator functions in protocols for preparing the outputs --- quapy/data/base.py | 18 ++++++++++++++---- quapy/protocol.py | 14 +++++++------- quapy/tests/test_datasets.py | 10 ++++------ quapy/tests/test_protocols.py | 4 ++-- 4 files changed, 27 insertions(+), 19 deletions(-) diff --git a/quapy/data/base.py b/quapy/data/base.py index b22a71f..4601c15 100644 --- a/quapy/data/base.py +++ b/quapy/data/base.py @@ -63,10 +63,9 @@ class LabelledCollection: """ return self.instances.shape[0] - @property def prevalence(self): """ - Returns the prevalence, or relative frequency, of the classes of interest. + Returns the prevalence, or relative frequency, of the classes in the codeframe. :return: a np.ndarray of shape `(n_classes)` with the relative frequencies of each class, in the same order as listed by `self.classes_` @@ -75,7 +74,7 @@ class LabelledCollection: def counts(self): """ - Returns the number of instances for each of the classes of interest. + Returns the number of instances for each of the classes in the codeframe. :return: a np.ndarray of shape `(n_classes)` with the number of instances of each class, in the same order as listed by `self.classes_` @@ -252,7 +251,8 @@ class LabelledCollection: @property def Xp(self): """ - Gets the instances and the true prevalence. This is useful when implementing evaluation protocols + Gets the instances and the true prevalence. This is useful when implementing evaluation protocols from + a `LabelledCollection` object. :return: a tuple `(instances, prevalence)` from this collection """ @@ -420,6 +420,16 @@ class Dataset: """ return len(self.vocabulary) + @property + def train_test(self): + """ + Alias to `self.training` and `self.test` + + :return: the training and test collections + :return: the training and test collections + """ + return self.training, self.test + def stats(self, show): """ Returns (and eventually prints) a dictionary with some stats of this dataset. E.g.,: diff --git a/quapy/protocol.py b/quapy/protocol.py index c55c3ef..fec37ca 100644 --- a/quapy/protocol.py +++ b/quapy/protocol.py @@ -135,13 +135,13 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol): :param random_seed: allows replicating samples across runs (default None) """ - def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, random_seed=None): + def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, random_seed=None, return_type='sample_prev'): super(APP, self).__init__(random_seed) self.data = data self.sample_size = sample_size self.n_prevalences = n_prevalences self.repeats = repeats - self.set_collator(collator_fn=lambda x: (x.instances, x.prevalence())) + self.collator = OnLabelledCollectionProtocol.get_collator(return_type) def prevalence_grid(self): """ @@ -192,13 +192,13 @@ class NPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol): :param random_seed: allows replicating samples across runs (default None) """ - def __init__(self, data:LabelledCollection, sample_size, repeats=100, random_seed=None): + def __init__(self, data:LabelledCollection, sample_size, repeats=100, random_seed=None, return_type='sample_prev'): super(NPP, self).__init__(random_seed) self.data = data self.sample_size = sample_size self.repeats = repeats self.random_seed = random_seed - self.set_collator(collator_fn=lambda x: (x.instances, x.prevalence())) + self.collator = OnLabelledCollectionProtocol.get_collator(return_type) def samples_parameters(self): indexes = [] @@ -229,13 +229,13 @@ class USimplexPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol) :param random_seed: allows replicating samples across runs (default None) """ - def __init__(self, data: LabelledCollection, sample_size, repeats=100, random_seed=None): + def __init__(self, data: LabelledCollection, sample_size, repeats=100, random_seed=None, return_type='sample_prev'): super(USimplexPP, self).__init__(random_seed) self.data = data self.sample_size = sample_size self.repeats = repeats self.random_seed = random_seed - self.set_collator(collator_fn=lambda x: (x.instances, x.prevalence())) + self.collator = OnLabelledCollectionProtocol.get_collator(return_type) def samples_parameters(self): indexes = [] @@ -339,7 +339,7 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol): indexesA, indexesB = indexes sampleA = self.A.sampling_from_index(indexesA) sampleB = self.B.sampling_from_index(indexesB) - return sampleA+sampleB + return (sampleA+sampleB).Xp def total(self): return self.repeats * len(self.mixture_points) diff --git a/quapy/tests/test_datasets.py b/quapy/tests/test_datasets.py index 8d70fe9..b0c2f7a 100644 --- a/quapy/tests/test_datasets.py +++ b/quapy/tests/test_datasets.py @@ -46,9 +46,7 @@ def test_fetch_UCIDataset(dataset_name): @pytest.mark.parametrize('dataset_name', LEQUA2022_TASKS) def test_fetch_lequa2022(dataset_name): - fetch_lequa2022(dataset_name) - # dataset = fetch_lequa2022(dataset_name) - # print(f'Dataset {dataset_name}') - # print('Training set stats') - # dataset.training.stats() - # print('Test set stats') \ No newline at end of file + train, gen_val, gen_test = fetch_lequa2022(dataset_name) + print(train.stats()) + print('Val:', gen_val.total()) + print('Test:', gen_test.total()) diff --git a/quapy/tests/test_protocols.py b/quapy/tests/test_protocols.py index b68567b..aeb1f4e 100644 --- a/quapy/tests/test_protocols.py +++ b/quapy/tests/test_protocols.py @@ -12,8 +12,8 @@ def mock_labelled_collection(prefix=''): def samples_to_str(protocol): samples_str = "" - for sample in protocol(): - samples_str += f'{sample.instances}\t{sample.labels}\t{sample.prevalence()}\n' + for instances, prev in protocol(): + samples_str += f'{instances}\t{prev}\n' return samples_str