kernax.utils
Module containing utility functions for Kernax.
median_heuristic
median_heuristic(x)
Function that computes the median heuristic for the lengthscale parameter.
Parameters:
-
x(NDArray[int32 | int64]) –Sample matrix of size
(n, d)
Returns:
-
–
Median heuristic value.
laplace_log_p_hardplus
laplace_log_p_hardplus(x, logprob_fn)
Function that computes the clipped laplacian of a log-probability for the provided sample matrix.
Parameters:
-
x(ArrayLike) –Sample matrix of size
(n, d) -
logprob_fn(LogProbFn) –Callable log-probability function
Returns:
-
JaxArray–Values of the laplacian of log-probability, clipped to be non-negative.
laplace_log_p_softplus
laplace_log_p_softplus(x, logprob_fn)
Function that computes the sum of positive second-order derivatives of the log-probability for the provided sample matrix.
Parameters:
-
x(ArrayLike) –Sample matrix of size
(n, d) -
logprob_fn(LogProbFn) –Callable log-probability function
Returns:
-
JaxArray–Values of the laplacian of log-probability