1
0
Fork 0

dys implementation

This commit is contained in:
Pablo González 2022-07-11 12:21:49 +02:00
parent 1742b75504
commit 46e294002f
3 changed files with 84 additions and 2 deletions

View File

@ -78,6 +78,12 @@ def HellingerDistance(P, Q):
""" """
return np.sqrt(np.sum((np.sqrt(P) - np.sqrt(Q))**2)) return np.sqrt(np.sum((np.sqrt(P) - np.sqrt(Q))**2))
def TopsoeDistance(P, Q, epsilon=1e-20):
""" Topsoe
"""
return np.sum(P*np.log((2*P+epsilon)/(P+Q+epsilon)) +
Q*np.log((2*Q+epsilon)/(P+Q+epsilon)))
def uniform_prevalence_sampling(n_classes, size=1): def uniform_prevalence_sampling(n_classes, size=1):
""" """

View File

@ -19,6 +19,7 @@ AGGREGATIVE_METHODS = {
aggregative.PACC, aggregative.PACC,
aggregative.EMQ, aggregative.EMQ,
aggregative.HDy, aggregative.HDy,
aggregative.DyS,
aggregative.X, aggregative.X,
aggregative.T50, aggregative.T50,
aggregative.MAX, aggregative.MAX,

View File

@ -1,6 +1,7 @@
from abc import abstractmethod from abc import abstractmethod
from copy import deepcopy from copy import deepcopy
from typing import Union import string
from typing import Callable, Union
import numpy as np import numpy as np
from joblib import Parallel, delayed from joblib import Parallel, delayed
from sklearn.base import BaseEstimator from sklearn.base import BaseEstimator
@ -172,7 +173,7 @@ def _training_helper(learner,
if isinstance(val_split, float): if isinstance(val_split, float):
if not (0 < val_split < 1): if not (0 < val_split < 1):
raise ValueError(f'train/val split {val_split} out of range, must be in (0,1)') raise ValueError(f'train/val split {val_split} out of range, must be in (0,1)')
train, unused = data.split_stratified(train_prop=1 - val_split) train, unused = data.split_stratified(train_prop=1 - val_split,random_state=0)
elif isinstance(val_split, LabelledCollection): elif isinstance(val_split, LabelledCollection):
train = data train = data
unused = val_split unused = val_split
@ -637,6 +638,80 @@ class HDy(AggregativeProbabilisticQuantifier, BinaryQuantifier):
return np.asarray([1 - class1_prev, class1_prev]) return np.asarray([1 - class1_prev, class1_prev])
class DyS(AggregativeProbabilisticQuantifier, BinaryQuantifier):
"""
`DyS framework <https://ojs.aaai.org/index.php/AAAI/article/view/4376>`_ (DyS).
DyS is a generalization of HDy method, using a Ternary Search in order to find the prevalence that
minimizes the distance between distributions.
Details for the ternary search have been got from <https://dl.acm.org/doi/pdf/10.1145/3219819.3220059>
:param learner: a sklearn's Estimator that generates a binary classifier
:param val_split: a float in range (0,1) indicating the proportion of data to be used as a stratified held-out
validation distribution, or a :class:`quapy.data.base.LabelledCollection` (the split itself).
:param n_bins: an int with the number of bins to use to compute the histograms.
:param distance: an str with a distance already included in the librar (HD or topsoe), of a function
that computes the distance between two distributions.
:param tol: a float with the tolerance for the ternary search algorithm.
"""
def __init__(self, learner: BaseEstimator, val_split=0.4, n_bins=8, distance: Union[str, Callable]='HD', tol=1e-05):
self.learner = learner
self.val_split = val_split
self.tol = tol
self.distance = distance
self.n_bins = n_bins
def _ternary_search(self, f, left, right, tol):
"""
Find maximum of unimodal function f() within [left, right]
"""
while abs(right - left) >= tol:
left_third = left + (right - left) / 3
right_third = right - (right - left) / 3
if f(left_third) > f(right_third):
left = left_third
else:
right = right_third
# Left and right are the current bounds; the maximum is between them
return (left + right) / 2
def _compute_distance(self, Px_train, Px_test, distance: Union[str, Callable]='HD'):
if distance=='HD':
return F.HellingerDistance(Px_train, Px_test)
elif distance=='topsoe':
return F.TopsoeDistance(Px_train, Px_test)
else:
return distance(Px_train, Px_test)
def fit(self, data: LabelledCollection, fit_learner=True, val_split: Union[float, LabelledCollection] = None):
if val_split is None:
val_split = self.val_split
self._check_binary(data, self.__class__.__name__)
self.learner, validation = _training_helper(
self.learner, data, fit_learner, ensure_probabilistic=True, val_split=val_split)
Px = self.classify(validation.instances)[:, 1] # takes only the P(y=+1|x)
self.Pxy1 = Px[validation.labels == self.learner.classes_[1]]
self.Pxy0 = Px[validation.labels == self.learner.classes_[0]]
self.Pxy1_density = np.histogram(self.Pxy1, bins=self.n_bins, range=(0, 1), density=True)[0]
self.Pxy0_density = np.histogram(self.Pxy0, bins=self.n_bins, range=(0, 1), density=True)[0]
return self
def aggregate(self, classif_posteriors):
Px = classif_posteriors[:, 1] # takes only the P(y=+1|x)
Px_test = np.histogram(Px, bins=self.n_bins, range=(0, 1), density=True)[0]
def distribution_distance(prev):
Px_train = prev * self.Pxy1_density + (1 - prev) * self.Pxy0_density
return self._compute_distance(Px_train,Px_test,self.distance)
class1_prev = self._ternary_search(f=distribution_distance, left=0, right=1, tol=self.tol)
return np.asarray([1 - class1_prev, class1_prev])
class ELM(AggregativeQuantifier, BinaryQuantifier): class ELM(AggregativeQuantifier, BinaryQuantifier):
""" """
Class of Explicit Loss Minimization (ELM) quantifiers. Class of Explicit Loss Minimization (ELM) quantifiers.