Add projection onto the probability simplex
This commit is contained in:
parent
020530e14f
commit
4dd66b1921
|
@ -1,6 +1,6 @@
|
|||
import itertools
|
||||
from collections import defaultdict
|
||||
from typing import Union, Callable
|
||||
from typing import Literal, Union, Callable
|
||||
|
||||
import scipy
|
||||
import numpy as np
|
||||
|
@ -374,4 +374,55 @@ def linear_search(loss, n_classes):
|
|||
if min_score is None or score < min_score:
|
||||
prev_selected, min_score = prev, score
|
||||
|
||||
return np.asarray([1 - prev_selected, prev_selected])
|
||||
return np.asarray([1 - prev_selected, prev_selected])
|
||||
|
||||
|
||||
def _project_onto_probability_simplex(v: np.ndarray) -> np.ndarray:
|
||||
"""Projects a point onto the probability simplex.
|
||||
|
||||
The code is adapted from Mathieu Blondel's BSD-licensed
|
||||
`implementation <https://gist.github.com/mblondel/6f3b7aaad90606b98f71>`_
|
||||
which is accompanying the paper
|
||||
|
||||
Mathieu Blondel, Akinori Fujino, and Naonori Ueda.
|
||||
Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex,
|
||||
ICPR 2014, `URL <http://www.mblondel.org/publications/mblondel-icpr2014.pdf>`_
|
||||
|
||||
:param v: point in n-dimensional space, shape `(n,)`
|
||||
:return: projection of `v` onto (n-1)-dimensional probability simplex, shape `(n,)`
|
||||
"""
|
||||
v = np.asarray(v)
|
||||
n = len(v)
|
||||
|
||||
# Sort the values in the descending order
|
||||
u = np.sort(v)[::-1]
|
||||
|
||||
cssv = np.cumsum(u) - 1.0
|
||||
ind = np.arange(1, n + 1)
|
||||
cond = u - cssv / ind > 0
|
||||
rho = ind[cond][-1]
|
||||
theta = cssv[cond][-1] / float(rho)
|
||||
return np.maximum(v - theta, 0)
|
||||
|
||||
|
||||
|
||||
def clip_prevalence(p: np.ndarray, method: Literal[None, "none", "clip", "project"]) -> np.ndarray:
|
||||
"""
|
||||
Clips the proportions vector `p` so that it is a valid probability distribution.
|
||||
|
||||
:param p: the proportions vector to be clipped, shape `(n_classes,)`
|
||||
:param method: the method to use for normalization.
|
||||
If `None` or `"none"`, no normalization is performed.
|
||||
If `"clip"`, the values are clipped to the range [0,1] and normalized, so they sum up to 1.
|
||||
If `"project"`, the values are projected onto the probability simplex.
|
||||
:return: the normalized prevalence vector, shape `(n_classes,)`
|
||||
"""
|
||||
if method is None or method == "none":
|
||||
return p
|
||||
elif method == "clip":
|
||||
adjusted = np.clip(p, 0, 1)
|
||||
return adjusted / adjusted.sum()
|
||||
elif method == "project":
|
||||
return _project_onto_probability_simplex(p)
|
||||
else:
|
||||
raise ValueError(f"Method {method} not known.")
|
||||
|
|
|
@ -443,8 +443,7 @@ class ACC(AggregativeCrispQuantifier):
|
|||
|
||||
try:
|
||||
adjusted_prevs = np.linalg.solve(A, B)
|
||||
adjusted_prevs = np.clip(adjusted_prevs, 0, 1)
|
||||
adjusted_prevs /= adjusted_prevs.sum()
|
||||
adjusted_prevs = F.clip_prevalence(adjusted_prevs, method="clip")
|
||||
except np.linalg.LinAlgError:
|
||||
adjusted_prevs = prevs_estim # no way to adjust them!
|
||||
|
||||
|
|
Loading…
Reference in New Issue