Skip to content

Quick start

Example usage of Stein thinning on a Gaussian sample:

import jax
import jax.numpy as jnp
from jax.scipy.stats import multivariate_normal
from kernax.utils import median_heuristic
from kernax import SteinThinning

rng_key = jax.random.PRNGKey(0)
x = jax.random.normal(rng_key, (1000,2))

def logprob_fn(x):
    return multivariate_normal.logpdf(x, mean=jnp.zeros(2), cov=jnp.eye(2))
score_fn = jax.grad(logprob_fn)
score_values = jax.vmap(score_fn, 0)(x)

lengthscale = jnp.array([median_heuristic(x)])
stein_fn = SteinThinning(x, score_values, lengthscale)
indices = stein_fn(100)

To use the regularized variant, add a few lines:

from kernax.utils import laplace_log_p_softplus
from kernax import RegularizedSteinThinning

log_p = jax.vmap(score_fn, 0)(x)
laplace_log_p_values = laplace_log_p_softplus(x, score_fn)

reg_stein_fn = RegularizedSteinThinning(x, log_p, score_values, laplace_log_p_values, lengthscale)
indices = reg_stein_fn(100)