kernax.quantization
Module for kernel quantization using maximum mean discrepancy.
KernelHerding
Bases: Module
Greedy MMD-based kernel quantization (herding-style thinning).
Once instantiated, the module can be called with an integer m to
select m representative samples from the input dataset.
Parameters:
-
x–The dataset of shape
(N, d)to be subsampled. -
kernel_fn–Kernel function of the form
k(x: (d,), y: (d,)) -> scalar.
__call__
__call__(m)
Return indices of a subset (size m) that greedily minimizes MMD.
Parameters:
-
m(int) –Number of points to select (must be <= N).
Returns:
-
JaxArray–Indices of the selected points into
x, array of shape(m,).