Refactor solving routine
This commit is contained in:
parent
4dd66b1921
commit
d34b086a76
|
@ -1,4 +1,5 @@
|
|||
import itertools
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Literal, Union, Callable
|
||||
|
||||
|
@ -426,3 +427,66 @@ def clip_prevalence(p: np.ndarray, method: Literal[None, "none", "clip", "projec
|
|||
return _project_onto_probability_simplex(p)
|
||||
else:
|
||||
raise ValueError(f"Method {method} not known.")
|
||||
|
||||
|
||||
def solve_adjustment(
|
||||
p_c_y: np.ndarray,
|
||||
p_c: np.ndarray,
|
||||
method: Literal["inversion", "invariant-ratio"],
|
||||
solver: Literal["exact", "minimize", "exact-raise", "exact-cc"],
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Function finding the prevalence vector by adjusting
|
||||
the classifier predictions.
|
||||
|
||||
:param p_c_y: array of shape `(n_classes, n_classes,)` with entry `(c,y)` being the estimate
|
||||
of :math:`P(C=c|Y=y)`, that is, the probability that an instance that belongs to class :math:`y`
|
||||
ends up being classified as belonging to class :math:`c`
|
||||
:param p_c: classifier predictions, where the entry `c` is the estimate of :math:`P(C=c)`. Shape `(n_classes,)`
|
||||
:param method: adjustment method to be used:
|
||||
'inversion': matrix inversion method based on the matrix equality :math:`P(C)=P(C|Y)P(Y)`,
|
||||
which tries to invert `P(C|Y)` matrix.
|
||||
'invariant-ratio': invariant ratio estimator of `Vaz et al. <https://jmlr.org/papers/v20/18-456.html>`_,
|
||||
which replaces the last equation with the normalization condition.
|
||||
:param solver: the method to use for solving the system of linear equations. Valid options are:
|
||||
'exact-raise': tries to solve the system using matrix inversion. Raises an error if the matrix has
|
||||
rank strictly less than `n_classes`.
|
||||
'exact-cc': if the matrix is not of full rank, returns `p_c` as the estimates, which corresponds
|
||||
to no adjustment (i.e., the classify and count method. See :class:`quapy.method.aggregative.CC`)
|
||||
'exact': deprecated, defaults to 'exact-cc'
|
||||
'minimize': minimizes a loss, so the solution always exists
|
||||
"""
|
||||
if solver == "exact":
|
||||
warnings.warn("The 'exact' solver is deprecated. Use 'exact-raise' or 'exact-cc'", DeprecationWarning, stacklevel=2)
|
||||
solver = "exact-cc"
|
||||
|
||||
A = np.array(p_c_y, dtype=float)
|
||||
B = np.array(p_c, dtype=float)
|
||||
|
||||
if method == "inversion":
|
||||
pass # We leave A and B unchanged
|
||||
elif method == "invariant-ratio":
|
||||
# Change the last set of equations
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError(f"Flavour {method} not known.")
|
||||
|
||||
|
||||
if solver == "minimize":
|
||||
def loss(prev):
|
||||
return np.linalg.norm(A @ prev - B)
|
||||
return optim_minimize(loss, n_classes=A.shape[0])
|
||||
else:
|
||||
# Solvers based on matrix inversion, so we use try/except block
|
||||
try:
|
||||
return np.linalg.solve(A, B)
|
||||
except np.linalg.LinAlgError:
|
||||
# The matrix is not invertible.
|
||||
# Depending on the solver, we either raise an error
|
||||
# or return the classifier predictions without adjustment
|
||||
if solver == "exact-raise":
|
||||
raise
|
||||
elif solver == "exact-cc":
|
||||
return p_c
|
||||
else:
|
||||
raise ValueError(f"Solver {solver} not known.")
|
||||
|
|
|
@ -435,27 +435,13 @@ class ACC(AggregativeCrispQuantifier):
|
|||
:return: an adjusted `np.ndarray` of shape `(n_classes,)` with the corrected class prevalence estimates
|
||||
"""
|
||||
|
||||
A = PteCondEstim
|
||||
B = prevs_estim
|
||||
|
||||
if solver == 'exact':
|
||||
# attempts an exact solution of the linear system (may fail)
|
||||
|
||||
try:
|
||||
adjusted_prevs = np.linalg.solve(A, B)
|
||||
adjusted_prevs = F.clip_prevalence(adjusted_prevs, method="clip")
|
||||
except np.linalg.LinAlgError:
|
||||
adjusted_prevs = prevs_estim # no way to adjust them!
|
||||
|
||||
return adjusted_prevs
|
||||
|
||||
elif solver == 'minimize':
|
||||
# poses the problem as an optimization one, and tries to minimize the norm of the differences
|
||||
|
||||
def loss(prev):
|
||||
return np.linalg.norm(A @ prev - B)
|
||||
|
||||
return F.optim_minimize(loss, n_classes=A.shape[0])
|
||||
estimate = F.solve_adjustment(
|
||||
p_c_y=PteCondEstim,
|
||||
p_c=prevs_estim,
|
||||
solver=solver,
|
||||
method='inversion',
|
||||
)
|
||||
return F.clip_prevalence(estimate, method="clip")
|
||||
|
||||
|
||||
class PCC(AggregativeSoftQuantifier):
|
||||
|
|
Loading…
Reference in New Issue