kernax.bjsamplers
High-level wrappers for BlackJAX samplers.
hmc
hmc(
logprob_fn,
init_positions,
num_samples,
step_size,
inverse_mass_matrix,
num_integration_steps,
rng_key,
)
Wrapper of the HMC algorithm implemented in BlackJAX.
Parameters:
-
logprob_fn(LogProbFn) –Callable function that returns the log-probability of the target.
-
init_positions(JaxArray) –Initial guess for the HMC algorithm.
-
num_samples(int) –Number of iterations (burn-in included).
-
step_size(float) –Step size of the leapfrog integrator.
-
inverse_mass_matrix(JaxArray) –Flattened inverse mass matrix.
-
num_integration_steps(int) –Number of leapfrog steps.
-
rng_key(PRNGKeyArray) –A JAX PRNGKey.
Returns:
-
tuple[HMCInfo, HMCState]–A HMCState: tuple of states and informations.
nuts
nuts(
logprob_fn,
init_positions,
num_samples,
step_size,
inverse_mass_matrix,
rng_key,
)
Wrapper of the NUTS algorithm implemented in BlackJAX.
Parameters:
-
logprob_fn(LogProbFn) –Callable function that returns the log-probability of the target.
-
init_positions(JaxArray) –Initial guess for the NUTS algorithm.
-
num_samples(int) –Number of iterations (burn-in included).
-
step_size(float) –Step size of the leapfrog integrator.
-
inverse_mass_matrix(JaxArray) –Flattened inverse mass matrix.
-
rng_key(PRNGKeyArray) –A JAX PRNGKey.
Returns:
-
tuple[HMCInfo, HMCState]–A HMCState: tuple of states and informations.
mala
mala(
logprob_fn,
init_positions,
num_samples,
step_size,
rng_key,
)
Wrapper of the MALA algorithm implemented in BlackJAX.
Parameters:
-
logprob_fn(LogProbFn) –Callable function that returns the log-probability of the target.
-
init_positions(JaxArray) –Initial guess for the MALA algorithm.
-
num_samples(int) –Number of iterations (burn-in included).
-
step_size(float) –Step size of the MALA algorithm.
-
rng_key(PRNGKeyArray) –A JAX PRNGKey.
Returns:
-
tuple[MALAInfo, MALAState]–A MALAState: tuple of states and informations.