Quick start
Example usage of Stein thinning on a Gaussian sample:
import jax
import jax.numpy as jnp
from jax.scipy.stats import multivariate_normal
from kernax.utils import median_heuristic
from kernax import SteinThinning
rng_key = jax.random.PRNGKey(0)
x = jax.random.normal(rng_key, (1000,2))
def logprob_fn(x):
return multivariate_normal.logpdf(x, mean=jnp.zeros(2), cov=jnp.eye(2))
score_fn = jax.grad(logprob_fn)
score_values = jax.vmap(score_fn, 0)(x)
lengthscale = jnp.array([median_heuristic(x)])
stein_fn = SteinThinning(x, score_values, lengthscale)
indices = stein_fn(100)
To use the regularized variant, add a few lines:
from kernax.utils import laplace_log_p_softplus
from kernax import RegularizedSteinThinning
log_p = jax.vmap(score_fn, 0)(x)
laplace_log_p_values = laplace_log_p_softplus(x, score_fn)
reg_stein_fn = RegularizedSteinThinning(x, log_p, score_values, laplace_log_p_values, lengthscale)
indices = reg_stein_fn(100)