Skip to content

Stein thinning

Setup

Let's say that we want to thin a Gaussian sample with Stein thinning [1] and regularized Stein thinning [2], and that the target distribution is also Gaussian.

First generate a toy sample.

import jax.random as jr
import jax.numpy as jnp

rng_key = jr.PRNGKey(42)
x = jr.multivariate_normal(rng_key, mean=jnp.zeros(2), cov=jnp.eye(2), shape=(1_000,))

In order to apply Stein thinning, we need the values of the score function and a lengthscale.

Here, the expression of the target score function is straightforward and can be implemented by hand or with, e.g., jax.scipy.stats.

from jax.scipy.stats import multivariate_normal

def logprob_fn(x):
    return multivariate_normal(x, mean=jnp.zeros(2), cov=jnp.eye(2))

score_fn = jax.grad(logprob_fn)
score_values = jax.vmap(score_fn, 0)(x)

We also need a lengthscale which is chosen as the median heuristic.

Note that the utility function median_heuristic only accepts and returns NumPy arrays. This is a known limitation.

from kernax.utils import median_heuristic

lengthscale = jnp.asarray(median_heuristic(np.asarray(x)))

Vanilla algorithm

Stein thinning can applied as follows to select 100 points amongst the original sample.

from kernax import SteinThinning

stein_fn = SteinThinning(x, score_values, lengthscale)
indices = stein_fn(100)

Regularized variant

Regularized Stein thinning can be used in a similar fashion but requires two additional inputs:

  • The log-probability values.
  • The regularization terms introduced in the paper.
from kernax.utils import laplace_log_p_softplus

log_p_values = jax.vmap(logprob_fn, 0)(x)
laplace_log_p_values = laplace_log_p_softplus(x, score_fn)

from kernax import RegularizedSteinThinning

reg_stein_fn = RegularizedSteinThinning(x, log_p_values, score_values, laplace_log_p_values, lengthscale)
indices = reg_stein_fn(100)

References

[1] Riabiz, M., Chen, W. Y., Cockayne, J., Swietach, P., Niederer, S. A., Mackey, L., & Oates, C. J. (2022). Optimal thinning of MCMC output. Journal of the Royal Statistical Society Series B: Statistical Methodology, 84(4), 1059-1081.

[2] Bénard, C., Staber, B., & Da Veiga, S. (2023). Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization. NeurIPS 2023.