Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow to `vmap` over batch axis in `empirical_ntk_fn`. This follows from an observation that `d(vmap_x(f))/dp (p, x) == vmap_x(df/dp)(p, x)`, and most common neural networks are effectively `vmap`s over their batch axis. In experiments this seems to give ~2-260X speedup, notably by allowing to use larger batches in the direct method. For small batch sizes this should have no effect. Further, fuse `nt.empirical_implicit_ntk_fn` and `nt.empirical_direct_ntk_fn` into a single `nt.empirical_ntk_fn` that now accepts the `implementation=1/2` argument. `nt.empirical_kernel_fn` and `nt.monte_carlo_kernel_fn` now also accept this argument. This is breaking if you were using `nt.empirical_direct_ntk_fn` (now this is `nt.empirical_ntk_fn(..., implementation=1)`. Implementation-wise, I believe this gives the following speedups: 1) In `nt.empirical_direct_ntk_fn` (now `nt.empirical_ntk_fn(..., implementation=1)`), O(batch_size_1) time/memory improvement when constructing the Jacobian (followed by contraction, which is unchanged). I believe the most notable benefit here is increased batch size when construction the Jacobian. 2) In `nt.empirical_implicit_ntk_fn` (now `nt.empirical_ntk_fn(..., implementation=2)`, same O(batch_size_1) time/memory improvement, BUT in practice it seems to only give about 2X speedup, since this method does not gain any memory efficiency and remains O(batch_size_1 * batch_size_2 * #params). This is inspired from discussion with schsam@ and #30, but I'm not entirely sure how this relates to the layer-wise Jacobians idea. Also: - make direct method default (`implementation=1`); add suggestion when to use each. - make stax layers preserve exact input PyTrees (e.g. tuples vs lists etc). - small fix to `nt.empirical_direct_ntk_fn` to work with `x2=None`, and activate respective tests. - do not raise an error (only warn) if elements of an input pytree have mismatching batch or channel axes, since this case still works in a finite case. - fix some typos in stax tests. Co-authored-by: Sam Schoenholz <[email protected]> PiperOrigin-RevId: 342982475
- Loading branch information