QuAcc/baselines/mandoline.py

262 lines
7.9 KiB
Python
Raw Normal View History

2023-11-22 19:18:22 +01:00
from functools import partial
from types import SimpleNamespace
from typing import List, Optional
import numpy as np
import scipy.optimize
import scipy.special
import sklearn.metrics.pairwise as skmetrics
def Phi(
D: np.ndarray,
edge_list: np.ndarray = None,
):
"""
Given an n x d matrix of (example, slices), calculate the potential
matrix.
Includes correlations modeled by the edges in the `edge_list`.
Args:
D (np.ndarray): n x d matrix of (example, slice)
edge_list (np.ndarray): k x 2 matrix of edge correlations to be modeled.
edge_list[i, :] should be indices for a pair of columns of D.
Returns:
Potential matrix. Equals D when edge_list is None, otherwise adds additional
(x_i * x_j) "cross-terms" corresponding to the edges in the `edge_list`.
Examples:
>>> D = np.random.choice([-1, 1], size=(100, 6))
>>> edge_list = np.array([(0, 1), (1, 4)])
>>> Phi(D, edge_list)
"""
if edge_list is not None:
pairwise_terms = (
D[np.arange(len(D)), edge_list[:, 0][:, np.newaxis]].T
* D[np.arange(len(D)), edge_list[:, 1][:, np.newaxis]].T
)
return np.concatenate([D, pairwise_terms], axis=1)
else:
return D
def log_partition_ratio(
x: np.ndarray,
Phi_D_src: np.ndarray,
n_src: int,
):
"""
Calculate the log-partition ratio in the KLIEP problem.
"""
return np.log(n_src) - scipy.special.logsumexp(Phi_D_src.dot(x))
def mandoline(
D_src: np.ndarray,
D_tgt: np.ndarray,
edge_list: np.ndarray,
sigma: float = None,
):
"""
Mandoline solver.
Args:
D_src: (n_src x d) matrix of (example, slices) for the source distribution.
D_tgt: (n_tgt x d) matrix of (example, slices) for the source distribution.
edge_list: list of edge correlations between slices that should be modeled.
sigma: optional parameter that activates RBF kernel-based KLIEP with scale
`sigma`.
Returns: SimpleNamespace that contains
opt: result of scipy.optimize
Phi_D_src: source potential matrix used in Mandoline
Phi_D_tgt: target potential matrix used in Mandoline
n_src: number of source samples
n_tgt: number of target samples
edge_list: the `edge_list` parameter passed as input
"""
# Copy and binarize the input matrices to -1/1
D_src, D_tgt = np.copy(D_src), np.copy(D_tgt)
if np.min(D_src) == 0:
D_src[D_src == 0] = -1
D_tgt[D_tgt == 0] = -1
# Edge list encoding dependencies between gs
if edge_list is not None:
edge_list = np.array(edge_list)
# Create the potential matrices
Phi_D_tgt, Phi_D_src = Phi(D_tgt, edge_list), Phi(D_src, edge_list)
# Number of examples
n_src, n_tgt = Phi_D_src.shape[0], Phi_D_tgt.shape[0]
def f(x):
obj = Phi_D_tgt.dot(x).sum() - n_tgt * scipy.special.logsumexp(Phi_D_src.dot(x))
return -obj
# Set the kernel
kernel = partial(skmetrics.rbf_kernel, gamma=sigma)
def llkliep_f(x):
obj = kernel(
Phi_D_tgt, x[:, np.newaxis]
).sum() - n_tgt * scipy.special.logsumexp(kernel(Phi_D_src, x[:, np.newaxis]))
return -obj
# Solve
if not sigma:
opt = scipy.optimize.minimize(
f, np.random.randn(Phi_D_tgt.shape[1]), method="BFGS"
)
else:
opt = scipy.optimize.minimize(
llkliep_f, np.random.randn(Phi_D_tgt.shape[1]), method="BFGS"
)
return SimpleNamespace(
opt=opt,
Phi_D_src=Phi_D_src,
Phi_D_tgt=Phi_D_tgt,
n_src=n_src,
n_tgt=n_tgt,
edge_list=edge_list,
)
def log_density_ratio(D, solved):
"""
Calculate the log density ratio for a solved Mandoline run.
"""
Phi_D = Phi(D, None)
return Phi_D.dot(solved.opt.x) + log_partition_ratio(
solved.opt.x, solved.Phi_D_src, solved.n_src
)
def get_k_most_unbalanced_gs(D_src, D_tgt, k):
"""
Get the top k slices that shift most between source and target
distributions.
Uses difference in marginals between each slice.
"""
marginal_diff = np.abs(D_src.mean(axis=0) - D_tgt.mean(axis=0))
differences = np.sort(marginal_diff)[-k:]
indices = np.argsort(marginal_diff)[-k:]
return list(indices), list(differences)
def weighted_estimator(weights: Optional[np.ndarray], mat: np.ndarray):
"""
Calculate a weighted empirical mean over a matrix of samples.
Args:
weights (Optional[np.ndarray]):
length n array of weights that sums to 1. Calculates an unweighted
mean if `weights` is None.
mat (np.ndarray):
(n x r) matrix of empirical observations that is being averaged.
Returns:
Length r np.ndarray of weighted means.
"""
_sum_weights = np.sum(weights)
if _sum_weights != 1.0:
if (_err := abs(1.0 - _sum_weights)) > 1e-15:
print(_err)
assert _sum_weights == 1, "`weights` must sum to 1."
if weights is None:
return np.mean(mat, axis=0)
return np.sum(weights[:, np.newaxis] * mat, axis=0)
def estimate_performance(
D_src: np.ndarray,
D_tgt: np.ndarray,
edge_list: np.ndarray,
empirical_mat_list_src: List[np.ndarray],
):
"""
Estimate performance on a target distribution using slice information from the
source and target data.
This function runs Mandoline to calculate the importance weights to reweight
the source data.
Args:
D_src (np.ndarray): (n_src x d) matrix of (example, slices) for the source
distribution.
D_tgt (np.ndarray): (n_tgt x d) matrix of (example, slices) for the target
distribution.
edge_list (np.ndarray):
empirical_mat_list_src (List[np.ndarray]):
Returns:
SimpleNamespace with 3 attributes
- `all_estimates` is a list of SimpleNamespace objects with
2 attributes
- `weighted` is the estimate for the target distribution
- `source` is the estimate for the source distribution
- `solved`: result of scipy.optimize Mandoline solver
- `weights`: self-normalized importance weights used to weight the source data
"""
# Run the solver
solved = mandoline(D_src, D_tgt, edge_list)
# Compute the weights on the source dataset
density_ratios = np.e ** log_density_ratio(solved.Phi_D_src, solved)
# Self-normalized importance weights
weights = density_ratios / np.sum(density_ratios)
all_estimates = []
for mat_src in empirical_mat_list_src:
# Estimates is a 1-D array of estimates for each mat e.g.
# each mat can correspond to a model's (n x 1) error matrix
weighted_estimates = weighted_estimator(weights, mat_src)
source_estimates = weighted_estimator(
np.ones(solved.n_src) / solved.n_src, mat_src
)
all_estimates.append(
SimpleNamespace(
weighted=weighted_estimates,
source=source_estimates,
)
)
return SimpleNamespace(
all_estimates=all_estimates,
solved=solved,
weights=weights,
)
###########################################################################
def get_entropy(probas):
return -np.sum(np.multiply(probas, np.log(probas + 1e-20)), axis=1)
def get_slices(probas, n_ent_bins=6):
ln, ncl = probas.shape
preds = np.argmax(probas, axis=1)
pred_slices = np.full((ln, ncl), fill_value=-1, dtype="<i8")
pred_slices[np.arange(ln), preds] = 1
ent = get_entropy(probas)
range_top = get_entropy(np.array([np.ones(ncl) / ncl]))[0]
ent_bins = np.linspace(0, range_top, n_ent_bins + 1)
bins_map = np.digitize(ent, bins=ent_bins, right=True) - 1
ent_slices = np.full((ln, n_ent_bins), fill_value=-1, dtype="<i8")
ent_slices[np.arange(ln), bins_map] = 1
return np.concatenate([pred_slices, ent_slices], axis=1)