kernax.discrepancies
Module for discrepancy measures.
MMD
MMD(x, y)
Implements the V-estimator the maximum mean discrepancy.
Parameters:
-
x(JaxArray) –Sample matrix of size
(n, d). -
y(JaxArray) –Sample matrix of size
(m, d).
Returns:
-
JaxArray(Scalar) –The V-estimator of the maximum mean discrepancy.
KSD
KSD(x, sx, kernel_fn=None)
Implements the V-estimator of kernelized Stein discrepancy.
This function is not jittable yet.
Parameters:
-
x(JaxArray) –Sample matrix of size
(n, d). -
sx(JaxArray) –Score function evaluated at x, matrix of size
(n, d). -
kernel_fn(Optional[Callable], default:None) –Kernel function that takes two inputs and returns a scalar. If None, the IMQ kernel is used with a lengthscale determined by the median heuristic.
Returns:
-
JaxArray(Scalar) –The V-estimator of the kernelized Stein discrepancy.