kernax.kernels
Module containing kernel functions for use in Kernax.
IMQ
IMQ(x, y, lengthscale)
Inverse multi-quadratric kernel function.
Parameters:
-
x(JaxArray) –Vector of dimension
d -
y(JaxArray) –Vector of dimension
d -
lengthscale(float) –Scalar lengthscale / bandwidth
Returns:
-
Scalar–Value of the kernel function
k(x,y)
Gaussian
Gaussian(x, y, lengthscale)
Gaussian kernel function.
Parameters:
-
x(JaxArray) –Vector of dimension
d -
y(JaxArray) –Vector of dimension
d -
lengthscale(float) –Scalar lengthscale / bandwidth
Returns:
-
Scalar–Value of the kernel function
k(x,y)
Energy
Energy(x, y)
Distance induced kernel function.
Parameters:
-
x(JaxArray) –Vector of dimension
d -
y(JaxArray) –Vector of dimension
d -
lengthscale–Scalar lengthscale / bandwidth
Returns:
-
Scalar–Value of the kernel function
k(x,y)
SteinIMQ
SteinIMQ(x, sx, y, sy, lengthscale)
Langevin Stein kernel with the IMQ as the underlying kernel.
Parameters:
-
x(JaxArray) –Vector of dimension
d -
sx(JaxArray) –Score functon evaluated at x, vector of dimension
d -
y(JaxArray) –Vector of dimension
d -
sy(JaxArray) –Score function evaluetaed at y, vector of dimension
d -
lengthscale(float) –Scalar lengthscale / bandwidth
Returns:
-
Scalar–Value the Stein IMQ kernel
k_p(x,y)
SteinGaussian
SteinGaussian(x, sx, y, sy, lengthscale)
Langevin Stein kernel with the Gaussian kernel as the underlying kernel.
Parameters:
-
x(JaxArray) –Vector of dimension
d -
sx(JaxArray) –Score function evaluated at x, vector of dimension
d -
y(JaxArray) –Vector of dimension
d -
sy(JaxArray) –Score function evaluetaed at y, vector of dimension
d -
lengthscale(float) –Scalar lengthscale / bandwidth
Returns:
-
Scalar–Value the Stein Gaussian kernel
k_p(x,y)
GetSteinFn
GetSteinFn(kernel_fn)
Helper that builds the Stein kernel function k_p(x,y) given an arbitrary underlying kernel k(x,y).
The function signature is kp_fn(x, sx, y, sy), where x and y are vectors of dimension d, and sx and sy are the score functions evaluated at x and y, respectively, both of dimension d.
Parameters:
-
kernel_fn(KernelFn) –Callable kernel function of the form
(x,y) \mapsto k(x,y). Any hyperparameters such as the lengthscale should be fixed withjax.tree_util.Partial.
Returns:
-
SteinKernelFn–A function that computes the Stein kernel
k_p(x,y)given two vectorsxandy, and their corresponding score functionssxandsy.