Skip to content

kernax.kernels

Module containing kernel functions for use in Kernax.

IMQ

IMQ(x, y, lengthscale)

Inverse multi-quadratric kernel function.

Parameters:

  • x (JaxArray) –

    Vector of dimension d

  • y (JaxArray) –

    Vector of dimension d

  • lengthscale (float) –

    Scalar lengthscale / bandwidth

Returns:

  • Scalar

    Value of the kernel function k(x,y)

Gaussian

Gaussian(x, y, lengthscale)

Gaussian kernel function.

Parameters:

  • x (JaxArray) –

    Vector of dimension d

  • y (JaxArray) –

    Vector of dimension d

  • lengthscale (float) –

    Scalar lengthscale / bandwidth

Returns:

  • Scalar

    Value of the kernel function k(x,y)

Energy

Energy(x, y)

Distance induced kernel function.

Parameters:

  • x (JaxArray) –

    Vector of dimension d

  • y (JaxArray) –

    Vector of dimension d

  • lengthscale

    Scalar lengthscale / bandwidth

Returns:

  • Scalar

    Value of the kernel function k(x,y)

SteinIMQ

SteinIMQ(x, sx, y, sy, lengthscale)

Langevin Stein kernel with the IMQ as the underlying kernel.

Parameters:

  • x (JaxArray) –

    Vector of dimension d

  • sx (JaxArray) –

    Score functon evaluated at x, vector of dimension d

  • y (JaxArray) –

    Vector of dimension d

  • sy (JaxArray) –

    Score function evaluetaed at y, vector of dimension d

  • lengthscale (float) –

    Scalar lengthscale / bandwidth

Returns:

  • Scalar

    Value the Stein IMQ kernel k_p(x,y)

SteinGaussian

SteinGaussian(x, sx, y, sy, lengthscale)

Langevin Stein kernel with the Gaussian kernel as the underlying kernel.

Parameters:

  • x (JaxArray) –

    Vector of dimension d

  • sx (JaxArray) –

    Score function evaluated at x, vector of dimension d

  • y (JaxArray) –

    Vector of dimension d

  • sy (JaxArray) –

    Score function evaluetaed at y, vector of dimension d

  • lengthscale (float) –

    Scalar lengthscale / bandwidth

Returns:

  • Scalar

    Value the Stein Gaussian kernel k_p(x,y)

GetSteinFn

GetSteinFn(kernel_fn)

Helper that builds the Stein kernel function k_p(x,y) given an arbitrary underlying kernel k(x,y).

The function signature is kp_fn(x, sx, y, sy), where x and y are vectors of dimension d, and sx and sy are the score functions evaluated at x and y, respectively, both of dimension d.

Parameters:

  • kernel_fn (KernelFn) –

    Callable kernel function of the form (x,y) \mapsto k(x,y). Any hyperparameters such as the lengthscale should be fixed with jax.tree_util.Partial.

Returns:

  • SteinKernelFn

    A function that computes the Stein kernel k_p(x,y) given two vectors x and y, and their corresponding score functions sx and sy.