From c0c37f0a178164aacbb626181a1fe43bd3973d37 Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Thu, 16 Jun 2022 16:54:15 +0200 Subject: [PATCH] return type in covariate protocol --- quapy/protocol.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/quapy/protocol.py b/quapy/protocol.py index fec37ca..f8b828f 100644 --- a/quapy/protocol.py +++ b/quapy/protocol.py @@ -301,7 +301,8 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol): repeats=1, prevalence=None, mixture_points=11, - random_seed=None): + random_seed=None, + return_type='sample_prev'): super(CovariateShiftPP, self).__init__(random_seed) self.A = domainA self.B = domainB @@ -322,6 +323,7 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol): assert all(np.logical_and(self.mixture_points >= 0, self.mixture_points<=1)), \ 'mixture_model datatype not understood (expected int or a sequence of real values in [0,1])' self.random_seed = random_seed + self.collator = OnLabelledCollectionProtocol.get_collator(return_type) def samples_parameters(self): indexesA, indexesB = [], [] @@ -339,7 +341,7 @@ class CovariateShiftPP(AbstractStochasticSeededProtocol): indexesA, indexesB = indexes sampleA = self.A.sampling_from_index(indexesA) sampleB = self.B.sampling_from_index(indexesB) - return (sampleA+sampleB).Xp + return self.collator(sampleA+sampleB) def total(self): return self.repeats * len(self.mixture_points)