Skip to content

kernax.utils

Module containing utility functions for Kernax.

median_heuristic

median_heuristic(x)

Function that computes the median heuristic for the lengthscale parameter.

Parameters:

  • x (NDArray[int32 | int64]) –

    Sample matrix of size (n, d)

Returns:

  • Median heuristic value.

laplace_log_p_hardplus

laplace_log_p_hardplus(x, logprob_fn)

Function that computes the clipped laplacian of a log-probability for the provided sample matrix.

Parameters:

  • x (ArrayLike) –

    Sample matrix of size (n, d)

  • logprob_fn (LogProbFn) –

    Callable log-probability function

Returns:

  • JaxArray

    Values of the laplacian of log-probability, clipped to be non-negative.

laplace_log_p_softplus

laplace_log_p_softplus(x, logprob_fn)

Function that computes the sum of positive second-order derivatives of the log-probability for the provided sample matrix.

Parameters:

  • x (ArrayLike) –

    Sample matrix of size (n, d)

  • logprob_fn (LogProbFn) –

    Callable log-probability function

Returns:

  • JaxArray

    Values of the laplacian of log-probability