From 29db15ae25b879c88d6f70c637bd28cedb6e84ba Mon Sep 17 00:00:00 2001 From: Alejandro Moreo Date: Thu, 9 Nov 2023 18:13:54 +0100 Subject: [PATCH] added DMx and DMy, with a classmethod that returns HDx and HDy respectively --- examples/comparing_HDy_HDx.py | 4 +- examples/ifcb_experiments.py | 2 +- examples/model_selection.py | 4 +- quapy/functional.py | 54 +++++++++++++ quapy/method/aggregative.py | 61 ++++++++------ quapy/method/non_aggregative.py | 136 +++++++++++--------------------- quapy/tests/test_methods.py | 10 +-- 7 files changed, 144 insertions(+), 127 deletions(-) diff --git a/examples/comparing_HDy_HDx.py b/examples/comparing_HDy_HDx.py index 025f7cd..e7a32ef 100644 --- a/examples/comparing_HDy_HDx.py +++ b/examples/comparing_HDy_HDx.py @@ -6,7 +6,7 @@ from tqdm import tqdm import quapy as qp from quapy.protocol import APP from quapy.method.aggregative import HDy -from quapy.method.non_aggregative import HDx +from quapy.method.non_aggregative import DMx """ @@ -42,7 +42,7 @@ for dataset_name in tqdm(qp.datasets.UCI_DATASETS, total=len(qp.datasets.UCI_DAT # HDx............................................ tinit = time() - hdx = HDx().fit(train) + hdx = DMx.HDx(n_jobs=-1).fit(train) t_hdx_train = time() - tinit tinit = time() diff --git a/examples/ifcb_experiments.py b/examples/ifcb_experiments.py index 913fdb8..4cf9448 100644 --- a/examples/ifcb_experiments.py +++ b/examples/ifcb_experiments.py @@ -12,7 +12,7 @@ quantifiers = [ ('ACC', qp.method.aggregative.ACC(newLR())), ('PCC', qp.method.aggregative.PCC(newLR())), ('PACC', qp.method.aggregative.PACC(newLR())), - ('HDy', qp.method.aggregative.DistributionMatching(newLR())), + ('HDy', qp.method.aggregative.DMy(newLR())), ('EMQ', qp.method.aggregative.EMQ(newLR())) ] diff --git a/examples/model_selection.py b/examples/model_selection.py index b9b4903..ae7fb6a 100644 --- a/examples/model_selection.py +++ b/examples/model_selection.py @@ -1,6 +1,6 @@ import quapy as qp from quapy.protocol import APP -from quapy.method.aggregative import DistributionMatching +from quapy.method.aggregative import DMy from sklearn.linear_model import LogisticRegression import numpy as np @@ -8,7 +8,7 @@ import numpy as np In this example, we show how to perform model selection on a DistributionMatching quantifier. """ -model = DistributionMatching(LogisticRegression()) +model = DMy(LogisticRegression()) qp.environ['SAMPLE_SIZE'] = 100 qp.environ['N_JOBS'] = -1 diff --git a/quapy/functional.py b/quapy/functional.py index 2f64c2b..e29466f 100644 --- a/quapy/functional.py +++ b/quapy/functional.py @@ -291,3 +291,57 @@ def get_divergence(divergence: Union[str, Callable]): return divergence else: raise ValueError(f'argument "divergence" not understood; use a str or a callable function') + + +def argmin_prevalence(loss, n_classes, method='optim_minimize'): + if method == 'optim_minimize': + return optim_minimize(loss, n_classes) + elif method == 'linear_search': + return linear_search(loss, n_classes) + elif method == 'ternary_search': + raise NotImplementedError() + else: + raise NotImplementedError() + + +def optim_minimize(loss, n_classes): + """ + Searches for the optimal prevalence values, i.e., an `n_classes`-dimensional vector of the (`n_classes`-1)-simplex + that yields the smallest lost. This optimization is carried out by means of a constrained search using scipy's + SLSQP routine. + + :param loss: (callable) the function to minimize + :param n_classes: (int) the number of classes, i.e., the dimensionality of the prevalence vector + :return: (ndarray) the best prevalence vector found + """ + from scipy import optimize + + # the initial point is set as the uniform distribution + uniform_distribution = np.full(fill_value=1 / n_classes, shape=(n_classes,)) + + # solutions are bounded to those contained in the unit-simplex + bounds = tuple((0, 1) for _ in range(n_classes)) # values in [0,1] + constraints = ({'type': 'eq', 'fun': lambda x: 1 - sum(x)}) # values summing up to 1 + r = optimize.minimize(loss, x0=uniform_distribution, method='SLSQP', bounds=bounds, constraints=constraints) + return r.x + + +def linear_search(loss, n_classes): + """ + Performs a linear search for the best prevalence value in binary problems. The search is carried out by exploring + the range [0,1] stepping by 0.01. This search is inefficient, and is added only for completeness (some of the + early methods in quantification literature used it, e.g., HDy). A most powerful alternative is `optim_minimize`. + + :param loss: (callable) the function to minimize + :param n_classes: (int) the number of classes, i.e., the dimensionality of the prevalence vector + :return: (ndarray) the best prevalence vector found + """ + assert n_classes==2, 'linear search is only available for binary problems' + + prev_selected, min_score = None, None + for prev in prevalence_linspace(n_prevalences=100, repeats=1, smooth_limits_epsilon=0.0): + score = loss(np.asarray([1 - prev, prev])) + if min_score is None or score < min_score: + prev_selected, min_score = prev, score + + return np.asarray([1 - prev_selected, prev_selected]) \ No newline at end of file diff --git a/quapy/method/aggregative.py b/quapy/method/aggregative.py index 526414c..232a92b 100644 --- a/quapy/method/aggregative.py +++ b/quapy/method/aggregative.py @@ -568,10 +568,11 @@ class HDy(AggregativeProbabilisticQuantifier, BinaryQuantifier): self.Pxy0 = Px[validation.labels == self.classifier.classes_[0]] # pre-compute the histogram for positive and negative examples self.bins = np.linspace(10, 110, 11, dtype=int) # [10, 20, 30, ..., 100, 110] - self.Pxy1_density = {bins: np.histogram(self.Pxy1, bins=bins, range=(0, 1), density=True)[0] for bins in - self.bins} - self.Pxy0_density = {bins: np.histogram(self.Pxy0, bins=bins, range=(0, 1), density=True)[0] for bins in - self.bins} + def hist(P, bins): + h = np.histogram(P, bins=bins, range=(0, 1), density=True)[0] + return h / h.sum() + self.Pxy1_density = {bins: hist(self.Pxy1, bins) for bins in self.bins} + self.Pxy0_density = {bins: hist(self.Pxy0, bins) for bins in self.bins} return self def aggregate(self, classif_posteriors): @@ -712,7 +713,7 @@ class SMM(AggregativeProbabilisticQuantifier, BinaryQuantifier): return np.asarray([1 - class1_prev, class1_prev]) -class DistributionMatching(AggregativeProbabilisticQuantifier): +class DMy(AggregativeProbabilisticQuantifier): """ Generic Distribution Matching quantifier for binary or multiclass quantification based on the space of posterior probabilities. This implementation takes the number of bins, the divergence, and the possibility to work on CDF @@ -733,14 +734,24 @@ class DistributionMatching(AggregativeProbabilisticQuantifier): :param n_jobs: number of parallel workers (default None) """ - def __init__(self, classifier, val_split=0.4, nbins=8, divergence: Union[str, Callable]='HD', cdf=False, n_jobs=None): + def __init__(self, classifier, val_split=0.4, nbins=8, divergence: Union[str, Callable]='HD', + cdf=False, search='optim_minimize', n_jobs=None): self.classifier = classifier self.val_split = val_split self.nbins = nbins self.divergence = divergence self.cdf = cdf + self.search = search self.n_jobs = n_jobs + @classmethod + def HDy(cls, classifier, val_split=0.4, n_jobs=None): + from quapy.method.meta import MedianEstimator + + hdy = DMy(classifier=classifier, val_split=val_split, search='linear_search', divergence='HD') + hdy = MedianEstimator(hdy, param_grid={'nbins': np.linspace(10, 110, 11).astype(int)}, n_jobs=n_jobs) + return hdy + def __get_distributions(self, posteriors): histograms = [] post_dims = posteriors.shape[1] @@ -794,26 +805,20 @@ class DistributionMatching(AggregativeProbabilisticQuantifier): `n` channels (proper distributions of binned posterior probabilities), on which the divergence is computed independently. The matching is computed as an average of the divergence across all channels. - :param instances: instances in the sample + :param posteriors: posterior probabilities of the instances in the sample :return: a vector of class prevalence estimates """ test_distribution = self.__get_distributions(posteriors) divergence = get_divergence(self.divergence) n_classes, n_channels, nbins = self.validation_distribution.shape - def match(prev): + def loss(prev): prev = np.expand_dims(prev, axis=0) mixture_distribution = (prev @ self.validation_distribution.reshape(n_classes,-1)).reshape(n_channels, -1) divs = [divergence(test_distribution[ch], mixture_distribution[ch]) for ch in range(n_channels)] return np.mean(divs) - # the initial point is set as the uniform distribution - uniform_distribution = np.full(fill_value=1 / n_classes, shape=(n_classes,)) + return F.argmin_prevalence(loss, n_classes, method=self.search) - # solutions are bounded to those contained in the unit-simplex - bounds = tuple((0, 1) for x in range(n_classes)) # values in [0,1] - constraints = ({'type': 'eq', 'fun': lambda x: 1 - sum(x)}) # values summing up to 1 - r = optimize.minimize(match, x0=uniform_distribution, method='SLSQP', bounds=bounds, constraints=constraints) - return r.x def newELM(svmperf_base=None, loss='01', C=1): @@ -1215,17 +1220,6 @@ class MS2(MS): return np.median(tprs), np.median(fprs) -ClassifyAndCount = CC -AdjustedClassifyAndCount = ACC -ProbabilisticClassifyAndCount = PCC -ProbabilisticAdjustedClassifyAndCount = PACC -ExpectationMaximizationQuantifier = EMQ -SLD = EMQ -HellingerDistanceY = HDy -MedianSweep = MS -MedianSweep2 = MS2 - - class OneVsAllAggregative(OneVsAllGeneric, AggregativeQuantifier): """ Allows any binary quantifier to perform quantification on single-label datasets. @@ -1283,3 +1277,18 @@ class OneVsAllAggregative(OneVsAllGeneric, AggregativeQuantifier): # the estimation for the positive class prevalence return self.dict_binary_quantifiers[c].aggregate(classif_predictions[:, c])[1] + +#--------------------------------------------------------------- +# aliases +#--------------------------------------------------------------- + +ClassifyAndCount = CC +AdjustedClassifyAndCount = ACC +ProbabilisticClassifyAndCount = PCC +ProbabilisticAdjustedClassifyAndCount = PACC +ExpectationMaximizationQuantifier = EMQ +DistributionMatchingY = DMy +SLD = EMQ +HellingerDistanceY = HDy +MedianSweep = MS +MedianSweep2 = MS2 diff --git a/quapy/method/non_aggregative.py b/quapy/method/non_aggregative.py index 8768c92..87e59fb 100644 --- a/quapy/method/non_aggregative.py +++ b/quapy/method/non_aggregative.py @@ -1,7 +1,5 @@ from typing import Union, Callable - import numpy as np -from scipy import optimize from functional import get_divergence from quapy.data import LabelledCollection @@ -41,81 +39,7 @@ class MaximumLikelihoodPrevalenceEstimation(BaseQuantifier): return self.estimated_prevalence - -class HDx(BinaryQuantifier): - """ - `Hellinger Distance x `_ (HDx). - HDx is a method for training binary quantifiers, that models quantification as the problem of - minimizing the average divergence (in terms of the Hellinger Distance) across the feature-specific normalized - histograms of two representations, one for the unlabelled examples, and another generated from the training - examples as a mixture model of the class-specific representations. The parameters of the mixture thus represent - the estimates of the class prevalence values. The method computes all matchings for nbins in [10, 20, ..., 110] - and reports the mean of the median. The best prevalence is searched via linear search, from 0 to 1 steppy by 0.01. - """ - - def __init__(self): - self.feat_ranges = None - - def covariate_histograms(self, X, nbins): - assert self.feat_ranges is not None, 'quantify called before fit' - - histograms = [] - for col_idx in range(self.nfeats): - feature = X[:,col_idx] - feat_range = self.feat_ranges[col_idx] - histograms.append(np.histogram(feature, bins=nbins, range=feat_range, density=True)[0]) - - return np.vstack(histograms).T - - def fit(self, data: LabelledCollection): - """ - Trains a HDx quantifier. - - :param data: the training set - :return: self - """ - - self._check_binary(data, self.__class__.__name__) - X, y = data.Xy - - self.nfeats = X.shape[1] - self.feat_ranges = _get_features_range(X) - - # pre-compute the representation for positive and negative examples - self.bins = np.linspace(10, 110, 11, dtype=int) # [10, 20, 30, ..., 100, 110] - self.H0 = {bins:self.covariate_histograms(X[y == 0], bins) for bins in self.bins} - self.H1 = {bins:self.covariate_histograms(X[y == 1], bins) for bins in self.bins} - return self - - def quantify(self, X): - # "In this work, the number of bins b used in HDx and HDy was chosen from 10 to 110 in steps of 10, - # and the final estimated a priori probability was taken as the median of these 11 estimates." - # (González-Castro, et al., 2013). - - assert X.shape[1] == self.nfeats, f'wrong shape in quantify; expected {self.nfeats}, found {X.shape[1]}' - - prev_estimations = [] - for nbins in self.bins: - Ht = self.covariate_histograms(X, nbins=nbins) - H0 = self.H0[nbins] - H1 = self.H1[nbins] - - # the authors proposed to search for the prevalence yielding the best matching as a linear search - # at small steps (modern implementations resort to an optimization procedure) - prev_selected, min_dist = None, None - for prev in F.prevalence_linspace(n_prevalences=100, repeats=1, smooth_limits_epsilon=0.0): - Hx = prev * H1 + (1 - prev) * H0 - hdx = np.mean([F.HellingerDistance(Hx[:,col], Ht[:,col]) for col in range(self.nfeats)]) - - if prev_selected is None or hdx < min_dist: - prev_selected, min_dist = prev, hdx - prev_estimations.append(prev_selected) - - class1_prev = np.median(prev_estimations) - return np.asarray([1 - class1_prev, class1_prev]) - - -class DistributionMatchingX(BaseQuantifier): +class DMx(BaseQuantifier): """ Generic Distribution Matching quantifier for binary or multiclass quantification based on the space of covariates. This implementation takes the number of bins, the divergence, and the possibility to work on CDF as hyperparameters. @@ -128,22 +52,51 @@ class DistributionMatchingX(BaseQuantifier): :param n_jobs: number of parallel workers (default None) """ - def __init__(self, nbins=8, divergence: Union[str, Callable]='HD', cdf=False, n_jobs=None): + def __init__(self, nbins=8, divergence: Union[str, Callable]='HD', cdf=False, search='optim_minimize', n_jobs=None): self.nbins = nbins self.divergence = divergence self.cdf = cdf + self.search = search self.n_jobs = n_jobs + @classmethod + def HDx(cls, n_jobs=None): + """ + `Hellinger Distance x `_ (HDx). + HDx is a method for training binary quantifiers, that models quantification as the problem of + minimizing the average divergence (in terms of the Hellinger Distance) across the feature-specific normalized + histograms of two representations, one for the unlabelled examples, and another generated from the training + examples as a mixture model of the class-specific representations. The parameters of the mixture thus represent + the estimates of the class prevalence values. + + The method computes all matchings for nbins in [10, 20, ..., 110] and reports the mean of the median. + The best prevalence is searched via linear search, from 0 to 1 stepping by 0.01. + + :param n_jobs: number of parallel workers + :return: an instance of this class setup to mimick the performance of the HDx as originally proposed by + González-Castro, Alaiz-Rodríguez, Alegre (2013) + """ + from quapy.method.meta import MedianEstimator + + dmx = DMx(divergence='HD', cdf=False, search='linear_search') + nbins = {'nbins': np.linspace(10, 110, 11, dtype=int)} + hdx = MedianEstimator(base_quantifier=dmx, param_grid=nbins, n_jobs=n_jobs) + return hdx + def __get_distributions(self, X): + histograms = [] for feat_idx in range(self.nfeats): - hist = np.histogram(X[:, feat_idx], bins=self.nbins, range=self.feat_ranges[feat_idx])[0] - normhist = hist / hist.sum() - histograms.append(normhist) - + feature = X[:, feat_idx] + feat_range = self.feat_ranges[feat_idx] + hist = np.histogram(feature, bins=self.nbins, range=feat_range)[0] + norm_hist = hist / hist.sum() + histograms.append(norm_hist) distributions = np.vstack(histograms) + if self.cdf: distributions = np.cumsum(distributions, axis=1) + return distributions def fit(self, data: LabelledCollection): @@ -184,20 +137,14 @@ class DistributionMatchingX(BaseQuantifier): test_distribution = self.__get_distributions(instances) divergence = get_divergence(self.divergence) n_classes, n_feats, nbins = self.validation_distribution.shape - def match(prev): + def loss(prev): prev = np.expand_dims(prev, axis=0) mixture_distribution = (prev @ self.validation_distribution.reshape(n_classes,-1)).reshape(n_feats, -1) divs = [divergence(test_distribution[feat], mixture_distribution[feat]) for feat in range(n_feats)] return np.mean(divs) - # the initial point is set as the uniform distribution - uniform_distribution = np.full(fill_value=1 / n_classes, shape=(n_classes,)) + return F.argmin_prevalence(loss, n_classes, method=self.search) - # solutions are bounded to those contained in the unit-simplex - bounds = tuple((0, 1) for x in range(n_classes)) # values in [0,1] - constraints = ({'type': 'eq', 'fun': lambda x: 1 - sum(x)}) # values summing up to 1 - r = optimize.minimize(match, x0=uniform_distribution, method='SLSQP', bounds=bounds, constraints=constraints) - return r.x def _get_features_range(X): @@ -206,4 +153,11 @@ def _get_features_range(X): for col_idx in range(ncols): feature = X[:,col_idx] feat_ranges.append((np.min(feature), np.max(feature))) - return feat_ranges \ No newline at end of file + return feat_ranges + + +#--------------------------------------------------------------- +# aliases +#--------------------------------------------------------------- + +DistributionMatchingX = DMx \ No newline at end of file diff --git a/quapy/tests/test_methods.py b/quapy/tests/test_methods.py index da5485a..bca34e3 100644 --- a/quapy/tests/test_methods.py +++ b/quapy/tests/test_methods.py @@ -10,7 +10,7 @@ from quapy.data import Dataset, LabelledCollection from quapy.method import AGGREGATIVE_METHODS, NON_AGGREGATIVE_METHODS from quapy.method.meta import Ensemble from quapy.protocol import APP -from quapy.method.aggregative import DistributionMatching +from quapy.method.aggregative import DMy from quapy.method.meta import MedianEstimator datasets = [pytest.param(qp.datasets.fetch_twitter('hcr', pickle=True), id='hcr'), @@ -189,7 +189,7 @@ def test_median_meta(): errors = [] for nbins in nbins_grid: with qp.util.temp_seed(0): - q = DistributionMatching(LogisticRegression(), nbins=nbins) + q = DMy(LogisticRegression(), nbins=nbins) mae, estim_prevs = __fit_test(q, train, test) prevs.append(estim_prevs) errors.append(mae) @@ -198,7 +198,7 @@ def test_median_meta(): mae = np.mean(errors) print(f'\tMAE={mae:.4f}') - q = DistributionMatching(LogisticRegression()) + q = DMy(LogisticRegression()) q = MedianEstimator(q, param_grid={'nbins': nbins_grid}, random_state=0, n_jobs=-1) median_mae, prev = __fit_test(q, train, test) print(f'\tMAE={median_mae:.4f}') @@ -220,12 +220,12 @@ def test_median_meta_modsel(): nbins_grid = [2, 4, 5, 10, 15] - q = DistributionMatching(LogisticRegression()) + q = DMy(LogisticRegression()) q = MedianEstimator(q, param_grid={'nbins': nbins_grid}, random_state=0, n_jobs=-1) median_mae, _ = __fit_test(q, train, test) print(f'\tMAE={median_mae:.4f}') - q = DistributionMatching(LogisticRegression()) + q = DMy(LogisticRegression()) lr_params = {'classifier__C': np.logspace(-1, 1, 3)} q = MedianEstimator(q, param_grid={'nbins': nbins_grid}, random_state=0, n_jobs=-1) q = GridSearchQ(q, param_grid=lr_params, protocol=APP(val), n_jobs=-1)