forked from moreo/QuaPy
adding SMM
This commit is contained in:
parent
46e294002f
commit
428f10fb2d
|
@ -173,7 +173,7 @@ def _training_helper(learner,
|
||||||
if isinstance(val_split, float):
|
if isinstance(val_split, float):
|
||||||
if not (0 < val_split < 1):
|
if not (0 < val_split < 1):
|
||||||
raise ValueError(f'train/val split {val_split} out of range, must be in (0,1)')
|
raise ValueError(f'train/val split {val_split} out of range, must be in (0,1)')
|
||||||
train, unused = data.split_stratified(train_prop=1 - val_split,random_state=0)
|
train, unused = data.split_stratified(train_prop=1 - val_split)
|
||||||
elif isinstance(val_split, LabelledCollection):
|
elif isinstance(val_split, LabelledCollection):
|
||||||
train = data
|
train = data
|
||||||
unused = val_split
|
unused = val_split
|
||||||
|
@ -712,6 +712,45 @@ class DyS(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
||||||
return np.asarray([1 - class1_prev, class1_prev])
|
return np.asarray([1 - class1_prev, class1_prev])
|
||||||
|
|
||||||
|
|
||||||
|
class SMM(AggregativeProbabilisticQuantifier, BinaryQuantifier):
|
||||||
|
"""
|
||||||
|
`SMM method <https://ieeexplore.ieee.org/document/9260028>`_ (SMM).
|
||||||
|
SMM is a simplification of matching distribution methods where the representation of the examples
|
||||||
|
is created using the mean instead of a histogram.
|
||||||
|
|
||||||
|
:param learner: a sklearn's Estimator that generates a binary classifier.
|
||||||
|
:param val_split: a float in range (0,1) indicating the proportion of data to be used as a stratified held-out
|
||||||
|
validation distribution, or a :class:`quapy.data.base.LabelledCollection` (the split itself).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, learner: BaseEstimator, val_split=0.4):
|
||||||
|
self.learner = learner
|
||||||
|
self.val_split = val_split
|
||||||
|
|
||||||
|
def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, LabelledCollection] = None):
|
||||||
|
if val_split is None:
|
||||||
|
val_split = self.val_split
|
||||||
|
|
||||||
|
self._check_binary(data, self.__class__.__name__)
|
||||||
|
self.learner, validation = _training_helper(
|
||||||
|
self.learner, data, fit_learner, ensure_probabilistic=True, val_split=val_split)
|
||||||
|
Px = self.classify(validation.instances)[:, 1] # takes only the P(y=+1|x)
|
||||||
|
self.Pxy1 = Px[validation.labels == self.learner.classes_[1]]
|
||||||
|
self.Pxy0 = Px[validation.labels == self.learner.classes_[0]]
|
||||||
|
self.Pxy1_mean = np.mean(self.Pxy1)
|
||||||
|
self.Pxy0_mean = np.mean(self.Pxy0)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def aggregate(self, classif_posteriors):
|
||||||
|
Px = classif_posteriors[:, 1] # takes only the P(y=+1|x)
|
||||||
|
Px_mean = np.mean(Px)
|
||||||
|
|
||||||
|
class1_prev = (Px_mean - self.Pxy0_mean)/(self.Pxy1_mean - self.Pxy0_mean)
|
||||||
|
class1_prev = np.clip(class1_prev, 0, 1)
|
||||||
|
|
||||||
|
return np.asarray([1 - class1_prev, class1_prev])
|
||||||
|
|
||||||
|
|
||||||
class ELM(AggregativeQuantifier, BinaryQuantifier):
|
class ELM(AggregativeQuantifier, BinaryQuantifier):
|
||||||
"""
|
"""
|
||||||
Class of Explicit Loss Minimization (ELM) quantifiers.
|
Class of Explicit Loss Minimization (ELM) quantifiers.
|
||||||
|
|
Loading…
Reference in New Issue