QuAcc/baselines/densratio/core.py

71 lines
2.1 KiB
Python
Raw Normal View History

2023-10-27 12:35:25 +02:00
"""
densratio.core
~~~~~~~~~~~~~~
Estimate Density Ratio p(x)/q(y)
"""
from numpy import linspace
from .helpers import to_ndarray
from .RuLSIF import RuLSIF
def densratio(
x, y, alpha=0, sigma_range="auto", lambda_range="auto", kernel_num=100, verbose=True
):
"""Estimate alpha-mixture Density Ratio p(x)/(alpha*p(x) + (1 - alpha)*q(x))
Arguments:
x: sample from p(x).
y: sample from q(x).
alpha: Default 0 - corresponds to ordinary density ratio.
sigma_range: search range of Gaussian kernel bandwidth.
Default "auto" means 10^-3, 10^-2, ..., 10^9.
lambda_range: search range of regularization parameter for uLSIF.
Default "auto" means 10^-3, 10^-2, ..., 10^9.
kernel_num: number of kernels. Default 100.
verbose: indicator to print messages. Default True.
Returns:
densratio.DensityRatio object which has `compute_density_ratio()`.
Raises:
ValueError: if dimension of x != dimension of y
Usage::
>>> from scipy.stats import norm
>>> from densratio import densratio
>>> x = norm.rvs(size=200, loc=1, scale=1./8)
>>> y = norm.rvs(size=200, loc=1, scale=1./2)
>>> result = densratio(x, y, alpha=0.7)
>>> print(result)
>>> density_ratio = result.compute_density_ratio(y)
>>> print(density_ratio)
"""
x = to_ndarray(x)
y = to_ndarray(y)
if x.shape[1] != y.shape[1]:
raise ValueError("x and y must be same dimensions.")
if isinstance(sigma_range, str) and sigma_range != "auto":
raise TypeError("Invalid value for sigma_range.")
if isinstance(lambda_range, str) and lambda_range != "auto":
raise TypeError("Invalid value for lambda_range.")
if sigma_range is None or (isinstance(sigma_range, str) and sigma_range == "auto"):
sigma_range = 10 ** linspace(-3, 9, 13)
if lambda_range is None or (
isinstance(lambda_range, str) and lambda_range == "auto"
):
lambda_range = 10 ** linspace(-3, 9, 13)
result = RuLSIF(x, y, alpha, sigma_range, lambda_range, kernel_num, verbose)
return result