71 lines
2.1 KiB
Python
71 lines
2.1 KiB
Python
|
"""
|
||
|
densratio.core
|
||
|
~~~~~~~~~~~~~~
|
||
|
|
||
|
Estimate Density Ratio p(x)/q(y)
|
||
|
"""
|
||
|
|
||
|
from numpy import linspace
|
||
|
|
||
|
from .helpers import to_ndarray
|
||
|
from .RuLSIF import RuLSIF
|
||
|
|
||
|
|
||
|
def densratio(
|
||
|
x, y, alpha=0, sigma_range="auto", lambda_range="auto", kernel_num=100, verbose=True
|
||
|
):
|
||
|
"""Estimate alpha-mixture Density Ratio p(x)/(alpha*p(x) + (1 - alpha)*q(x))
|
||
|
|
||
|
Arguments:
|
||
|
x: sample from p(x).
|
||
|
y: sample from q(x).
|
||
|
alpha: Default 0 - corresponds to ordinary density ratio.
|
||
|
sigma_range: search range of Gaussian kernel bandwidth.
|
||
|
Default "auto" means 10^-3, 10^-2, ..., 10^9.
|
||
|
lambda_range: search range of regularization parameter for uLSIF.
|
||
|
Default "auto" means 10^-3, 10^-2, ..., 10^9.
|
||
|
kernel_num: number of kernels. Default 100.
|
||
|
verbose: indicator to print messages. Default True.
|
||
|
|
||
|
Returns:
|
||
|
densratio.DensityRatio object which has `compute_density_ratio()`.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if dimension of x != dimension of y
|
||
|
|
||
|
Usage::
|
||
|
>>> from scipy.stats import norm
|
||
|
>>> from densratio import densratio
|
||
|
|
||
|
>>> x = norm.rvs(size=200, loc=1, scale=1./8)
|
||
|
>>> y = norm.rvs(size=200, loc=1, scale=1./2)
|
||
|
>>> result = densratio(x, y, alpha=0.7)
|
||
|
>>> print(result)
|
||
|
|
||
|
>>> density_ratio = result.compute_density_ratio(y)
|
||
|
>>> print(density_ratio)
|
||
|
"""
|
||
|
|
||
|
x = to_ndarray(x)
|
||
|
y = to_ndarray(y)
|
||
|
|
||
|
if x.shape[1] != y.shape[1]:
|
||
|
raise ValueError("x and y must be same dimensions.")
|
||
|
|
||
|
if isinstance(sigma_range, str) and sigma_range != "auto":
|
||
|
raise TypeError("Invalid value for sigma_range.")
|
||
|
|
||
|
if isinstance(lambda_range, str) and lambda_range != "auto":
|
||
|
raise TypeError("Invalid value for lambda_range.")
|
||
|
|
||
|
if sigma_range is None or (isinstance(sigma_range, str) and sigma_range == "auto"):
|
||
|
sigma_range = 10 ** linspace(-3, 9, 13)
|
||
|
|
||
|
if lambda_range is None or (
|
||
|
isinstance(lambda_range, str) and lambda_range == "auto"
|
||
|
):
|
||
|
lambda_range = 10 ** linspace(-3, 9, 13)
|
||
|
|
||
|
result = RuLSIF(x, y, alpha, sigma_range, lambda_range, kernel_num, verbose)
|
||
|
return result
|