Skip to content

kernax.discrepancies

Module for discrepancy measures.

MMD

MMD(x, y)

Implements the V-estimator the maximum mean discrepancy.

Parameters:

  • x (JaxArray) –

    Sample matrix of size (n, d).

  • y (JaxArray) –

    Sample matrix of size (m, d).

Returns:

  • JaxArray ( Scalar ) –

    The V-estimator of the maximum mean discrepancy.

KSD

KSD(x, sx, kernel_fn=None)

Implements the V-estimator of kernelized Stein discrepancy.

This function is not jittable yet.

Parameters:

  • x (JaxArray) –

    Sample matrix of size (n, d).

  • sx (JaxArray) –

    Score function evaluated at x, matrix of size (n, d).

  • kernel_fn (Optional[Callable], default: None ) –

    Kernel function that takes two inputs and returns a scalar. If None, the IMQ kernel is used with a lengthscale determined by the median heuristic.

Returns:

  • JaxArray ( Scalar ) –

    The V-estimator of the kernelized Stein discrepancy.