diff --git a/quapy/data/base.py b/quapy/data/base.py index c555692..b22a71f 100644 --- a/quapy/data/base.py +++ b/quapy/data/base.py @@ -63,6 +63,7 @@ class LabelledCollection: """ return self.instances.shape[0] + @property def prevalence(self): """ Returns the prevalence, or relative frequency, of the classes of interest. @@ -248,6 +249,43 @@ class LabelledCollection: """ return self.instances, self.labels + @property + def Xp(self): + """ + Gets the instances and the true prevalence. This is useful when implementing evaluation protocols + + :return: a tuple `(instances, prevalence)` from this collection + """ + return self.instances, self.prevalence() + + @property + def X(self): + """ + An alias to self.instances + + :return: self.instances + """ + return self.instances + + @property + def y(self): + """ + An alias to self.labels + + :return: self.labels + """ + return self.labels + + @property + def p(self): + """ + An alias to self.prevalence() + + :return: self.prevalence() + """ + return self.prevalence() + + def stats(self, show=True): """ Returns (and eventually prints) a dictionary with some stats of this collection. E.g.,: diff --git a/quapy/protocol.py b/quapy/protocol.py index f539830..c55c3ef 100644 --- a/quapy/protocol.py +++ b/quapy/protocol.py @@ -84,14 +84,16 @@ class AbstractStochasticSeededProtocol(AbstractProtocol): if self.random_seed is not None: stack.enter_context(qp.util.temp_seed(self.random_seed)) for params in self.samples_parameters(): - yield self.collator_fn(self.sample(params)) + yield self.collator(self.sample(params)) - def set_collator(self, collator_fn): - self.collator_fn = collator_fn + def collator(self, sample, *args): + return sample class OnLabelledCollectionProtocol: + RETURN_TYPES = ['sample_prev', 'labelled_collection'] + def get_labelled_collection(self): return self.data @@ -106,6 +108,15 @@ class OnLabelledCollectionProtocol: new = deepcopy(self) return new.on_preclassified_instances(pre_classifications, in_place=True) + @classmethod + def get_collator(cls, return_type='sample_prev'): + assert return_type in cls.RETURN_TYPES, \ + f'unknown return type passed as argument; valid ones are {cls.RETURN_TYPES}' + if return_type=='sample_prev': + return lambda lc:lc.Xp + elif return_type=='labelled_collection': + return lambda lc:lc + class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol): """