Skip to content

kernax.bjsamplers

High-level wrappers for BlackJAX samplers.

hmc

hmc(
    logprob_fn,
    init_positions,
    num_samples,
    step_size,
    inverse_mass_matrix,
    num_integration_steps,
    rng_key,
)

Wrapper of the HMC algorithm implemented in BlackJAX.

Parameters:

  • logprob_fn (LogProbFn) –

    Callable function that returns the log-probability of the target.

  • init_positions (JaxArray) –

    Initial guess for the HMC algorithm.

  • num_samples (int) –

    Number of iterations (burn-in included).

  • step_size (float) –

    Step size of the leapfrog integrator.

  • inverse_mass_matrix (JaxArray) –

    Flattened inverse mass matrix.

  • num_integration_steps (int) –

    Number of leapfrog steps.

  • rng_key (PRNGKeyArray) –

    A JAX PRNGKey.

Returns:

  • tuple[HMCInfo, HMCState]

    A HMCState: tuple of states and informations.

nuts

nuts(
    logprob_fn,
    init_positions,
    num_samples,
    step_size,
    inverse_mass_matrix,
    rng_key,
)

Wrapper of the NUTS algorithm implemented in BlackJAX.

Parameters:

  • logprob_fn (LogProbFn) –

    Callable function that returns the log-probability of the target.

  • init_positions (JaxArray) –

    Initial guess for the NUTS algorithm.

  • num_samples (int) –

    Number of iterations (burn-in included).

  • step_size (float) –

    Step size of the leapfrog integrator.

  • inverse_mass_matrix (JaxArray) –

    Flattened inverse mass matrix.

  • rng_key (PRNGKeyArray) –

    A JAX PRNGKey.

Returns:

  • tuple[HMCInfo, HMCState]

    A HMCState: tuple of states and informations.

mala

mala(
    logprob_fn,
    init_positions,
    num_samples,
    step_size,
    rng_key,
)

Wrapper of the MALA algorithm implemented in BlackJAX.

Parameters:

  • logprob_fn (LogProbFn) –

    Callable function that returns the log-probability of the target.

  • init_positions (JaxArray) –

    Initial guess for the MALA algorithm.

  • num_samples (int) –

    Number of iterations (burn-in included).

  • step_size (float) –

    Step size of the MALA algorithm.

  • rng_key (PRNGKeyArray) –

    A JAX PRNGKey.

Returns:

  • tuple[MALAInfo, MALAState]

    A MALAState: tuple of states and informations.