return type in covariate protocol

This commit is contained in:
Alejandro Moreo Fernandez 2022-06-16 16:54:15 +02:00
parent a7c768bb40
commit c0c37f0a17
1 changed files with 4 additions and 2 deletions

View File

@ -301,7 +301,8 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol):
repeats=1,
prevalence=None,
mixture_points=11,
random_seed=None):
random_seed=None,
return_type='sample_prev'):
super(CovariateShiftPP, self).__init__(random_seed)
self.A = domainA
self.B = domainB
@ -322,6 +323,7 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol):
assert all(np.logical_and(self.mixture_points >= 0, self.mixture_points<=1)), \
'mixture_model datatype not understood (expected int or a sequence of real values in [0,1])'
self.random_seed = random_seed
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
def samples_parameters(self):
indexesA, indexesB = [], []
@ -339,7 +341,7 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol):
indexesA, indexesB = indexes
sampleA = self.A.sampling_from_index(indexesA)
sampleB = self.B.sampling_from_index(indexesB)
return (sampleA+sampleB).Xp
return self.collator(sampleA+sampleB)
def total(self):
return self.repeats * len(self.mixture_points)