forked from moreo/QuaPy
collator functions in protocols for preparing the outputs
This commit is contained in:
parent
bfe4b8b51a
commit
82a01478ec
|
@ -63,10 +63,9 @@ class LabelledCollection:
|
|||
"""
|
||||
return self.instances.shape[0]
|
||||
|
||||
@property
|
||||
def prevalence(self):
|
||||
"""
|
||||
Returns the prevalence, or relative frequency, of the classes of interest.
|
||||
Returns the prevalence, or relative frequency, of the classes in the codeframe.
|
||||
|
||||
:return: a np.ndarray of shape `(n_classes)` with the relative frequencies of each class, in the same order
|
||||
as listed by `self.classes_`
|
||||
|
@ -75,7 +74,7 @@ class LabelledCollection:
|
|||
|
||||
def counts(self):
|
||||
"""
|
||||
Returns the number of instances for each of the classes of interest.
|
||||
Returns the number of instances for each of the classes in the codeframe.
|
||||
|
||||
:return: a np.ndarray of shape `(n_classes)` with the number of instances of each class, in the same order
|
||||
as listed by `self.classes_`
|
||||
|
@ -252,7 +251,8 @@ class LabelledCollection:
|
|||
@property
|
||||
def Xp(self):
|
||||
"""
|
||||
Gets the instances and the true prevalence. This is useful when implementing evaluation protocols
|
||||
Gets the instances and the true prevalence. This is useful when implementing evaluation protocols from
|
||||
a `LabelledCollection` object.
|
||||
|
||||
:return: a tuple `(instances, prevalence)` from this collection
|
||||
"""
|
||||
|
@ -420,6 +420,16 @@ class Dataset:
|
|||
"""
|
||||
return len(self.vocabulary)
|
||||
|
||||
@property
|
||||
def train_test(self):
|
||||
"""
|
||||
Alias to `self.training` and `self.test`
|
||||
|
||||
:return: the training and test collections
|
||||
:return: the training and test collections
|
||||
"""
|
||||
return self.training, self.test
|
||||
|
||||
def stats(self, show):
|
||||
"""
|
||||
Returns (and eventually prints) a dictionary with some stats of this dataset. E.g.,:
|
||||
|
|
|
@ -135,13 +135,13 @@ class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
|||
:param random_seed: allows replicating samples across runs (default None)
|
||||
"""
|
||||
|
||||
def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, random_seed=None):
|
||||
def __init__(self, data:LabelledCollection, sample_size, n_prevalences=21, repeats=10, random_seed=None, return_type='sample_prev'):
|
||||
super(APP, self).__init__(random_seed)
|
||||
self.data = data
|
||||
self.sample_size = sample_size
|
||||
self.n_prevalences = n_prevalences
|
||||
self.repeats = repeats
|
||||
self.set_collator(collator_fn=lambda x: (x.instances, x.prevalence()))
|
||||
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
|
||||
|
||||
def prevalence_grid(self):
|
||||
"""
|
||||
|
@ -192,13 +192,13 @@ class NPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
|||
:param random_seed: allows replicating samples across runs (default None)
|
||||
"""
|
||||
|
||||
def __init__(self, data:LabelledCollection, sample_size, repeats=100, random_seed=None):
|
||||
def __init__(self, data:LabelledCollection, sample_size, repeats=100, random_seed=None, return_type='sample_prev'):
|
||||
super(NPP, self).__init__(random_seed)
|
||||
self.data = data
|
||||
self.sample_size = sample_size
|
||||
self.repeats = repeats
|
||||
self.random_seed = random_seed
|
||||
self.set_collator(collator_fn=lambda x: (x.instances, x.prevalence()))
|
||||
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
|
||||
|
||||
def samples_parameters(self):
|
||||
indexes = []
|
||||
|
@ -229,13 +229,13 @@ class USimplexPP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol)
|
|||
:param random_seed: allows replicating samples across runs (default None)
|
||||
"""
|
||||
|
||||
def __init__(self, data: LabelledCollection, sample_size, repeats=100, random_seed=None):
|
||||
def __init__(self, data: LabelledCollection, sample_size, repeats=100, random_seed=None, return_type='sample_prev'):
|
||||
super(USimplexPP, self).__init__(random_seed)
|
||||
self.data = data
|
||||
self.sample_size = sample_size
|
||||
self.repeats = repeats
|
||||
self.random_seed = random_seed
|
||||
self.set_collator(collator_fn=lambda x: (x.instances, x.prevalence()))
|
||||
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
|
||||
|
||||
def samples_parameters(self):
|
||||
indexes = []
|
||||
|
@ -339,7 +339,7 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol):
|
|||
indexesA, indexesB = indexes
|
||||
sampleA = self.A.sampling_from_index(indexesA)
|
||||
sampleB = self.B.sampling_from_index(indexesB)
|
||||
return sampleA+sampleB
|
||||
return (sampleA+sampleB).Xp
|
||||
|
||||
def total(self):
|
||||
return self.repeats * len(self.mixture_points)
|
||||
|
|
|
@ -46,9 +46,7 @@ def test_fetch_UCIDataset(dataset_name):
|
|||
|
||||
@pytest.mark.parametrize('dataset_name', LEQUA2022_TASKS)
|
||||
def test_fetch_lequa2022(dataset_name):
|
||||
fetch_lequa2022(dataset_name)
|
||||
# dataset = fetch_lequa2022(dataset_name)
|
||||
# print(f'Dataset {dataset_name}')
|
||||
# print('Training set stats')
|
||||
# dataset.training.stats()
|
||||
# print('Test set stats')
|
||||
train, gen_val, gen_test = fetch_lequa2022(dataset_name)
|
||||
print(train.stats())
|
||||
print('Val:', gen_val.total())
|
||||
print('Test:', gen_test.total())
|
||||
|
|
|
@ -12,8 +12,8 @@ def mock_labelled_collection(prefix=''):
|
|||
|
||||
def samples_to_str(protocol):
|
||||
samples_str = ""
|
||||
for sample in protocol():
|
||||
samples_str += f'{sample.instances}\t{sample.labels}\t{sample.prevalence()}\n'
|
||||
for instances, prev in protocol():
|
||||
samples_str += f'{instances}\t{prev}\n'
|
||||
return samples_str
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue