Skip to content

kernax

Kernax package.

KernelHerding

Bases: Module

Greedy MMD-based kernel quantization (herding-style thinning).

Once instantiated, the module can be called with an integer m to select m representative samples from the input dataset.

Parameters:

  • x

    The dataset of shape (N, d) to be subsampled.

  • kernel_fn

    Kernel function of the form k(x: (d,), y: (d,)) -> scalar.

__call__

__call__(m)

Return indices of a subset (size m) that greedily minimizes MMD.

Parameters:

  • m (int) –

    Number of points to select (must be <= N).

Returns:

  • JaxArray

    Indices of the selected points into x, array of shape (m,).

RegularizedSteinThinning

Bases: Module

Regularized Stein thinning with the Langevin Stein operator and IMQ kernel.

Adds an entropy-style regularization term based on the target log-density.

Parameters:

  • x (JaxArray) –

    Samples to thin of shape (N, d).

  • log_p (JaxArray) –

    Target log-density evaluated at x of shape (N,).

  • score_p (JaxArray) –

    Scores of the target log-density at x (i.e., ∇_x log p(x)), of shape (N, d).

  • laplace_log_p (JaxArray) –

    Laplacian of the log-density at x, of shape (N,).

  • lengthscale (float) –

    IMQ kernel lengthscale. If None, uses the median heuristic on x.

  • stein_kernel (Callable) –

    A Stein kernel function with signature stein_kernel(xi, si, xj, sj, *, lengthscale) -> scalar. Defaults to :func:kernax.kernels.SteinIMQ.

__call__

__call__(m, weight_entropy=None)

Select a subset of size m via regularized Stein thinning.

Parameters:

  • m (int) –

    Number of points to select (must be <= N).

  • weight_entropy (float, default: None ) –

    Strength of the entropy regularization. Defaults to 1.0 / m.

Returns:

  • indices ( JaxArray ) –

    Indices of the selected points into x, of shape (m,).

SteinThinning

Bases: Module

Greedy Stein thinning with the Langevin Stein operator and IMQ kernel.

Once instantiated, the module can be called with an integer m to select m representative indices from the dataset by greedily minimizing the Stein objective (Riabiz et al., 2022).

Parameters:

  • x (JaxArray) –

    Samples to thin of shape (N, d).

  • score_p (JaxArray) –

    Scores of the target log-density evaluated at x (i.e., ∇_x log p(x)) of shape (N, d).

  • lengthscale (float) –

    IMQ kernel lengthscale. If None, uses the median heuristic on x.

  • stein_kernel (Callable) –

    A Stein kernel function with signature stein_kernel(xi, si, xj, sj, *, lengthscale) -> scalar. Defaults to :func:kernax.kernels.SteinIMQ.

__call__

__call__(m)

Select a subset of size m via greedy Stein thinning.

Parameters:

  • m (int) –

    Number of points to select (must be <= N).

Returns:

  • indices ( JaxArray ) –

    Indices of the selected points into x, of shape (m,).