kernax.thinning
Module for Stein thinning algorithms using Langevin Stein operator and IMQ kernel.
SteinThinning
Bases: Module
Greedy Stein thinning with the Langevin Stein operator and IMQ kernel.
Once instantiated, the module can be called with an integer m to select
m representative indices from the dataset by greedily minimizing the
Stein objective (Riabiz et al., 2022).
Parameters:
-
x(JaxArray) –Samples to thin of shape
(N, d). -
score_p(JaxArray) –Scores of the target log-density evaluated at
x(i.e.,∇_x log p(x)) of shape(N, d). -
lengthscale(float) –IMQ kernel lengthscale. If
None, uses the median heuristic onx. -
stein_kernel(Callable) –A Stein kernel function with signature
stein_kernel(xi, si, xj, sj, *, lengthscale) -> scalar. Defaults to :func:kernax.kernels.SteinIMQ.
__call__
__call__(m)
Select a subset of size m via greedy Stein thinning.
Parameters:
-
m(int) –Number of points to select (must be <= N).
Returns:
-
indices(JaxArray) –Indices of the selected points into
x, of shape(m,).
RegularizedSteinThinning
Bases: Module
Regularized Stein thinning with the Langevin Stein operator and IMQ kernel.
Adds an entropy-style regularization term based on the target log-density.
Parameters:
-
x(JaxArray) –Samples to thin of shape
(N, d). -
log_p(JaxArray) –Target log-density evaluated at
xof shape(N,). -
score_p(JaxArray) –Scores of the target log-density at
x(i.e.,∇_x log p(x)), of shape(N, d). -
laplace_log_p(JaxArray) –Laplacian of the log-density at
x, of shape(N,). -
lengthscale(float) –IMQ kernel lengthscale. If
None, uses the median heuristic onx. -
stein_kernel(Callable) –A Stein kernel function with signature
stein_kernel(xi, si, xj, sj, *, lengthscale) -> scalar. Defaults to :func:kernax.kernels.SteinIMQ.
__call__
__call__(m, weight_entropy=None)
Select a subset of size m via regularized Stein thinning.
Parameters:
-
m(int) –Number of points to select (must be <= N).
-
weight_entropy(float, default:None) –Strength of the entropy regularization. Defaults to
1.0 / m.
Returns:
-
indices(JaxArray) –Indices of the selected points into
x, of shape (m,).