return type in covariate protocol
This commit is contained in:
parent
a7c768bb40
commit
c0c37f0a17
|
@ -301,7 +301,8 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol):
|
|||
repeats=1,
|
||||
prevalence=None,
|
||||
mixture_points=11,
|
||||
random_seed=None):
|
||||
random_seed=None,
|
||||
return_type='sample_prev'):
|
||||
super(CovariateShiftPP, self).__init__(random_seed)
|
||||
self.A = domainA
|
||||
self.B = domainB
|
||||
|
@ -322,6 +323,7 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol):
|
|||
assert all(np.logical_and(self.mixture_points >= 0, self.mixture_points<=1)), \
|
||||
'mixture_model datatype not understood (expected int or a sequence of real values in [0,1])'
|
||||
self.random_seed = random_seed
|
||||
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
|
||||
|
||||
def samples_parameters(self):
|
||||
indexesA, indexesB = [], []
|
||||
|
@ -339,7 +341,7 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol):
|
|||
indexesA, indexesB = indexes
|
||||
sampleA = self.A.sampling_from_index(indexesA)
|
||||
sampleB = self.B.sampling_from_index(indexesB)
|
||||
return (sampleA+sampleB).Xp
|
||||
return self.collator(sampleA+sampleB)
|
||||
|
||||
def total(self):
|
||||
return self.repeats * len(self.mixture_points)
|
||||
|
|
Loading…
Reference in New Issue