Refactor solving routine

This commit is contained in:
Paweł Czyż 2024-03-15 17:58:23 +01:00
parent 4dd66b1921
commit d34b086a76
2 changed files with 71 additions and 21 deletions

View File

@ -1,4 +1,5 @@
import itertools
import warnings
from collections import defaultdict
from typing import Literal, Union, Callable
@ -426,3 +427,66 @@ def clip_prevalence(p: np.ndarray, method: Literal[None, "none", "clip", "projec
return _project_onto_probability_simplex(p)
else:
raise ValueError(f"Method {method} not known.")
def solve_adjustment(
p_c_y: np.ndarray,
p_c: np.ndarray,
method: Literal["inversion", "invariant-ratio"],
solver: Literal["exact", "minimize", "exact-raise", "exact-cc"],
) -> np.ndarray:
"""
Function finding the prevalence vector by adjusting
the classifier predictions.
:param p_c_y: array of shape `(n_classes, n_classes,)` with entry `(c,y)` being the estimate
of :math:`P(C=c|Y=y)`, that is, the probability that an instance that belongs to class :math:`y`
ends up being classified as belonging to class :math:`c`
:param p_c: classifier predictions, where the entry `c` is the estimate of :math:`P(C=c)`. Shape `(n_classes,)`
:param method: adjustment method to be used:
'inversion': matrix inversion method based on the matrix equality :math:`P(C)=P(C|Y)P(Y)`,
which tries to invert `P(C|Y)` matrix.
'invariant-ratio': invariant ratio estimator of `Vaz et al. <https://jmlr.org/papers/v20/18-456.html>`_,
which replaces the last equation with the normalization condition.
:param solver: the method to use for solving the system of linear equations. Valid options are:
'exact-raise': tries to solve the system using matrix inversion. Raises an error if the matrix has
rank strictly less than `n_classes`.
'exact-cc': if the matrix is not of full rank, returns `p_c` as the estimates, which corresponds
to no adjustment (i.e., the classify and count method. See :class:`quapy.method.aggregative.CC`)
'exact': deprecated, defaults to 'exact-cc'
'minimize': minimizes a loss, so the solution always exists
"""
if solver == "exact":
warnings.warn("The 'exact' solver is deprecated. Use 'exact-raise' or 'exact-cc'", DeprecationWarning, stacklevel=2)
solver = "exact-cc"
A = np.array(p_c_y, dtype=float)
B = np.array(p_c, dtype=float)
if method == "inversion":
pass # We leave A and B unchanged
elif method == "invariant-ratio":
# Change the last set of equations
raise NotImplementedError
else:
raise ValueError(f"Flavour {method} not known.")
if solver == "minimize":
def loss(prev):
return np.linalg.norm(A @ prev - B)
return optim_minimize(loss, n_classes=A.shape[0])
else:
# Solvers based on matrix inversion, so we use try/except block
try:
return np.linalg.solve(A, B)
except np.linalg.LinAlgError:
# The matrix is not invertible.
# Depending on the solver, we either raise an error
# or return the classifier predictions without adjustment
if solver == "exact-raise":
raise
elif solver == "exact-cc":
return p_c
else:
raise ValueError(f"Solver {solver} not known.")

View File

@ -435,27 +435,13 @@ class ACC(AggregativeCrispQuantifier):
:return: an adjusted `np.ndarray` of shape `(n_classes,)` with the corrected class prevalence estimates
"""
A = PteCondEstim
B = prevs_estim
if solver == 'exact':
# attempts an exact solution of the linear system (may fail)
try:
adjusted_prevs = np.linalg.solve(A, B)
adjusted_prevs = F.clip_prevalence(adjusted_prevs, method="clip")
except np.linalg.LinAlgError:
adjusted_prevs = prevs_estim # no way to adjust them!
return adjusted_prevs
elif solver == 'minimize':
# poses the problem as an optimization one, and tries to minimize the norm of the differences
def loss(prev):
return np.linalg.norm(A @ prev - B)
return F.optim_minimize(loss, n_classes=A.shape[0])
estimate = F.solve_adjustment(
p_c_y=PteCondEstim,
p_c=prevs_estim,
solver=solver,
method='inversion',
)
return F.clip_prevalence(estimate, method="clip")
class PCC(AggregativeSoftQuantifier):