forked from moreo/QuaPy
updating properties of labelled collection
This commit is contained in:
parent
45642ad778
commit
bfe4b8b51a
|
@ -63,6 +63,7 @@ class LabelledCollection:
|
|||
"""
|
||||
return self.instances.shape[0]
|
||||
|
||||
@property
|
||||
def prevalence(self):
|
||||
"""
|
||||
Returns the prevalence, or relative frequency, of the classes of interest.
|
||||
|
@ -248,6 +249,43 @@ class LabelledCollection:
|
|||
"""
|
||||
return self.instances, self.labels
|
||||
|
||||
@property
|
||||
def Xp(self):
|
||||
"""
|
||||
Gets the instances and the true prevalence. This is useful when implementing evaluation protocols
|
||||
|
||||
:return: a tuple `(instances, prevalence)` from this collection
|
||||
"""
|
||||
return self.instances, self.prevalence()
|
||||
|
||||
@property
|
||||
def X(self):
|
||||
"""
|
||||
An alias to self.instances
|
||||
|
||||
:return: self.instances
|
||||
"""
|
||||
return self.instances
|
||||
|
||||
@property
|
||||
def y(self):
|
||||
"""
|
||||
An alias to self.labels
|
||||
|
||||
:return: self.labels
|
||||
"""
|
||||
return self.labels
|
||||
|
||||
@property
|
||||
def p(self):
|
||||
"""
|
||||
An alias to self.prevalence()
|
||||
|
||||
:return: self.prevalence()
|
||||
"""
|
||||
return self.prevalence()
|
||||
|
||||
|
||||
def stats(self, show=True):
|
||||
"""
|
||||
Returns (and eventually prints) a dictionary with some stats of this collection. E.g.,:
|
||||
|
|
|
@ -84,14 +84,16 @@ class AbstractStochasticSeededProtocol(AbstractProtocol):
|
|||
if self.random_seed is not None:
|
||||
stack.enter_context(qp.util.temp_seed(self.random_seed))
|
||||
for params in self.samples_parameters():
|
||||
yield self.collator_fn(self.sample(params))
|
||||
yield self.collator(self.sample(params))
|
||||
|
||||
def set_collator(self, collator_fn):
|
||||
self.collator_fn = collator_fn
|
||||
def collator(self, sample, *args):
|
||||
return sample
|
||||
|
||||
|
||||
class OnLabelledCollectionProtocol:
|
||||
|
||||
RETURN_TYPES = ['sample_prev', 'labelled_collection']
|
||||
|
||||
def get_labelled_collection(self):
|
||||
return self.data
|
||||
|
||||
|
@ -106,6 +108,15 @@ class OnLabelledCollectionProtocol:
|
|||
new = deepcopy(self)
|
||||
return new.on_preclassified_instances(pre_classifications, in_place=True)
|
||||
|
||||
@classmethod
|
||||
def get_collator(cls, return_type='sample_prev'):
|
||||
assert return_type in cls.RETURN_TYPES, \
|
||||
f'unknown return type passed as argument; valid ones are {cls.RETURN_TYPES}'
|
||||
if return_type=='sample_prev':
|
||||
return lambda lc:lc.Xp
|
||||
elif return_type=='labelled_collection':
|
||||
return lambda lc:lc
|
||||
|
||||
|
||||
class APP(AbstractStochasticSeededProtocol, OnLabelledCollectionProtocol):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue