Skip to content

kernax.thinning

Module for Stein thinning algorithms using Langevin Stein operator and IMQ kernel.

SteinThinning

Bases: Module

Greedy Stein thinning with the Langevin Stein operator and IMQ kernel.

Once instantiated, the module can be called with an integer m to select m representative indices from the dataset by greedily minimizing the Stein objective (Riabiz et al., 2022).

Parameters:

  • x (JaxArray) –

    Samples to thin of shape (N, d).

  • score_p (JaxArray) –

    Scores of the target log-density evaluated at x (i.e., ∇_x log p(x)) of shape (N, d).

  • lengthscale (float) –

    IMQ kernel lengthscale. If None, uses the median heuristic on x.

  • stein_kernel (Callable) –

    A Stein kernel function with signature stein_kernel(xi, si, xj, sj, *, lengthscale) -> scalar. Defaults to :func:kernax.kernels.SteinIMQ.

__call__

__call__(m)

Select a subset of size m via greedy Stein thinning.

Parameters:

  • m (int) –

    Number of points to select (must be <= N).

Returns:

  • indices ( JaxArray ) –

    Indices of the selected points into x, of shape (m,).

RegularizedSteinThinning

Bases: Module

Regularized Stein thinning with the Langevin Stein operator and IMQ kernel.

Adds an entropy-style regularization term based on the target log-density.

Parameters:

  • x (JaxArray) –

    Samples to thin of shape (N, d).

  • log_p (JaxArray) –

    Target log-density evaluated at x of shape (N,).

  • score_p (JaxArray) –

    Scores of the target log-density at x (i.e., ∇_x log p(x)), of shape (N, d).

  • laplace_log_p (JaxArray) –

    Laplacian of the log-density at x, of shape (N,).

  • lengthscale (float) –

    IMQ kernel lengthscale. If None, uses the median heuristic on x.

  • stein_kernel (Callable) –

    A Stein kernel function with signature stein_kernel(xi, si, xj, sj, *, lengthscale) -> scalar. Defaults to :func:kernax.kernels.SteinIMQ.

__call__

__call__(m, weight_entropy=None)

Select a subset of size m via regularized Stein thinning.

Parameters:

  • m (int) –

    Number of points to select (must be <= N).

  • weight_entropy (float, default: None ) –

    Strength of the entropy regularization. Defaults to 1.0 / m.

Returns:

  • indices ( JaxArray ) –

    Indices of the selected points into x, of shape (m,).