diff --git a/quapy/protocol.py b/quapy/protocol.py index fec37ca..f8b828f 100644 --- a/quapy/protocol.py +++ b/quapy/protocol.py @@ -301,7 +301,8 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol): repeats=1, prevalence=None, mixture_points=11, - random_seed=None): + random_seed=None, + return_type='sample_prev'): super(CovariateShiftPP, self).__init__(random_seed) self.A = domainA self.B = domainB @@ -322,6 +323,7 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol): assert all(np.logical_and(self.mixture_points >= 0, self.mixture_points<=1)), \ 'mixture_model datatype not understood (expected int or a sequence of real values in [0,1])' self.random_seed = random_seed + self.collator = OnLabelledCollectionProtocol.get_collator(return_type) def samples_parameters(self): indexesA, indexesB = [], [] @@ -339,7 +341,7 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol): indexesA, indexesB = indexes sampleA = self.A.sampling_from_index(indexesA) sampleB = self.B.sampling_from_index(indexesB) - return (sampleA+sampleB).Xp + return self.collator(sampleA+sampleB) def total(self): return self.repeats * len(self.mixture_points)