Skip to content

Commit

Permalink
Empirical NTK speedup.
Browse files Browse the repository at this point in the history
Allow to `vmap` over batch axis in `empirical_ntk_fn`. This follows from an observation that `d(vmap_x(f))/dp (p, x) == vmap_x(df/dp)(p, x)`, and most common neural networks are effectively `vmap`s over their batch axis. In experiments this seems to give ~2-260X speedup, notably by allowing to use larger batches in the direct method. For small batch sizes this should have no effect.

Further, fuse `nt.empirical_implicit_ntk_fn` and `nt.empirical_direct_ntk_fn` into a single `nt.empirical_ntk_fn` that now accepts the `implementation=1/2` argument. `nt.empirical_kernel_fn` and `nt.monte_carlo_kernel_fn` now also accept this argument. This is breaking if you were using `nt.empirical_direct_ntk_fn` (now this is `nt.empirical_ntk_fn(..., implementation=1)`.

Implementation-wise, I believe this gives the following speedups:

1) In `nt.empirical_direct_ntk_fn` (now `nt.empirical_ntk_fn(..., implementation=1)`), O(batch_size_1) time/memory improvement when constructing the Jacobian (followed by contraction, which is unchanged). I believe the most notable benefit here is increased batch size when construction the Jacobian.

2) In `nt.empirical_implicit_ntk_fn` (now `nt.empirical_ntk_fn(..., implementation=2)`, same O(batch_size_1) time/memory improvement, BUT in practice it seems to only give about 2X speedup, since this method does not gain any memory efficiency and remains O(batch_size_1 * batch_size_2 * #params).

This is inspired from discussion with schsam@ and #30, but I'm not entirely sure how this relates to the layer-wise Jacobians idea.

Also:
- make direct method default (`implementation=1`); add suggestion when to use each.
- make stax layers preserve exact input PyTrees (e.g. tuples vs lists etc).
- small fix to `nt.empirical_direct_ntk_fn` to work with `x2=None`, and activate respective tests.
- do not raise an error (only warn) if elements of an input pytree have mismatching batch or channel axes, since this case still works in a finite case.
- fix some typos in stax tests.

Co-authored-by: Sam Schoenholz <[email protected]>
PiperOrigin-RevId: 342982475
  • Loading branch information
romanngg and sschoenholz committed Nov 18, 2020
1 parent 9a456ba commit f15b652
Show file tree
Hide file tree
Showing 11 changed files with 839 additions and 344 deletions.
3 changes: 2 additions & 1 deletion examples/function_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def main(unused_argv):
grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

# Create an MSE predictor to solve the NTK equation in function space.
ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0)
ntk = nt.batch(nt.empirical_ntk_fn(apply_fn, vmap_axes=0),
batch_size=4, device_count=0)
g_dd = ntk(x_train, None, params)
g_td = ntk(x_test, x_train, params)
predictor = nt.predict.gradient_descent_mse(g_dd, y_train)
Expand Down
3 changes: 1 addition & 2 deletions neural_tangents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
"""Public Neural Tangents modules and functions."""


__version__ = '0.3.4'
__version__ = '0.3.5'


from neural_tangents import predict
from neural_tangents import stax
from neural_tangents.utils.batch import batch
from neural_tangents.utils.empirical import empirical_direct_ntk_fn
from neural_tangents.utils.empirical import empirical_kernel_fn
from neural_tangents.utils.empirical import empirical_nngp_fn
from neural_tangents.utils.empirical import empirical_ntk_fn
Expand Down
29 changes: 17 additions & 12 deletions neural_tangents/stax.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
from jax.tree_util import tree_map
from neural_tangents.utils import utils, dataclasses
from neural_tangents.utils.kernel import Kernel
from neural_tangents.utils.typing import AnalyticKernelFn, Axes, Get, InitFn, ApplyFn, InternalLayer, Layer, LayerKernelFn, PyTree, NTTree
from neural_tangents.utils.typing import AnalyticKernelFn, Axes, Get, InitFn, ApplyFn, InternalLayer, Layer, LayerKernelFn, PyTree, NTTree, Kernels
import numpy as onp
import scipy as osp

Expand Down Expand Up @@ -342,14 +342,17 @@ def parallel(*layers: Layer) -> InternalLayer:
sequence of outputs with the same length as the argument `layers`.
"""
init_fns, apply_fns, kernel_fns = zip(*layers)
init_fn_stax, apply_fn = ostax.parallel(*zip(init_fns, apply_fns))
init_fn_stax, apply_fn_stax = ostax.parallel(*zip(init_fns, apply_fns))

def init_fn(rng, input_shape):
return list(init_fn_stax(rng, input_shape))
return type(input_shape)(init_fn_stax(rng, input_shape))

def apply_fn(params, inputs, **kwargs):
return type(inputs)(apply_fn_stax(params, inputs, **kwargs))

@_requires(**_get_input_req_attr(kernel_fns, fold=op.and_))
def kernel_fn(ks: List[Kernel], **kwargs) -> List[Kernel]:
return [f(k, **kwargs) for k, f in zip(ks, kernel_fns)]
def kernel_fn(ks: NTTree[Kernel], **kwargs) -> NTTree[Kernel]:
return type(ks)(f(k, **kwargs) for k, f in zip(ks, kernel_fns))

return init_fn, apply_fn, kernel_fn

Expand Down Expand Up @@ -1288,7 +1291,7 @@ def FanInSum() -> InternalLayer:
"""
init_fn, apply_fn = ostax.FanInSum

def kernel_fn(ks: List[Kernel], **kwargs) -> Kernel:
def kernel_fn(ks: Kernels, **kwargs) -> Kernel:
ks, is_reversed = _proprocess_kernels_for_fan_in(ks)
if not all([k.shape1 == ks[0].shape1 and
k.shape2 == ks[0].shape2 for k in ks[1:]]):
Expand Down Expand Up @@ -1350,7 +1353,7 @@ def FanInProd() -> InternalLayer:
def apply_fn(params, inputs, **kwargs):
return functools.reduce(np.multiply, inputs)

def kernel_fn(ks: List[Kernel], **kwargs) -> Kernel:
def kernel_fn(ks: Kernels, **kwargs) -> Kernel:
ks, is_reversed = _proprocess_kernels_for_fan_in(ks)
if not all([k.shape1 == ks[0].shape1 and
k.shape2 == ks[0].shape2 for k in ks[1:]]):
Expand Down Expand Up @@ -1415,7 +1418,7 @@ def FanInConcat(axis: int = -1) -> InternalLayer:
"""
init_fn, apply_fn = ostax.FanInConcat(axis)

def kernel_fn(ks: List[Kernel], **kwargs) -> Kernel:
def kernel_fn(ks: Kernels, **kwargs) -> Kernel:
ks, is_reversed = _proprocess_kernels_for_fan_in(ks)

diagonal_batch = ks[0].diagonal_batch
Expand Down Expand Up @@ -3250,8 +3253,10 @@ def _get_input_req_attr(
if fold is op.and_:
if k in req and req[k] != v:
if (req[k] >= 0 and v >= 0) or (req[k] < 0 and v < 0):
raise ValueError(f'`{k}` parameters must match in all '
f'parallel branches, got {req[k]} and {v}.')
warnings.warn(f'For `kernel_fn`, `{k}` parameters must match in'
f' all parallel branches, got {req[k]} and {v}. '
f'This WILL lead to [silent] errors if '
f'`kernel_fn` is called.')
else:
warnings.warn(f'Got potentially mismatching `{k}` values in '
f'parallel branches: {req[k]} and {v}.')
Expand Down Expand Up @@ -3892,8 +3897,7 @@ def _affine(
return W_std**2 * mat + b_std**2


def _proprocess_kernels_for_fan_in(
ks: List[Kernel]) -> Tuple[List[Kernel], bool]:
def _proprocess_kernels_for_fan_in(ks: Kernels) -> Tuple[List[Kernel], bool]:
# Check diagonal requirements.
if not all(k.diagonal_batch == ks[0].diagonal_batch and
k.diagonal_spatial == ks[0].diagonal_spatial and
Expand All @@ -3909,6 +3913,7 @@ def _proprocess_kernels_for_fan_in(
# If kernels have different spatial axes order, transpose some of them.
n_kernels = len(ks)
n_reversed = sum(ker.is_reversed for ker in ks)
ks = list(ks)

if n_reversed > n_kernels / 2:
is_reversed = True
Expand Down
Loading

0 comments on commit f15b652

Please sign in to comment.