Add projection onto the probability simplex

This commit is contained in:
Paweł Czyż 2024-03-15 17:06:20 +01:00
parent 020530e14f
commit 4dd66b1921
2 changed files with 54 additions and 4 deletions

View File

@ -1,6 +1,6 @@
import itertools
from collections import defaultdict
from typing import Union, Callable
from typing import Literal, Union, Callable
import scipy
import numpy as np
@ -375,3 +375,54 @@ def linear_search(loss, n_classes):
prev_selected, min_score = prev, score
return np.asarray([1 - prev_selected, prev_selected])
def _project_onto_probability_simplex(v: np.ndarray) -> np.ndarray:
"""Projects a point onto the probability simplex.
The code is adapted from Mathieu Blondel's BSD-licensed
`implementation <https://gist.github.com/mblondel/6f3b7aaad90606b98f71>`_
which is accompanying the paper
Mathieu Blondel, Akinori Fujino, and Naonori Ueda.
Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex,
ICPR 2014, `URL <http://www.mblondel.org/publications/mblondel-icpr2014.pdf>`_
:param v: point in n-dimensional space, shape `(n,)`
:return: projection of `v` onto (n-1)-dimensional probability simplex, shape `(n,)`
"""
v = np.asarray(v)
n = len(v)
# Sort the values in the descending order
u = np.sort(v)[::-1]
cssv = np.cumsum(u) - 1.0
ind = np.arange(1, n + 1)
cond = u - cssv / ind > 0
rho = ind[cond][-1]
theta = cssv[cond][-1] / float(rho)
return np.maximum(v - theta, 0)
def clip_prevalence(p: np.ndarray, method: Literal[None, "none", "clip", "project"]) -> np.ndarray:
"""
Clips the proportions vector `p` so that it is a valid probability distribution.
:param p: the proportions vector to be clipped, shape `(n_classes,)`
:param method: the method to use for normalization.
If `None` or `"none"`, no normalization is performed.
If `"clip"`, the values are clipped to the range [0,1] and normalized, so they sum up to 1.
If `"project"`, the values are projected onto the probability simplex.
:return: the normalized prevalence vector, shape `(n_classes,)`
"""
if method is None or method == "none":
return p
elif method == "clip":
adjusted = np.clip(p, 0, 1)
return adjusted / adjusted.sum()
elif method == "project":
return _project_onto_probability_simplex(p)
else:
raise ValueError(f"Method {method} not known.")

View File

@ -443,8 +443,7 @@ class ACC(AggregativeCrispQuantifier):
try:
adjusted_prevs = np.linalg.solve(A, B)
adjusted_prevs = np.clip(adjusted_prevs, 0, 1)
adjusted_prevs /= adjusted_prevs.sum()
adjusted_prevs = F.clip_prevalence(adjusted_prevs, method="clip")
except np.linalg.LinAlgError:
adjusted_prevs = prevs_estim # no way to adjust them!