From 428f10fb2d09021b34cc7bf2c8d40199f5943f4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pablo=20Gonz=C3=A1lez?= Date: Mon, 11 Jul 2022 14:04:28 +0200 Subject: [PATCH] adding SMM --- quapy/method/aggregative.py | 41 ++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/quapy/method/aggregative.py b/quapy/method/aggregative.py index ac6fdc3..a2e03ae 100644 --- a/quapy/method/aggregative.py +++ b/quapy/method/aggregative.py @@ -173,7 +173,7 @@ def _training_helper(learner, if isinstance(val_split, float): if not (0 < val_split < 1): raise ValueError(f'train/val split {val_split} out of range, must be in (0,1)') - train, unused = data.split_stratified(train_prop=1 - val_split,random_state=0) + train, unused = data.split_stratified(train_prop=1 - val_split) elif isinstance(val_split, LabelledCollection): train = data unused = val_split @@ -712,6 +712,45 @@ class DyS(AggregativeProbabilisticQuantifier, BinaryQuantifier): return np.asarray([1 - class1_prev, class1_prev]) +class SMM(AggregativeProbabilisticQuantifier, BinaryQuantifier): + """ + `SMM method `_ (SMM). + SMM is a simplification of matching distribution methods where the representation of the examples + is created using the mean instead of a histogram. + + :param learner: a sklearn's Estimator that generates a binary classifier. + :param val_split: a float in range (0,1) indicating the proportion of data to be used as a stratified held-out + validation distribution, or a :class:`quapy.data.base.LabelledCollection` (the split itself). + """ + + def __init__(self, learner: BaseEstimator, val_split=0.4): + self.learner = learner + self.val_split = val_split + + def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, LabelledCollection] = None): + if val_split is None: + val_split = self.val_split + + self._check_binary(data, self.__class__.__name__) + self.learner, validation = _training_helper( + self.learner, data, fit_learner, ensure_probabilistic=True, val_split=val_split) + Px = self.classify(validation.instances)[:, 1] # takes only the P(y=+1|x) + self.Pxy1 = Px[validation.labels == self.learner.classes_[1]] + self.Pxy0 = Px[validation.labels == self.learner.classes_[0]] + self.Pxy1_mean = np.mean(self.Pxy1) + self.Pxy0_mean = np.mean(self.Pxy0) + return self + + def aggregate(self, classif_posteriors): + Px = classif_posteriors[:, 1] # takes only the P(y=+1|x) + Px_mean = np.mean(Px) + + class1_prev = (Px_mean - self.Pxy0_mean)/(self.Pxy1_mean - self.Pxy0_mean) + class1_prev = np.clip(class1_prev, 0, 1) + + return np.asarray([1 - class1_prev, class1_prev]) + + class ELM(AggregativeQuantifier, BinaryQuantifier): """ Class of Explicit Loss Minimization (ELM) quantifiers.