Skip to content

Maximum mean disrepancy quantization

This notebook illustrates how to use :class:KernelQuantization for compressing a sample [1,2]. Let's generate a Gaussian toy sample.

import jax.random as jr
import jax.numpy as jnp

rng_key = jr.PRNGKey(42)
x = jr.multivariate_normal(rng_key, mean=jnp.zeros(2), cov=jnp.eye(2), shape=(10_000,))

The KernelHerding class is used to build a callable that performs quantization given a chosen a kernel function.

Here, we use the energy kernel function in which case the maximum mean discrepancy reduces to the energy distance [3].

from kernax.kernels import Energy
from kernax import KernelHerding

quantization_fn = KernelHerding(X, kernel_fn=Energy)
idx = quantization_fn(m = 1_000)

The output idx gathers the selected indices. Let's plot the result.

import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(X[:, 0], X[:, 1], ls="", marker="o", color="gray", label="Initial sample")
ax.plot(X[idx, 0], X[idx, 1], ls="", marker="o", color="r", label="Selected sample")
ax.legend(fontsize=14)

Note that any custom kernel function can be used. An example is shown below and you can check the kernax.kernels module for more examples.

from kernax.types import JaxArray
from jaxtyping import Scalar

def custom_kernel(x: JaxArray, y: JaxArray) -> Scalar:
    # implement your own kernel
    # example: linear kernel
    kxy = jnp.dot(x, y)
    return kxy

References

[1] Chen, Y., Welling, M., & Smola, A. (2012). Super-samples from kernel herding. arXiv preprint arXiv:1203.3472

[2] Teymur, O., Gorham, J., Riabiz, M., & Oates, C. (2021, March). Optimal quantisation of probability measures using maximum mean discrepancy. In International Conference on Artificial Intelligence and Statistics (pp. 1027-1035). PMLR.

[3] Sejdinovic, D., Sriperumbudur, B., Gretton, A., & Fukumizu, K. (2013). Equivalence of distance-based and RKHS-based statistics in hypothesis testing. The annals of statistics, 2263-2291.