return type in covariate protocol
This commit is contained in:
parent
a7c768bb40
commit
c0c37f0a17
|
@ -301,7 +301,8 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol):
|
||||||
repeats=1,
|
repeats=1,
|
||||||
prevalence=None,
|
prevalence=None,
|
||||||
mixture_points=11,
|
mixture_points=11,
|
||||||
random_seed=None):
|
random_seed=None,
|
||||||
|
return_type='sample_prev'):
|
||||||
super(CovariateShiftPP, self).__init__(random_seed)
|
super(CovariateShiftPP, self).__init__(random_seed)
|
||||||
self.A = domainA
|
self.A = domainA
|
||||||
self.B = domainB
|
self.B = domainB
|
||||||
|
@ -322,6 +323,7 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol):
|
||||||
assert all(np.logical_and(self.mixture_points >= 0, self.mixture_points<=1)), \
|
assert all(np.logical_and(self.mixture_points >= 0, self.mixture_points<=1)), \
|
||||||
'mixture_model datatype not understood (expected int or a sequence of real values in [0,1])'
|
'mixture_model datatype not understood (expected int or a sequence of real values in [0,1])'
|
||||||
self.random_seed = random_seed
|
self.random_seed = random_seed
|
||||||
|
self.collator = OnLabelledCollectionProtocol.get_collator(return_type)
|
||||||
|
|
||||||
def samples_parameters(self):
|
def samples_parameters(self):
|
||||||
indexesA, indexesB = [], []
|
indexesA, indexesB = [], []
|
||||||
|
@ -339,7 +341,7 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol):
|
||||||
indexesA, indexesB = indexes
|
indexesA, indexesB = indexes
|
||||||
sampleA = self.A.sampling_from_index(indexesA)
|
sampleA = self.A.sampling_from_index(indexesA)
|
||||||
sampleB = self.B.sampling_from_index(indexesB)
|
sampleB = self.B.sampling_from_index(indexesB)
|
||||||
return (sampleA+sampleB).Xp
|
return self.collator(sampleA+sampleB)
|
||||||
|
|
||||||
def total(self):
|
def total(self):
|
||||||
return self.repeats * len(self.mixture_points)
|
return self.repeats * len(self.mixture_points)
|
||||||
|
|
Loading…
Reference in New Issue