From f15b6528a47a73b1940f069309e69111b5235e13 Mon Sep 17 00:00:00 2001 From: Roman Novak Date: Tue, 17 Nov 2020 17:38:45 -0800 Subject: [PATCH] Empirical NTK speedup. Allow to `vmap` over batch axis in `empirical_ntk_fn`. This follows from an observation that `d(vmap_x(f))/dp (p, x) == vmap_x(df/dp)(p, x)`, and most common neural networks are effectively `vmap`s over their batch axis. In experiments this seems to give ~2-260X speedup, notably by allowing to use larger batches in the direct method. For small batch sizes this should have no effect. Further, fuse `nt.empirical_implicit_ntk_fn` and `nt.empirical_direct_ntk_fn` into a single `nt.empirical_ntk_fn` that now accepts the `implementation=1/2` argument. `nt.empirical_kernel_fn` and `nt.monte_carlo_kernel_fn` now also accept this argument. This is breaking if you were using `nt.empirical_direct_ntk_fn` (now this is `nt.empirical_ntk_fn(..., implementation=1)`. Implementation-wise, I believe this gives the following speedups: 1) In `nt.empirical_direct_ntk_fn` (now `nt.empirical_ntk_fn(..., implementation=1)`), O(batch_size_1) time/memory improvement when constructing the Jacobian (followed by contraction, which is unchanged). I believe the most notable benefit here is increased batch size when construction the Jacobian. 2) In `nt.empirical_implicit_ntk_fn` (now `nt.empirical_ntk_fn(..., implementation=2)`, same O(batch_size_1) time/memory improvement, BUT in practice it seems to only give about 2X speedup, since this method does not gain any memory efficiency and remains O(batch_size_1 * batch_size_2 * #params). This is inspired from discussion with schsam@ and https://github.com/google/neural-tangents/issues/30, but I'm not entirely sure how this relates to the layer-wise Jacobians idea. Also: - make direct method default (`implementation=1`); add suggestion when to use each. - make stax layers preserve exact input PyTrees (e.g. tuples vs lists etc). - small fix to `nt.empirical_direct_ntk_fn` to work with `x2=None`, and activate respective tests. - do not raise an error (only warn) if elements of an input pytree have mismatching batch or channel axes, since this case still works in a finite case. - fix some typos in stax tests. Co-authored-by: Sam Schoenholz PiperOrigin-RevId: 342982475 --- examples/function_space.py | 3 +- neural_tangents/__init__.py | 3 +- neural_tangents/stax.py | 29 +- neural_tangents/utils/empirical.py | 759 +++++++++++++------ neural_tangents/utils/monte_carlo.py | 49 +- neural_tangents/utils/typing.py | 16 +- notebooks/function_space_linearization.ipynb | 3 +- tests/batch_test.py | 6 +- tests/empirical_test.py | 185 ++++- tests/monte_carlo_test.py | 12 +- tests/stax_test.py | 118 ++- 11 files changed, 839 insertions(+), 344 deletions(-) diff --git a/examples/function_space.py b/examples/function_space.py index cfb7f156..335f3826 100644 --- a/examples/function_space.py +++ b/examples/function_space.py @@ -69,7 +69,8 @@ def main(unused_argv): grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y))) # Create an MSE predictor to solve the NTK equation in function space. - ntk = nt.batch(nt.empirical_ntk_fn(apply_fn), batch_size=4, device_count=0) + ntk = nt.batch(nt.empirical_ntk_fn(apply_fn, vmap_axes=0), + batch_size=4, device_count=0) g_dd = ntk(x_train, None, params) g_td = ntk(x_test, x_train, params) predictor = nt.predict.gradient_descent_mse(g_dd, y_train) diff --git a/neural_tangents/__init__.py b/neural_tangents/__init__.py index eedd2b04..4a841e4a 100644 --- a/neural_tangents/__init__.py +++ b/neural_tangents/__init__.py @@ -16,13 +16,12 @@ """Public Neural Tangents modules and functions.""" -__version__ = '0.3.4' +__version__ = '0.3.5' from neural_tangents import predict from neural_tangents import stax from neural_tangents.utils.batch import batch -from neural_tangents.utils.empirical import empirical_direct_ntk_fn from neural_tangents.utils.empirical import empirical_kernel_fn from neural_tangents.utils.empirical import empirical_nngp_fn from neural_tangents.utils.empirical import empirical_ntk_fn diff --git a/neural_tangents/stax.py b/neural_tangents/stax.py index 51f0bf8d..f92445d2 100644 --- a/neural_tangents/stax.py +++ b/neural_tangents/stax.py @@ -84,7 +84,7 @@ from jax.tree_util import tree_map from neural_tangents.utils import utils, dataclasses from neural_tangents.utils.kernel import Kernel -from neural_tangents.utils.typing import AnalyticKernelFn, Axes, Get, InitFn, ApplyFn, InternalLayer, Layer, LayerKernelFn, PyTree, NTTree +from neural_tangents.utils.typing import AnalyticKernelFn, Axes, Get, InitFn, ApplyFn, InternalLayer, Layer, LayerKernelFn, PyTree, NTTree, Kernels import numpy as onp import scipy as osp @@ -342,14 +342,17 @@ def parallel(*layers: Layer) -> InternalLayer: sequence of outputs with the same length as the argument `layers`. """ init_fns, apply_fns, kernel_fns = zip(*layers) - init_fn_stax, apply_fn = ostax.parallel(*zip(init_fns, apply_fns)) + init_fn_stax, apply_fn_stax = ostax.parallel(*zip(init_fns, apply_fns)) def init_fn(rng, input_shape): - return list(init_fn_stax(rng, input_shape)) + return type(input_shape)(init_fn_stax(rng, input_shape)) + + def apply_fn(params, inputs, **kwargs): + return type(inputs)(apply_fn_stax(params, inputs, **kwargs)) @_requires(**_get_input_req_attr(kernel_fns, fold=op.and_)) - def kernel_fn(ks: List[Kernel], **kwargs) -> List[Kernel]: - return [f(k, **kwargs) for k, f in zip(ks, kernel_fns)] + def kernel_fn(ks: NTTree[Kernel], **kwargs) -> NTTree[Kernel]: + return type(ks)(f(k, **kwargs) for k, f in zip(ks, kernel_fns)) return init_fn, apply_fn, kernel_fn @@ -1288,7 +1291,7 @@ def FanInSum() -> InternalLayer: """ init_fn, apply_fn = ostax.FanInSum - def kernel_fn(ks: List[Kernel], **kwargs) -> Kernel: + def kernel_fn(ks: Kernels, **kwargs) -> Kernel: ks, is_reversed = _proprocess_kernels_for_fan_in(ks) if not all([k.shape1 == ks[0].shape1 and k.shape2 == ks[0].shape2 for k in ks[1:]]): @@ -1350,7 +1353,7 @@ def FanInProd() -> InternalLayer: def apply_fn(params, inputs, **kwargs): return functools.reduce(np.multiply, inputs) - def kernel_fn(ks: List[Kernel], **kwargs) -> Kernel: + def kernel_fn(ks: Kernels, **kwargs) -> Kernel: ks, is_reversed = _proprocess_kernels_for_fan_in(ks) if not all([k.shape1 == ks[0].shape1 and k.shape2 == ks[0].shape2 for k in ks[1:]]): @@ -1415,7 +1418,7 @@ def FanInConcat(axis: int = -1) -> InternalLayer: """ init_fn, apply_fn = ostax.FanInConcat(axis) - def kernel_fn(ks: List[Kernel], **kwargs) -> Kernel: + def kernel_fn(ks: Kernels, **kwargs) -> Kernel: ks, is_reversed = _proprocess_kernels_for_fan_in(ks) diagonal_batch = ks[0].diagonal_batch @@ -3250,8 +3253,10 @@ def _get_input_req_attr( if fold is op.and_: if k in req and req[k] != v: if (req[k] >= 0 and v >= 0) or (req[k] < 0 and v < 0): - raise ValueError(f'`{k}` parameters must match in all ' - f'parallel branches, got {req[k]} and {v}.') + warnings.warn(f'For `kernel_fn`, `{k}` parameters must match in' + f' all parallel branches, got {req[k]} and {v}. ' + f'This WILL lead to [silent] errors if ' + f'`kernel_fn` is called.') else: warnings.warn(f'Got potentially mismatching `{k}` values in ' f'parallel branches: {req[k]} and {v}.') @@ -3892,8 +3897,7 @@ def _affine( return W_std**2 * mat + b_std**2 -def _proprocess_kernels_for_fan_in( - ks: List[Kernel]) -> Tuple[List[Kernel], bool]: +def _proprocess_kernels_for_fan_in(ks: Kernels) -> Tuple[List[Kernel], bool]: # Check diagonal requirements. if not all(k.diagonal_batch == ks[0].diagonal_batch and k.diagonal_spatial == ks[0].diagonal_spatial and @@ -3909,6 +3913,7 @@ def _proprocess_kernels_for_fan_in( # If kernels have different spatial axes order, transpose some of them. n_kernels = len(ks) n_reversed = sum(ker.is_reversed for ker in ks) + ks = list(ks) if n_reversed > n_kernels / 2: is_reversed = True diff --git a/neural_tangents/utils/empirical.py b/neural_tangents/utils/empirical.py index b439ac93..fbec432b 100644 --- a/neural_tangents/utils/empirical.py +++ b/neural_tangents/utils/empirical.py @@ -14,34 +14,87 @@ """Compute empirical NNGP and NTK; approximate functions via Taylor series. -NNGP and NTK are computed using `empirical_nngp_fn`, `empirical_ntk_fn` (or -`empirical_direct_ntk_fn`), or `empirical_kernel_fn` (for both). - -For networks with multiple outputs, in principal the empirical kernels will have -terms measuring the covariance between the outputs. Here, we ignore these -cross-terms and consider each output separately. Please raise an issue if this -feature is important to you. - -WARNING: resulting kernel shape is *nearly* `zip(f(x1).shape, f(x2).shape)` -subject to `trace_axes` and `diagonal_axes` parameters, which make certain -assumptions about the outputs `f(x)` that may only be true in the infinite width -/ infinite number of samples limit, or may not apply to your architecture. For -most precise results in the context of linearized training dynamics of a -specific finite-width network, set both `trace_axes=()` and `diagonal_axes=()` -to obtain the kernel exactly of shape `zip(f(x1).shape, f(x2).shape)`. Please -refer to individual functions' docstrings for details. +All functions in this module are applicable to any JAX functions of proper +signatures (not only those from `nt.stax`). + +NNGP and NTK are computed using `empirical_nngp_fn`, `empirical_ntk_fn`, or + `empirical_kernel_fn` (for both). The kernels have a very specific output +shape convention that may be unexpected. Further, NTK has multiple +implementations that may perform differently depending on the task. +Please read individual functions' docstrings. + +Example: + >>> from jax import random + >>> import neural_tangents as nt + >>> from neural_tangents import stax + >>> + >>> key1, key2, key3 = random.split(random.PRNGKey(1), 3) + >>> x_train = random.normal(key1, (20, 32, 32, 3)) + >>> y_train = random.uniform(key1, (20, 10)) + >>> x_test = random.normal(key2, (5, 32, 32, 3)) + >>> + >>> # A narrow CNN. + >>> init_fn, f, _ = stax.serial( + >>> stax.Conv(32, (3, 3)), + >>> stax.Relu(), + >>> stax.Conv(32, (3, 3)), + >>> stax.Relu(), + >>> stax.Conv(32, (3, 3)), + >>> stax.Flatten(), + >>> stax.Dense(10) + >>> ) + >>> + >>> _, params = init_fn(key3, x_train.shape) + >>> + >>> # Default setting: reducing over logits; pass `vmap_axes=0` because the + >>> # network is iid along the batch axis, no BatchNorm. Use default + >>> # `implementation=1` since the network has few trainable parameters. + >>> kernel_fn = nt.empirical_kernel_fn(f, trace_axes=(-1,), + >>> vmap_axes=0, implementation=1) + >>> + >>> # (5, 20) np.ndarray test-train NNGP/NTK + >>> nngp_test_train = kernel_fn(x_test, x_train, 'nngp', params) + >>> ntk_test_train = kernel_fn(x_test, x_train, 'ntk', params) + >>> + >>> # Full kernel: not reducing over logits. + >>> kernel_fn = nt.empirical_kernel_fn(f, trace_axes=(), vmap_axes=0) + >>> + >>> # (5, 20, 10, 10) np.ndarray test-train NNGP/NTK namedtuple. + >>> k_test_train = kernel_fn(x_test, x_train, params) + >>> + >>> # A wide FCN with lots of parameters + >>> init_fn, f, _ = stax.serial( + >>> stax.Flatten(), + >>> stax.Dense(1024), + >>> stax.Relu(), + >>> stax.Dense(1024), + >>> stax.Relu(), + >>> stax.Dense(10) + >>> ) + >>> + >>> _, params = init_fn(key3, x_train.shape) + >>> + >>> # Use implicit differentiation in NTK: `implementation=2` to reduce + >>> # memory cost, since the network has many trainable parameters. + >>> ntk_fn = nt.empirical_ntk_fn(f, vmap_axes=0, implementation=2) + >>> + >>> # (5, 5) np.ndarray test-test NTK + >>> ntk_test_train = ntk_fn(x_test, None, params) + >>> + >>> # Compute only output variances: + >>> nngp_fn = nt.empirical_nngp_fn(f, diagonal_axes=(0,)) + >>> + >>> # (20,) np.ndarray train-train diagonal NNGP + >>> nngp_train_train_diag = nngp_fn(x_train, None, params) """ import operator from typing import Union, Callable, Optional, Tuple, Dict -from jax.api import eval_shape -from jax.api import jacobian -from jax.api import jvp -from jax.api import vjp +from jax.api import eval_shape, jacobian, jvp, vjp, vmap, _std_basis, _unravel_array_into_pytree, linear_transpose import jax.numpy as np -from jax.tree_util import tree_flatten, tree_unflatten, tree_structure, tree_multimap, tree_reduce, tree_map +from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce, tree_map from neural_tangents.utils import utils -from neural_tangents.utils.typing import ApplyFn, EmpiricalKernelFn, NTTree, PyTree, Axes +from neural_tangents.utils.typing import ApplyFn, EmpiricalKernelFn, NTTree, PyTree, Axes, VMapAxes def linearize(f: Callable[..., PyTree], @@ -69,10 +122,10 @@ def linearize(f: Callable[..., PyTree], `params`. """ def f_lin(p, *args, **kwargs): - dparams = tree_multimap(operator.sub, p, params) + dparams = _sub(p, params) f_params_x, proj = jvp(lambda param: f(param, *args, **kwargs), (params,), (dparams,)) - return tree_multimap(operator.add, f_params_x, proj) + return _add(f_params_x, proj) return f_lin @@ -113,11 +166,10 @@ def f_jvp(p): return val_jvp df = taylorize_r(f_jvp, params, dparams, degree, current_degree + 1) - return tree_multimap(operator.add, f(params), - tree_map(lambda x: x / (current_degree + 1), df)) + return _add(f(params), _div(df, (current_degree + 1))) def f_tayl(p, *args, **kwargs): - dparams = tree_multimap(operator.sub, p, params) + dparams = _sub(p, params) return taylorize_r(lambda param: f(param, *args, **kwargs), params, dparams, degree, 0) @@ -127,151 +179,33 @@ def f_tayl(p, *args, **kwargs): # Empirical Kernel -def empirical_implicit_ntk_fn(f: ApplyFn, - trace_axes: Axes = (-1,), - diagonal_axes: Axes = () - ) -> Callable[[NTTree[np.ndarray], - Optional[NTTree[np.ndarray]], - PyTree], - NTTree[np.ndarray]]: - - r"""Returns a function to draw a single sample the NTK of a given network `f`. - - The Neural Tangent Kernel is defined as :math:`J(X_1) J(X_2)^T` where - :math:`J` is the Jacobian :math:`df / dparams^T`. Computing the NTK directly - involves instantiating the jacobian which takes - `O(dataset_size * output_dim * parameters)` memory. It turns out it is - substantially more efficient (especially as the number of parameters grows) - to compute the NTK implicitly. - - The implicit kernel is derived by observing that: - :math:`\Theta = J(X_1) J(X_2)^T = d[J(X_1) J(X_2)^T v] / d[v^T]`, - for a vector :math:`v`. This allows the computation of the NTK to be phrased - as: :math:`a(v) = J(X_2)^T v`, which is computed by a vector-Jacobian product; - :math:`b(v) = J(X_1) a(v)` which is computed by a Jacobian-vector product; and - :math:`\Theta = d[b(v)] / d[v^T]` which is computed by taking the Jacobian of - :math:`b(v)`. - - Args: - f: - the function whose NTK we are computing. `f` should have the signature - `f(params, inputs[, rng])` and should return `np.ndarray` of outputs. - trace_axes: - output axes to trace the output kernel over, i.e. compute only the trace - of the covariance along the respective pair of axes (one pair for each - axis in `trace_axes`). This allows to save space and compute if you are - only interested in the respective trace, but also improve approximation - accuracy if you know that covariance along these pairs of axes converges - to a `constant * identity matrix` in the limit of interest (e.g. - infinite width or infinite `n_samples`). A common use case is the channel - / feature / logit axis, since activation slices along such axis are i.i.d. - and the respective covariance along the respective pair of axes indeed - converges to a constant-diagonal matrix in the infinite width or infinite - `n_samples` limit. - Also related to "contracting dimensions" in XLA terms. - (https://www.tensorflow.org/xla/operation_semantics#dotgeneral) - diagonal_axes: - output axes to diagonalize the output kernel over, i.e. compute only the - diagonal of the covariance along the respective pair of axes (one pair for - each axis in `diagonal_axes`). This allows to save space and compute, if - off-diagonal values along these axes are not needed, but also improve - approximation accuracy if their limiting value is known theoretically, - e.g. if they vanish in the limit of interest (e.g. infinite - width or infinite `n_samples`). If you further know that on-diagonal - values converge to the same constant in your limit of interest, you should - specify these axes in `trace_axes` instead, to save even more compute and - gain even more accuracy. A common use case is computing the variance - (instead of covariance) along certain axes. - Also related to "batch dimensions" in XLA terms. - (https://www.tensorflow.org/xla/operation_semantics#dotgeneral) - - Returns: - A function `ntk_fn` that computes the empirical ntk. - """ - - def _flatten(f): - return lambda p: tree_flatten(f(p))[0] - - def ntk_fn(x1: NTTree[np.ndarray], - x2: Optional[NTTree[np.ndarray]], - params: PyTree, - **apply_fn_kwargs) -> np.ndarray: - """Computes a single sample of the empirical NTK (implicit differentiation). - - Args: - x1: - first batch of inputs. - x2: - second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a - matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. - params: - A `PyTree` of parameters about which we would like to compute the - neural tangent kernel. - **apply_fn_kwargs: - keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split - into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs` - function which will be passed to `apply_fn`. In particular, the rng key - in `apply_fn_kwargs`, will be split into two different (if `x1 != x2`) - or same (if `x1 == x2`) rng keys. See the `_read_key` function for more - details. - - Returns: - A single sample of the empirical NTK. The shape of the kernel is "almost" - `zip(f(x1).shape, f(x2).shape)` except for: - 1) `trace_axes` are absent as they are contracted over. - 2) `diagonal_axes` are present only once. - All other axes are present twice. - """ - kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2) - f1 = _flatten(_get_f_params(f, x1, **kwargs1)) - f2 = (f1 if utils.all_none(x2) else - _flatten(_get_f_params(f, x2, **kwargs2))) - - def delta_vjp_jvp(delta): - def delta_vjp(delta): - return vjp(f2, params)[1](delta) - return jvp(f1, (params,), delta_vjp(delta))[1] - - # Since we are taking the Jacobian of a linear function (which does not - # depend on its coefficients), it is more efficient to substitute fx_dummy - # for the outputs of the network. fx_dummy has the same shape as the output - # of the network on a single piece of input data. - fx2_struct = eval_shape(f2, params) - - @utils.nt_tree_fn() - def dummy_output(fx_struct): - return np.ones(fx_struct.shape, fx_struct.dtype) - fx_dummy = dummy_output(fx2_struct) - - ntk = jacobian(delta_vjp_jvp)(fx_dummy) - if utils.is_list_or_tuple(fx_dummy): - fx_treedef = tree_structure(eval_shape( - _get_f_params(f, x1, **kwargs1), - params)) - ntk = [ntk[i][i] for i in range(len(fx_dummy))] - ntk = tree_unflatten(fx_treedef, ntk) - - return _trace_and_diagonal(ntk, trace_axes, diagonal_axes) - return ntk_fn - - -def empirical_direct_ntk_fn(f: ApplyFn, - trace_axes: Axes = (-1,), - diagonal_axes: Axes = () - ) -> Callable[[NTTree[np.ndarray], - Optional[NTTree[np.ndarray]], - PyTree], - NTTree[np.ndarray]]: - """Returns a function to draw a single sample the NTK of a given network `f`. - - The Neural Tangent Kernel is defined as :math:`J(X_1) J(X_2)^T` where - :math:`J` is the Jacobian :math:`df/dparams`. This function instantiates the - Jacobians directly and computes their outer product. +def empirical_kernel_fn( + f: ApplyFn, + trace_axes: Axes = (-1,), + diagonal_axes: Axes = (), + vmap_axes: VMapAxes = None, + implementation: int = 1 +) -> EmpiricalKernelFn: + r"""Returns a function that computes single draws from NNGP and NT kernels. + + WARNING: resulting kernel shape is *nearly* `zip(f(x1).shape, f(x2).shape)` + subject to `trace_axes` and `diagonal_axes` parameters, which make certain + assumptions about the outputs `f(x)` that may only be true in the infinite + width / infinite number of samples limit, or may not apply to your + architecture. For most precise results in the context of linearized training + dynamics of a specific finite-width network, set both `trace_axes=()` and + `diagonal_axes=()` to obtain the kernel exactly of shape + `zip(f(x1).shape, f(x2).shape)`. + + For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal + the empirical kernels will have terms measuring the covariance between the + outputs. Here, we ignore these cross-terms and consider each output + separately. Please raise an issue if this feature is important to you. Args: f: the function whose NTK we are computing. `f` should have the signature - `f(params, inputs[, rng])` and should return an `np.ndarray` outputs. + `f(params, inputs, **kwargs)` and should return an `np.ndarray` outputs. trace_axes: output axes to trace the output kernel over, i.e. compute only the trace of the covariance along the respective pair of axes (one pair for each @@ -300,30 +234,80 @@ def empirical_direct_ntk_fn(f: ApplyFn, (instead of covariance) along certain axes. Also related to "batch dimensions" in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral) + vmap_axes: + applicable only to NTK. + + A triple of `(in_axes, out_axes, kwargs_axes)` + passed to `vmap` to evaluate the empirical NTK in parallel ove these axes. + Precisely, providing this argument implies that `f(params, x, **kwargs)` + equals to a concatenation along `out_axes` of `f` applied to slices of + `x` and `**kwargs` along `in_axes` and `kwargs_axes`. In other words, it + certifies that `f` can be evaluated as a `vmap` with `out_axes=out_axes` + over `x` (along `in_axes`) and those arguments in `**kwargs` that are + present in `kwargs_axes.keys()` (along `kwargs_axes.values()`). + + For example if `_, f, _ = nt.stax.Aggregate()`, `f` is called via + `f(params, x, pattern=pattern)`. By default, inputs `x`, patterns + `pattern`, and outputs of `f` are all batched along the leading `0` + dimension, and each output `f(params, x, pattern=pattern)[i]` only + depends on the inputs `x[i]` and `pattern[i]`. In this case, we can + pass `vmap_axes=(0, 0, dict(pattern=0)` to specify along which dimensions + inputs, outputs, and keyword arguments are batched respectively. + + This allows us to evaluate Jacobians much more + efficiently. If `vmap_axes` is not a triple, it is interpreted as + `in_axes = out_axes = vmap_axes, kwargs_axes = {}`. For example a very + common use case is `vmap_axes=0` for a neural network with leading (`0`) + batch dimension, both for inputs and outputs, and no interactions between + different elements of the batch (e.g. no BatchNorm, and, in the case of + `nt.stax`, also no Dropout). However, if there is interaction between + batch elements or no concept of a batch axis at all, `vmap_axes` must be + set to `None`, to avoid wrong (and potentially silent) results. + implementation: + applicable only to NTK. + + `1` or `2`. + + `1` directly instantiates Jacobians and computes their outer + product. + + `2` uses implicit differentiation to avoid instantiating whole + Jacobians at once. The implicit kernel is derived by observing that: + :math:`\Theta = J(X_1) J(X_2)^T = [J(X_1) J(X_2)^T](I)`, + i.e. a linear function :math:`[J(X_1) J(X_2)^T]` applied to an identity + matrix :math:`I`. This allows the computation of the NTK to be + phrased as: :math:`a(v) = J(X_2)^T v`, which is computed by a + vector-Jacobian product; :math:`b(v) = J(X_1) a(v)` which is computed by + a Jacobian-vector product; and :math:`\Theta = [b(v)] / d[v^T](I)` which + is computed via a `vmap` of :math:`b(v)` over columns of the identity + matrix :math:`I`. + + It is best to benchmark each method on your specific task. We suggest + using `1` unless you get OOMs due to large number of trainable parameters, + otherwise - `2`. Returns: - A function `ntk_fn` that computes the empirical ntk. + A function to draw a single sample the NNGP and NTK empirical kernels of a + given network `f`. """ - @utils.nt_tree_fn(tree_structure_argnum=0) - def sum_and_contract(fx, j1, j2): - ndim = fx.ndim - size = utils.size_at(fx, trace_axes) - - _diagonal_axes = utils.canonicalize_axis(diagonal_axes, ndim) - _trace_axes = utils.canonicalize_axis(trace_axes, ndim) - - def contract(x, y): - param_axes = list(range(x.ndim))[ndim:] - contract_axes = _trace_axes + param_axes - return utils.dot_general(x, y, contract_axes, _diagonal_axes) / size + kwargs = dict(f=f, + trace_axes=trace_axes, + diagonal_axes=diagonal_axes) - return tree_reduce(operator.add, tree_multimap(contract, j1, j2)) + kernel_fns = { + 'nngp': empirical_nngp_fn(**kwargs), + 'ntk': empirical_ntk_fn(**kwargs, + vmap_axes=vmap_axes, + implementation=implementation) + } - def ntk_fn(x1: NTTree[np.ndarray], - x2: Optional[NTTree[np.ndarray]], - params: PyTree, - **apply_fn_kwargs) -> np.ndarray: - """Computes a single sample of the empirical NTK (jacobian outer product). + @utils.get_namedtuple('EmpiricalKernel') + def kernel_fn(x1: NTTree[np.ndarray], + x2: Optional[NTTree[np.ndarray]], + get: Union[None, str, Tuple[str, ...]], + params: PyTree, + **apply_fn_kwargs) -> NTTree[Dict[str, np.ndarray]]: + """Computes a single sample of the empirical kernel of type `get`. Args: x1: @@ -331,6 +315,9 @@ def ntk_fn(x1: NTTree[np.ndarray], x2: second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. + get: + type of the empirical kernel. `get=None` means `get=("nngp", "ntk")`. + Can be a string (`"nngp"`) or a tuple of strings (`("ntk", "nngp")`). params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. @@ -343,33 +330,26 @@ def ntk_fn(x1: NTTree[np.ndarray], details. Returns: - A single sample of the empirical NTK. The shape of the kernel is "almost" + A single sample of the empirical kernel. The shape is "almost" `zip(f(x1).shape, f(x2).shape)` except for: 1) `trace_axes` are absent as they are contracted over. 2) `diagonal_axes` are present only once. All other axes are present twice. - """ - - kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2) - f1 = _get_f_params(f, x1, **kwargs1) - jac_fn1 = jacobian(f1) - j1 = jac_fn1(params) - if x2 is None: - j2 = j1 - else: - f2 = _get_f_params(f, x2, **kwargs2) - jac_fn2 = jacobian(f2) - j2 = jac_fn2(params) - - fx1 = eval_shape(f1, params) - ntk = sum_and_contract(fx1, j1, j2) - return ntk + If `get` is a string, returns the requested `np.ndarray`. If `get` is a + tuple, returns an `EmpiricalKernel` namedtuple containing the + requested information. + """ + if get is None: + get = ('nngp', 'ntk') - return ntk_fn + out_dict = {g: kernel_fns[g](x1, x2, params, **apply_fn_kwargs) + for g in get} + out_dict = _dict_of_tree_to_tree_of_dict(out_dict, get) + return out_dict -empirical_ntk_fn = empirical_implicit_ntk_fn + return kernel_fn def empirical_nngp_fn(f: ApplyFn, @@ -381,6 +361,23 @@ def empirical_nngp_fn(f: ApplyFn, NTTree[np.ndarray]]: """Returns a function to draw a single sample the NNGP of a given network `f`. + The Neural Network Gaussian Process (NNGP) kernel is defined as + :math:`f(X_1) f(X_2)^T`, i.e. the outer product of the function outputs. + + WARNING: resulting kernel shape is *nearly* `zip(f(x1).shape, f(x2).shape)` + subject to `trace_axes` and `diagonal_axes` parameters, which make certain + assumptions about the outputs `f(x)` that may only be true in the infinite + width / infinite number of samples limit, or may not apply to your + architecture. For most precise results in the context of linearized training + dynamics of a specific finite-width network, set both `trace_axes=()` and + `diagonal_axes=()` to obtain the kernel exactly of shape + `zip(f(x1).shape, f(x2).shape)`. + + For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal + the empirical kernels will have terms measuring the covariance between the + outputs. Here, we ignore these cross-terms and consider each output + separately. Please raise an issue if this feature is important to you. + Args: f: the function whose NNGP we are computing. `f` should have the signature @@ -456,10 +453,7 @@ def output(x, **kwargs): kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2) out1 = output(x1, **kwargs1) - if utils.all_none(x2): - out2 = out1 - else: - out2 = output(x2, **kwargs2) + out2 = output(x2, **kwargs2) if not utils.all_none(x2) else out1 @utils.nt_tree_fn() def contract(out1, out2): @@ -471,11 +465,40 @@ def contract(out1, out2): return nngp_fn -def empirical_kernel_fn(f: ApplyFn, - trace_axes: Axes = (-1,), - diagonal_axes: Axes = () - ) -> EmpiricalKernelFn: - """Returns a function that computes single draws from NNGP and NT kernels. +def empirical_ntk_fn(f: ApplyFn, + trace_axes: Axes = (-1,), + diagonal_axes: Axes = (), + vmap_axes: VMapAxes = None, + implementation: int = 1 + ) -> Callable[[NTTree[np.ndarray], + Optional[NTTree[np.ndarray]], + PyTree], + NTTree[np.ndarray]]: + r"""Returns a function to draw a single sample the NTK of a given network `f`. + + The Neural Tangent Kernel is defined as :math:`J(X_1) J(X_2)^T` where + :math:`J` is the Jacobian :math:`df/dparams` of shape + `full_output_shape + params.shape`. + + For best performance: + 1) pass `x2=None` if `x1 == x2; + 2) prefer square batches (i.e `x1.shape == x2.shape`); + 3) make sure to set `vmap_axes` correctly. + 4) try different `implementation` values. + + WARNING: Resulting kernel shape is *nearly* `zip(f(x1).shape, f(x2).shape)` + subject to `trace_axes` and `diagonal_axes` parameters, which make certain + assumptions about the outputs `f(x)` that may only be true in the infinite + width / infinite number of samples limit, or may not apply to your + architecture. For most precise results in the context of linearized training + dynamics of a specific finite-width network, set both `trace_axes=()` and + `diagonal_axes=()` to obtain the kernel exactly of shape + `zip(f(x1).shape, f(x2).shape)`. + + For networks with multiple (i.e. lists, tuples, PyTrees) outputs, in principal + the empirical kernels will have terms measuring the covariance between the + outputs. Here, we ignore these cross-terms and consider each output + separately. Please raise an issue if this feature is important to you. Args: f: @@ -509,23 +532,190 @@ def empirical_kernel_fn(f: ApplyFn, (instead of covariance) along certain axes. Also related to "batch dimensions" in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral) + vmap_axes: + A triple of `(in_axes, out_axes, kwargs_axes)` + passed to `vmap` to evaluate the empirical NTK in parallel ove these axes. + Precisely, providing this argument implies that `f(params, x, **kwargs)` + equals to a concatenation along `out_axes` of `f` applied to slices of + `x` and `**kwargs` along `in_axes` and `kwargs_axes`. In other words, it + certifies that `f` can be evaluated as a `vmap` with `out_axes=out_axes` + over `x` (along `in_axes`) and those arguments in `**kwargs` that are + present in `kwargs_axes.keys()` (along `kwargs_axes.values()`). + + For example if `_, f, _ = nt.stax.Aggregate()`, `f` is called via + `f(params, x, pattern=pattern)`. By default, inputs `x`, patterns + `pattern`, and outputs of `f` are all batched along the leading `0` + dimension, and each output `f(params, x, pattern=pattern)[i]` only + depends on the inputs `x[i]` and `pattern[i]`. In this case, we can + pass `vmap_axes=(0, 0, dict(pattern=0)` to specify along which dimensions + inputs, outputs, and keyword arguments are batched respectively. + + This allows us to evaluate Jacobians much more + efficiently. If `vmap_axes` is not a triple, it is interpreted as + `in_axes = out_axes = vmap_axes, kwargs_axes = {}`. For example a very + common use case is `vmap_axes=0` for a neural network with leading (`0`) + batch dimension, both for inputs and outputs, and no interactions between + different elements of the batch (e.g. no BatchNorm, and, in the case of + `nt.stax`, also no Dropout). However, if there is interaction between + batch elements or no concept of a batch axis at all, `vmap_axes` must be + set to `None`, to avoid wrong (and potentially silent) results. + implementation: + `1` or `2`. + + `1` directly instantiates Jacobians and computes their outer + product. + + `2` uses implicit differentiation to avoid instantiating whole + Jacobians at once. The implicit kernel is derived by observing that: + :math:`\Theta = J(X_1) J(X_2)^T = [J(X_1) J(X_2)^T](I)`, + i.e. a linear function :math:`[J(X_1) J(X_2)^T]` applied to an identity + matrix :math:`I`. This allows the computation of the NTK to be + phrased as: :math:`a(v) = J(X_2)^T v`, which is computed by a + vector-Jacobian product; :math:`b(v) = J(X_1) a(v)` which is computed by + a Jacobian-vector product; and :math:`\Theta = [b(v)] / d[v^T](I)` which + is computed via a `vmap` of :math:`b(v)` over columns of the identity + matrix :math:`I`. + + It is best to benchmark each method on your specific task. We suggest + using `1` unless you get OOMs due to large number of trainable parameters, + otherwise - `2`. Returns: - A function to draw a single sample the NNGP and NTK empirical kernels of a - given network `f`. + A function `ntk_fn` that computes the empirical ntk. """ - kernel_fns = { - 'nngp': empirical_nngp_fn(f, trace_axes, diagonal_axes), - 'ntk': empirical_ntk_fn(f, trace_axes, diagonal_axes) - } + kwargs = dict(f=f, + trace_axes=trace_axes, + diagonal_axes=diagonal_axes, + vmap_axes=vmap_axes) - @utils.get_namedtuple('EmpiricalKernel') - def kernel_fn(x1: NTTree[np.ndarray], - x2: Optional[NTTree[np.ndarray]], - get: Union[None, str, Tuple[str, ...]], - params: PyTree, - **apply_fn_kwargs) -> NTTree[Dict[str, np.ndarray]]: - """Computes a single sample of the empirical kernel of type `get`. + if implementation == 1: + return _empirical_direct_ntk_fn(**kwargs) + + if implementation == 2: + return _empirical_implicit_ntk_fn(**kwargs) + + raise ValueError(implementation) + + +def _empirical_implicit_ntk_fn(f: ApplyFn, + trace_axes: Axes = (-1,), + diagonal_axes: Axes = (), + vmap_axes: VMapAxes = None + ) -> Callable[[NTTree[np.ndarray], + Optional[NTTree[np.ndarray]], + PyTree], + NTTree[np.ndarray]]: + """Compute NTK implicitly without instantiating full Jacobians.""" + + def ntk_fn(x1: NTTree[np.ndarray], + x2: Optional[NTTree[np.ndarray]], + params: PyTree, + **apply_fn_kwargs) -> np.ndarray: + """Computes a single sample of the empirical NTK (implicit differentiation). + + Args: + x1: + first batch of inputs. + x2: + second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a + matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. + params: + A `PyTree` of parameters about which we would like to compute the + neural tangent kernel. + **apply_fn_kwargs: + keyword arguments passed to `apply_fn`. `apply_fn_kwargs` will be split + into `apply_fn_kwargs1` and `apply_fn_kwargs2` by the `split_kwargs` + function which will be passed to `apply_fn`. In particular, the rng key + in `apply_fn_kwargs`, will be split into two different (if `x1 != x2`) + or same (if `x1 == x2`) rng keys. See the `_read_key` function for more + details. + + Returns: + A single sample of the empirical NTK. The shape of the kernel is "almost" + `zip(f(x1).shape, f(x2).shape)` except for: + 1) `trace_axes` are absent as they are contracted over. + 2) `diagonal_axes` are present only once. + All other axes are present twice. + """ + kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2) + fx1 = eval_shape(f, params, x1, **kwargs1) + x_axis, fx_axis, kw_axes = _canonicalize_axes(vmap_axes, x1, fx1, **kwargs1) + + keys = apply_fn_kwargs.keys() + args1 = (kwargs1[k] for k in keys) + args2 = (kwargs1[k] if k in kw_axes and kwargs2[k] is None else kwargs2[k] + for k in keys) + + def get_ntk(x1, x2, *args): + args1, args2 = args[:len(args) // 2], args[len(args) // 2 :] + _kwargs1 = {k: v for k, v in zip(keys, args1)} + _kwargs2 = {k: v for k, v in zip(keys, args2)} + + f1 = _get_f_params(f, x1, x_axis, fx_axis, kw_axes, **_kwargs1) + f2 = f1 if utils.all_none(x2) else _get_f_params( + f, x2, x_axis, fx_axis, kw_axes, **_kwargs2) + + def delta_vjp_jvp(delta): + def delta_vjp(delta): + return vjp(f2, params)[1](delta) + return jvp(f1, (params,), delta_vjp(delta))[1] + + fx1, fx2 = eval_shape(f1, params), eval_shape(f2, params) + eye = _std_basis(fx1) + ntk = vmap(linear_transpose(delta_vjp_jvp, fx2))(eye) + ntk = tree_map(lambda fx12: _unravel_array_into_pytree(fx1, 0, fx12), ntk) + ntk = _diagonal(ntk, fx1) + return ntk + + if x_axis is not None or kw_axes: + x2 = x1 if utils.all_none(x2) else x2 + + kw_in_axes = [kw_axes[k] if k in kw_axes else None for k in keys] + in_axes1 = [x_axis, None] + kw_in_axes + [None] * len(kw_in_axes) + in_axes2 = [None, x_axis] + [None] * len(kw_in_axes) + kw_in_axes + + get_ntk = vmap(vmap(get_ntk, + in_axes1, + fx_axis), + in_axes2, + _add(fx_axis, _ndim(fx1))) + + return _trace_and_diagonal(get_ntk(x1, x2, *args1, *args2), + trace_axes, diagonal_axes) + + return ntk_fn + + +def _empirical_direct_ntk_fn(f: ApplyFn, + trace_axes: Axes = (-1,), + diagonal_axes: Axes = (), + vmap_axes: VMapAxes = None + ) -> Callable[[NTTree[np.ndarray], + Optional[NTTree[np.ndarray]], + PyTree], + NTTree[np.ndarray]]: + """Compute NTK by directly instantiating Jacobians and contracting.""" + + @utils.nt_tree_fn(tree_structure_argnum=0) + def sum_and_contract(fx, j1, j2): + ndim = fx.ndim + size = utils.size_at(fx, trace_axes) + + _diagonal_axes = utils.canonicalize_axis(diagonal_axes, ndim) + _trace_axes = utils.canonicalize_axis(trace_axes, ndim) + + def contract(x, y): + param_axes = list(range(x.ndim))[ndim:] + contract_axes = _trace_axes + param_axes + return utils.dot_general(x, y, contract_axes, _diagonal_axes) / size + + return tree_reduce(operator.add, tree_multimap(contract, j1, j2)) + + def ntk_fn(x1: NTTree[np.ndarray], + x2: Optional[NTTree[np.ndarray]], + params: PyTree, + **apply_fn_kwargs) -> np.ndarray: + """Computes a single sample of the empirical NTK (jacobian outer product). Args: x1: @@ -533,9 +723,6 @@ def kernel_fn(x1: NTTree[np.ndarray], x2: second batch of inputs. `x2=None` means `x2=x1`. `f(x2)` must have a matching shape with `f(x1)` on `trace_axes` and `diagonal_axes`. - get: - type of the empirical kernel. `get=None` means `get=("nngp", "ntk")`. - Can be a string (`"nngp"`) or a tuple of strings (`("ntk", "nngp")`). params: A `PyTree` of parameters about which we would like to compute the neural tangent kernel. @@ -548,26 +735,35 @@ def kernel_fn(x1: NTTree[np.ndarray], details. Returns: - A single sample of the empirical kernel. The shape is "almost" + A single sample of the empirical NTK. The shape of the kernel is "almost" `zip(f(x1).shape, f(x2).shape)` except for: 1) `trace_axes` are absent as they are contracted over. 2) `diagonal_axes` are present only once. All other axes are present twice. - - If `get` is a string, returns the requested `np.ndarray`. If `get` is a - tuple, returns an `EmpiricalKernel` namedtuple containing the - requested information. """ - if get is None: - get = ('nngp', 'ntk') + kwargs1, kwargs2 = utils.split_kwargs(apply_fn_kwargs, x1, x2) + fx1 = eval_shape(f, params, x1, **kwargs1) + x_axis, fx_axis, kw_axes = _canonicalize_axes(vmap_axes, x1, fx1, **kwargs1) - out_dict = {g: kernel_fns[g](x1, x2, params, **apply_fn_kwargs) - for g in get} - out_dict = _dict_of_tree_to_tree_of_dict(out_dict, get) + keys = apply_fn_kwargs.keys() + args1, args2 = (kwargs1[k] for k in keys), (kwargs2[k] for k in keys) - return out_dict + def j_fn(x, *args): + _kwargs = {k: v for k, v in zip(keys, args)} + fx = _get_f_params(f, x, x_axis, fx_axis, kw_axes, **_kwargs) + jx = jacobian(fx)(params) + return jx - return kernel_fn + if x_axis is not None or kw_axes: + in_axes = [x_axis] + [kw_axes[k] if k in kw_axes else None for k in keys] + j_fn = vmap(j_fn, in_axes=in_axes, out_axes=fx_axis) + + j1 = j_fn(x1, *args1) + j2 = j_fn(x2, *args2) if not utils.all_none(x2) else j1 + ntk = sum_and_contract(fx1, j1, j2) + return ntk + + return ntk_fn # INTERNAL UTILITIES @@ -634,12 +830,81 @@ def _dict_of_tree_to_tree_of_dict(out_dict, get): *[out_dict[g] for g in get]) -def _get_f_params(f, x, **apply_fn_kwargs): +def _get_f_params(f, x, x_axis, fx_axis, kw_axes, **apply_fn_kwargs): + x = _expand_dims(x, x_axis) + + apply_fn_kwargs = { + k: _expand_dims(v, kw_axes[k]) if k in kw_axes else v + for k, v in apply_fn_kwargs.items() + } + def _f(p): - out = f(p, x, **apply_fn_kwargs) - out = utils.get_masked_array(out) + fx = f(p, x, **apply_fn_kwargs) + fx = utils.get_masked_array(fx) # TODO(romann): normalize properly if output is masked. get_masked = utils.nt_tree_fn()(lambda o: o.masked_value) - return get_masked(out) + fx = get_masked(fx) + return _squeeze(fx, fx_axis) + return _f + + +def _expand_dims(x, axis): + if axis is None or x is None: + return x + return tree_multimap(np.expand_dims, x, axis) + + +def _add(x, y): + if x is None or y is None: + return None + return tree_multimap(operator.add, x, y) + + +def _sub(x, y): + return tree_multimap(operator.sub, x, y) + + +def _div(x, y): + return tree_map(lambda x: x / y, x) + + +def _squeeze(x, axis, take=False): + if axis is None: + return x + if take: + return tree_multimap(lambda x, axis: np.take(x, 0, axis), x, axis) + return tree_multimap(np.squeeze, x, axis) + + +@utils.nt_tree_fn() +def _ndim(x): + return x.ndim + + +def _mod(x, y): + return tree_multimap(operator.mod, x, y) + + +def _diagonal(ntk, fx): + ntk_flat, _ = tree_flatten(ntk) + fx_flat, fx_tree = tree_flatten(fx) + n = len(fx_flat) + diag = [ntk_flat[i * (n + 1)] for i in range(n)] + return tree_unflatten(fx_tree, diag) + + +def _canonicalize_axes(vmap_axes: Optional[VMapAxes], + x: NTTree[np.ndarray], + fx: NTTree[np.ndarray], + **kwargs) -> VMapAxes: + if isinstance(vmap_axes, tuple) and len(vmap_axes) == 3: + x_axis, fx_axis, kw_axes = vmap_axes + else: + x_axis, fx_axis, kw_axes = vmap_axes, vmap_axes, {} + + x_axis = _mod(x_axis, _ndim(x)) if x_axis is not None else None + fx_axis = _mod(fx_axis, _ndim(fx)) if fx_axis is not None else None + kw_axes = _mod(kw_axes, {k: _ndim(kwargs[k]) for k in kw_axes}) + return x_axis, fx_axis, kw_axes diff --git a/neural_tangents/utils/monte_carlo.py b/neural_tangents/utils/monte_carlo.py index 322d16bd..dfe80f67 100644 --- a/neural_tangents/utils/monte_carlo.py +++ b/neural_tangents/utils/monte_carlo.py @@ -37,7 +37,7 @@ from neural_tangents.utils import batch from neural_tangents.utils import empirical from neural_tangents.utils import utils -from neural_tangents.utils.typing import PRNGKey, InitFn, ApplyFn, MonteCarloKernelFn, Axes, Get, EmpiricalKernelFn, PyTree, NTTree +from neural_tangents.utils.typing import PRNGKey, InitFn, ApplyFn, MonteCarloKernelFn, Axes, Get, EmpiricalKernelFn, PyTree, NTTree, VMapAxes def _sample_once_kernel_fn(kernel_fn: EmpiricalKernelFn, @@ -121,9 +121,11 @@ def monte_carlo_kernel_fn( device_count: int = -1, store_on_device: bool = True, trace_axes: Axes = (-1,), - diagonal_axes: Axes = () + diagonal_axes: Axes = (), + vmap_axes: VMapAxes = None, + implementation: int = 1 ) -> MonteCarloKernelFn: - """Return a Monte Carlo sampler of NTK and NNGP kernels of a given function. + r"""Return a Monte Carlo sampler of NTK and NNGP kernels of a given function. Note that the returned function is appropriately batched / parallelized. You don't need to apply the `nt.batch` or `jax.jit` decorators to it. Further, @@ -187,6 +189,43 @@ def monte_carlo_kernel_fn( (instead of covariance) along certain axes. Also related to "batch dimensions" in XLA terms. (https://www.tensorflow.org/xla/operation_semantics#dotgeneral) + vmap_axes: + applicable only to NTK. A triple of `(in_axes, out_axes, kwargs_axes)` + passed to `vmap` to evaluate the empirical NTK in parallel ove these axes. + Precisely, providing this argument implies that `f(params, x, **kwargs)` + equals to a concatenation along `out_axes` of `f` applied to slices of + `x` and `**kwargs` along `in_axes` and `kwargs_axes`, i.e. `f` can be + evaluated as a `vmap`. This allows to evaluate Jacobians much more + efficiently. If `vmap_axes` is not a triple, it is interpreted as + `in_axes = out_axes = vmap_axes, kwargs_axes = {}`. For example a very + common usecase is `vmap_axes=0` for a neural network with leading (`0`) + batch dimension, both for inputs and outputs, and no interactions between + different elements of the batch (e.g. no BatchNorm, and, in the case of + `nt.stax`, also no Dropout). However, if there is interaction between + batch elements or no concept of a batch axis at all, `vmap_axes` must be + set to `None`, to avoid wrong (and potentially silent) results. + implementation: + applicable only to NTK. + + `1` or `2`. + + `1` directly instantiates Jacobians and computes their outer + product. + + `2` uses implicit differentiation to avoid instantiating whole + Jacobians at once. The implicit kernel is derived by observing that: + :math:`\Theta = J(X_1) J(X_2)^T = [J(X_1) J(X_2)^T](I)`, + i.e. a linear function :math:`[J(X_1) J(X_2)^T]` applied to an identity + matrix :math:`I`. This allows the computation of the NTK to be + phrased as: :math:`a(v) = J(X_2)^T v`, which is computed by a + vector-Jacobian product; :math:`b(v) = J(X_1) a(v)` which is computed by + a Jacobian-vector product; and :math:`\Theta = [b(v)] / d[v^T](I)` which + is computed via a `vmap` of :math:`b(v)` over columns of the identity + matrix :math:`I`. + + It is best to benchmark each method on your specific task. We suggest + using `1` unless you get OOMs due to large number of trainable parameters, + otherwise - `2`. Returns: If `n_samples` is an integer, returns a function of signature @@ -231,7 +270,9 @@ def monte_carlo_kernel_fn( """ kernel_fn = empirical.empirical_kernel_fn(apply_fn, trace_axes=trace_axes, - diagonal_axes=diagonal_axes) + diagonal_axes=diagonal_axes, + vmap_axes=vmap_axes, + implementation=implementation) kernel_fn_sample_once = _sample_once_kernel_fn(kernel_fn, init_fn, diff --git a/neural_tangents/utils/typing.py b/neural_tangents/utils/typing.py index 5697951c..56ec6402 100644 --- a/neural_tangents/utils/typing.py +++ b/neural_tangents/utils/typing.py @@ -14,8 +14,7 @@ """Common Type Definitions.""" -from typing import Tuple, Callable, Union, List, Any, Optional, Sequence, \ - Generator, TypeVar +from typing import Tuple, Callable, Union, List, Any, Optional, Sequence, Generator, TypeVar, Dict import jax.numpy as np from neural_tangents.utils.kernel import Kernel @@ -50,7 +49,7 @@ array or kernel objects. """ T = TypeVar('T') -NTTree = Union[List[T], Tuple[T], T] +NTTree = Union[List[T], Tuple[T, ...], T] Shapes = NTTree[Tuple[int, ...]] @@ -131,3 +130,14 @@ Layer = Tuple[InitFn, ApplyFn, AnalyticKernelFn] + + +"""A type alias for kernel inputs/outputs of `FanOut`, `FanInSum`, etc. +""" +Kernels = Union[List[Kernel], Tuple[Kernel, ...]] + + +"""Specifies `(input, output, kwargs)` axes for `vmap` in empirical NTK. +""" +_VMapAxis = Optional[NTTree[int]] +VMapAxes = Tuple[_VMapAxis, _VMapAxis, Dict[str, _VMapAxis]] diff --git a/notebooks/function_space_linearization.ipynb b/notebooks/function_space_linearization.ipynb index 428c49b7..826f92da 100644 --- a/notebooks/function_space_linearization.ipynb +++ b/notebooks/function_space_linearization.ipynb @@ -312,7 +312,8 @@ }, "outputs": [], "source": [ - "ntk = nt.batch(nt.empirical_ntk_fn(f), batch_size=16, device_count=0)\n", + "ntk = nt.batch(nt.empirical_ntk_fn(f, vmap_axes=0),\n", + " batch_size=16, device_count=0)\n", "\n", "g_dd = ntk(train['image'], None, params)\n", "g_td = ntk(test['image'], train['image'], params)" diff --git a/tests/batch_test.py b/tests/batch_test.py index 0446cd96..5c53cc56 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -45,7 +45,7 @@ NETWORK = [FLAT, FLAT, FLAT, FLAT, INTERMEDIATE_CONV] OUTPUT_LOGITS = [1, 2, 3] CONVOLUTION_CHANNELS = 4 -WIDTH = 8 +WIDTH = 4 RTOL = 1e-2 test_utils.update_test_tolerance(f64_tol=5e-5) @@ -437,7 +437,7 @@ def broadcast(arg): def test_parallel_in_out(self, same_inputs): test_utils.stub_out_pmap(batch, 2) rng = random.PRNGKey(0) - input_key1, input_key2, mc_key = random.split(rng, 3) + input_key1, input_key2 = random.split(rng, 2) x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 4, 1)) x2_1, x2_2, x2_3 = random.normal(input_key2, (3, 8, 1)) @@ -503,7 +503,7 @@ def test_parallel_in_out_empirical(self, same_inputs): def net(N_out): return stax.parallel(stax.Dense(N_out), stax.parallel(stax.Dense(N_out + 1), - stax.Dense(N_out + 2))) + stax.Dense(N_out + 2))) # Check NNGP. init_fn, apply_fn, _ = net(WIDTH) diff --git a/tests/empirical_test.py b/tests/empirical_test.py index 40a3c249..5e749212 100644 --- a/tests/empirical_test.py +++ b/tests/empirical_test.py @@ -18,7 +18,7 @@ import operator from absl.testing import absltest from jax import test_util as jtu -from jax.api import jit, tree_multimap +from jax.api import jit, tree_map, tree_multimap from jax.config import config import jax.numpy as np import jax.random as random @@ -80,18 +80,22 @@ def _kernel_fns(key, network, out_logits, diagonal_axes, - trace_axes): + trace_axes, + vmap_axes=None): init_fn, f, _ = _build_network(input_shape, network, out_logits) _, params = init_fn(key, (-1,) + input_shape) - implicit_kernel_fn = empirical.empirical_implicit_ntk_fn(f, trace_axes, - diagonal_axes) - direct_kernel_fn = empirical.empirical_direct_ntk_fn(f, trace_axes, - diagonal_axes) - nngp_kernel_fn = empirical.empirical_nngp_fn(f, trace_axes, diagonal_axes) - - implicit_kernel_fn = jit(implicit_kernel_fn) - direct_kernel_fn = jit(direct_kernel_fn) - nngp_kernel_fn = jit(nngp_kernel_fn) + implicit_kernel_fn = jit(empirical._empirical_implicit_ntk_fn(f, + trace_axes, + diagonal_axes, + vmap_axes)) + direct_kernel_fn = jit(empirical._empirical_direct_ntk_fn(f, + trace_axes, + diagonal_axes, + vmap_axes)) + + nngp_kernel_fn = jit(empirical.empirical_nngp_fn(f, + trace_axes, + diagonal_axes)) return (partial(implicit_kernel_fn, params=params), partial(direct_kernel_fn, params=params), @@ -239,13 +243,27 @@ def testNTKAgainstDirect( implicit, direct, _ = kernel_fn(key, train_shape[1:], network, diagonal_axes=(), trace_axes=()) + implicit_batched, direct_batched, _ = kernel_fn(key, train_shape[1:], + network, + diagonal_axes=(), + trace_axes=(), + vmap_axes=0) + g = implicit(data_self, None) g_direct = direct(data_self, None) + g_batched = implicit_batched(data_self, None) + g_direct_batched = direct_batched(data_self, None) self.assertAllClose(g, g_direct) + self.assertAllClose(g, g_batched) + self.assertAllClose(g, g_direct_batched) g = implicit(data_other, data_self) g_direct = direct(data_other, data_self) + g_batched = implicit_batched(data_other, data_self) + g_direct_batched = direct_batched(data_other, data_self) self.assertAllClose(g, g_direct) + self.assertAllClose(g, g_batched) + self.assertAllClose(g, g_direct_batched) @jtu.parameterized.named_parameters( jtu.cases_from_list({ @@ -292,12 +310,17 @@ def testAxes(self, diagonal_axes, trace_axes): raise absltest.SkipTest( 'diagonal axes must be different from channel axes.') - implicit, direct, nngp = KERNELS['empirical_logits_3']( - key, - (5, 6, 3), - CONV, + get_kernel = KERNELS['empirical_logits_3'] + kwargs = dict( + key=key, + input_shape=(5, 6, 3), + network=CONV, diagonal_axes=diagonal_axes, - trace_axes=trace_axes) + trace_axes=trace_axes + ) + + implicit, direct, nngp = get_kernel(**kwargs) + implicit_batched, direct_batched, _ = get_kernel(**kwargs, vmap_axes=0) n_marg = len(_diagonal_axes) n_chan = len(_trace_axes) @@ -308,8 +331,13 @@ def testAxes(self, diagonal_axes, trace_axes): g_direct = direct(data_self, None) self.assertEqual(g_nngp.shape, g_direct.shape) + g_direct_batched = direct_batched(data_self, None) g = implicit(data_self, None) + g_batched = implicit_batched(data_self, None) + self.assertAllClose(g_direct, g) + self.assertAllClose(g_direct, g_direct_batched) + self.assertAllClose(g_direct, g_batched) if 0 not in _trace_axes and 0 not in _diagonal_axes: g_nngp = nngp(data_other, data_self) @@ -318,8 +346,13 @@ def testAxes(self, diagonal_axes, trace_axes): g_direct = direct(data_other, data_self) self.assertEqual(g_nngp.shape, g_direct.shape) + g_direct_batched = direct_batched(data_other, data_self) g = implicit(data_other, data_self) + g_batched = implicit_batched(data_other, data_self) + self.assertAllClose(g_direct, g) + self.assertAllClose(g_direct, g_direct_batched) + self.assertAllClose(g_direct, g_batched) @jtu.parameterized.named_parameters( jtu.cases_from_list({ @@ -334,7 +367,7 @@ def test_parallel_in_out(self, same_inputs): x2_1, x2_2 = np.split(random.normal(input_key2, (4, 21)), (10,), axis=1) x1 = (x1_1, x1_2) - x2 = (x2_1, x2_2) + x2 = (x2_1, x2_2) if not same_inputs else None def layer(N_out): return stax.parallel(stax.Dense(N_out), stax.Dense(N_out + 1)) @@ -342,17 +375,25 @@ def layer(N_out): init_fn, apply_fn, _ = stax.serial(layer(1024), layer(1)) _, params = init_fn(net_key, (x1_1.shape, x1_2.shape)) - implicit_kernel_fn = empirical.empirical_implicit_ntk_fn(apply_fn) - direct_kernel_fn = empirical.empirical_direct_ntk_fn(apply_fn) - nngp_kernel_fn = empirical.empirical_nngp_fn(apply_fn) - self.assertAllClose(direct_kernel_fn(x1, x2, params), - implicit_kernel_fn(x1, x2, params)) + implicit_kernel_fn = jit(empirical._empirical_implicit_ntk_fn(apply_fn)) + direct_kernel_fn = jit(empirical._empirical_direct_ntk_fn(apply_fn)) + implicit_batched_kernel_fn = jit(empirical._empirical_implicit_ntk_fn( + apply_fn, vmap_axes=(0, 0))) + direct_batched_kernel_fn = jit(empirical._empirical_direct_ntk_fn( + apply_fn, vmap_axes=(0, 0))) + + k_direct = direct_kernel_fn(x1, x2, params) + self.assertAllClose(k_direct, implicit_kernel_fn(x1, x2, params)) + self.assertAllClose(k_direct, direct_batched_kernel_fn(x1, x2, params)) + self.assertAllClose(k_direct, implicit_batched_kernel_fn(x1, x2, params)) + + nngp_kernel_fn = jit(empirical.empirical_nngp_fn(apply_fn)) nngp = nngp_kernel_fn(x1, x2, params) self.assertEqual(len(nngp), 2) - self.assertEqual(nngp[0].shape, (3, 4)) - self.assertEqual(nngp[1].shape, (3, 4)) + self.assertEqual(nngp[0].shape, (3, 3 if same_inputs else 4)) + self.assertEqual(nngp[1].shape, (3, 3 if same_inputs else 4)) @jtu.parameterized.named_parameters( jtu.cases_from_list({ @@ -368,8 +409,8 @@ def test_parallel_nested(self, same_inputs): x2_1, x2_2, x2_3 = np.split(random.normal(input_key2, (4, 33)), (10, 21), axis=1) - x1 = ((x1_1, x1_2), x1_3) - x2 = ((x2_1, x2_2), x2_3) + x1 = ([x1_1, x1_2], x1_3) + x2 = ([x2_1, x2_2], x2_3) if not same_inputs else None def layer(N_out): return stax.parallel(stax.parallel(stax.Dense(N_out), @@ -378,19 +419,93 @@ def layer(N_out): init_fn, apply_fn, _ = stax.serial(layer(1024), layer(1)) - _, params = init_fn(net_key, ((x1_1.shape, x1_2.shape), x1_3.shape)) - implicit_kernel_fn = empirical.empirical_implicit_ntk_fn(apply_fn) - direct_kernel_fn = empirical.empirical_direct_ntk_fn(apply_fn) - nngp_kernel_fn = empirical.empirical_nngp_fn(apply_fn) + _, params = init_fn(net_key, tree_map(np.shape, x1)) + implicit_kernel_fn = jit(empirical._empirical_implicit_ntk_fn(apply_fn)) + direct_kernel_fn = jit(empirical._empirical_direct_ntk_fn(apply_fn)) + + implicit_batched_kernel_fn = jit(empirical._empirical_implicit_ntk_fn( + apply_fn, vmap_axes=([0, 0], 0))) + direct_batched_kernel_fn = jit(empirical._empirical_direct_ntk_fn( + apply_fn, vmap_axes=([0, 0], 0))) - self.assertAllClose(direct_kernel_fn(x1, x2, params), - implicit_kernel_fn(x1, x2, params)) + k_direct = direct_kernel_fn(x1, x2, params) + self.assertAllClose(k_direct, implicit_kernel_fn(x1, x2, params)) + self.assertAllClose(k_direct, direct_batched_kernel_fn(x1, x2, params)) + self.assertAllClose(k_direct, implicit_batched_kernel_fn(x1, x2, params)) + + nngp_kernel_fn = jit(empirical.empirical_nngp_fn(apply_fn)) nngp = nngp_kernel_fn(x1, x2, params) + self.assertEqual(len(nngp), 2) - self.assertEqual(nngp[0][0].shape, (3, 4)) - self.assertEqual(nngp[0][1].shape, (3, 4)) - self.assertEqual(nngp[1].shape, (3, 4)) + nngp_shape = (3, 3 if same_inputs else 4) + self.assertEqual(nngp[0][0].shape, nngp_shape) + self.assertEqual(nngp[0][1].shape, nngp_shape) + self.assertEqual(nngp[1].shape, nngp_shape) + + @jtu.parameterized.named_parameters( + jtu.cases_from_list({ + 'testcase_name': '_same_inputs={}'.format(same_inputs), + 'same_inputs': same_inputs + } for same_inputs in [True, False])) + def test_vmap_axes(self, same_inputs): + n1, n2 = 3, 4 + c1, c2, c3 = 9, 5, 7 + h2, h3, w3 = 6, 8, 2 + + def get_x(n, k): + k1, k2, k3 = random.split(k, 3) + x1 = random.normal(k1, (n, c1)) + x2 = random.normal(k2, (h2, n, c2)) + x3 = random.normal(k3, (c3, w3, n, h3)) + x = [(x1, x2), x3] + return x + + x1 = get_x(n1, random.PRNGKey(1)) + x2 = get_x(n2, random.PRNGKey(2)) if not same_inputs else None + + p1 = random.normal(random.PRNGKey(5), (n1, h2, h2)) + p2 = None if same_inputs else random.normal(random.PRNGKey(6), (n2, h2, h2)) + + init_fn, apply_fn, _ = stax.serial( + stax.parallel( + stax.parallel( + stax.serial(stax.Dense(4, 2., 0.1), + stax.Relu(), + stax.Dense(3, 1., 0.15)), # 1 + stax.serial(stax.Conv(7, (2,), padding='SAME', + dimension_numbers=('HNC', 'OIH', 'NHC')), + stax.Erf(), + stax.Aggregate(1, 0, -1), + stax.GlobalAvgPool(), + stax.Dense(3, 0.5, 0.2)), # 2 + ), + stax.serial( + stax.Conv(5, (2, 3), padding='SAME', + dimension_numbers=('CWNH', 'IOHW', 'HWCN')), + stax.Sin(), + ) # 3 + ), + stax.parallel( + stax.FanInSum(), + stax.Conv(2, (2, 1), dimension_numbers=('HWCN', 'OIHW', 'HNWC')) + ) + ) + + _, params = init_fn(random.PRNGKey(3), tree_map(np.shape, x1)) + implicit = jit(empirical._empirical_implicit_ntk_fn(apply_fn)) + direct = jit(empirical._empirical_direct_ntk_fn(apply_fn)) + + implicit_batched = jit(empirical._empirical_implicit_ntk_fn( + apply_fn, vmap_axes=([(0, 1), 2], [-2, -3], dict(pattern=0)))) + direct_batched = jit(empirical._empirical_direct_ntk_fn( + apply_fn, vmap_axes=([(-2, -2), -2], [0, 1], dict(pattern=-3)))) + + k = direct(x1, x2, params, pattern=(p1, p2)) + + self.assertAllClose(k, implicit(x1, x2, params, pattern=(p1, p2))) + self.assertAllClose(k, direct_batched(x1, x2, params, pattern=(p1, p2))) + self.assertAllClose(k, implicit_batched(x1, x2, params, pattern=(p1, p2))) if __name__ == '__main__': diff --git a/tests/monte_carlo_test.py b/tests/monte_carlo_test.py index 89f15831..8aedd132 100644 --- a/tests/monte_carlo_test.py +++ b/tests/monte_carlo_test.py @@ -177,7 +177,8 @@ def test_monte_carlo_vs_analytic_ntk(self, batch_size, device_count, sample = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, 100, batch_size, device_count, - store_on_device) + store_on_device, + vmap_axes=0) ker_empirical = sample(x1, x2, 'ntk') ker_analytic = stax_kernel_fn(x1, x2, 'ntk') @@ -210,7 +211,7 @@ def test_monte_carlo_generator(self, batch_size, device_count, n_samples = [2**k for k in range(log_n_max)] sample_generator = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, batch_size, device_count, - store_on_device) + store_on_device, vmap_axes=0) if get is None: samples_12 = sample_generator(x1, x2) @@ -221,7 +222,8 @@ def test_monte_carlo_generator(self, batch_size, device_count, sample_fn = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, key, n, batch_size, device_count, - store_on_device) + store_on_device, + vmap_axes=0) sample_12 = sample_fn(x1, x2) sample_34 = sample_fn(x3, x4) self.assertAllClose(s_12, sample_12) @@ -242,7 +244,7 @@ def test_monte_carlo_generator(self, batch_size, device_count, for n, s_12, s_34 in zip(n_samples, samples_12, samples_34): sample_fn = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n, batch_size, - device_count, store_on_device) + device_count, store_on_device, vmap_axes=0) sample_12 = sample_fn(x1, x2, get) sample_34 = sample_fn(x3, x4, get) self.assertAllClose(s_12, sample_12) @@ -270,7 +272,7 @@ def test_parallel_in_out_mc(self, same_inputs, batch_size): input_key1, input_key2, net_key = random.split(rng, 3) x1_1, x1_2, x1_3 = random.normal(input_key1, (3, 2, 5)) - x1 = (x1_1, (x1_2, x1_2)) + x1 = (x1_1, (x1_2, x1_3)) if same_inputs: x2 = None diff --git a/tests/stax_test.py b/tests/stax_test.py index 7af50dd4..a6f11f01 100644 --- a/tests/stax_test.py +++ b/tests/stax_test.py @@ -691,7 +691,10 @@ def test_sparse_inputs(self, act, kernel, do_stabilize): exact = kernel_fn(x_sparse, None, kernel) mc = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, random.split(key, 2)[0], - samples)(x_sparse, None, kernel) + samples, + vmap_axes=0, + implementation=2)( + x_sparse, None, kernel) mc = np.reshape(mc, exact.shape) assert not np.any(np.isnan(exact)) @@ -785,7 +788,8 @@ def _check_agreement_with_empirical( def _get_empirical(n_samples, get): kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, device_count=device_count, - trace_axes=(channel_axis,), batch_size=batch_size + trace_axes=(channel_axis,), batch_size=batch_size, + implementation=2 ) if same_inputs: assert x2 is None @@ -855,6 +859,7 @@ def _test_activation(self, activation_fn, same_inputs, model, get, affine = stax.Dense(1024, W_std, b_std) readout = stax.Dense(output_dim) depth = 1 + else: rtol = 0.1 X0_1 = random.normal(key, (4, 8, 8, 3)) @@ -864,6 +869,7 @@ def _test_activation(self, activation_fn, same_inputs, model, get, stax.Flatten(), stax.Dense(output_dim)) depth = 2 + if platform == 'cpu': num_samplings = 200 rtol *= 2 @@ -875,7 +881,9 @@ def _test_activation(self, activation_fn, same_inputs, model, get, *[affine, activation_fn]*depth, readout) analytic_kernel = kernel_fn(X0_1, X0_2, get) mc_kernel_fn = monte_carlo.monte_carlo_kernel_fn( - init_fn, apply_fn, split, num_samplings) + init_fn, apply_fn, split, num_samplings, implementation=2, + vmap_axes=0 + ) empirical_kernel = mc_kernel_fn(X0_1, X0_2, get) test_utils.assert_close_matrices(self, analytic_kernel, empirical_kernel, rtol) @@ -1177,13 +1185,17 @@ def test_flatten(self, same_inputs): stax.Dense(1024, 2., 0.5)) kernel_fc_mc = monte_carlo.monte_carlo_kernel_fn(init_fc, apply_fc, key, - 200) + 200, implementation=2, + vmap_axes=0) kernel_bot_mc = monte_carlo.monte_carlo_kernel_fn(init_bot, apply_bot, key, - 200) + 200, implementation=2, + vmap_axes=0) kernel_mid_mc = monte_carlo.monte_carlo_kernel_fn(init_mid, apply_mid, key, - 200) + 200, implementation=2, + vmap_axes=0) kernel_top_mc = monte_carlo.monte_carlo_kernel_fn(init_top, apply_top, key, - 200) + 200, implementation=2, + vmap_axes=0) K = kernel_fc(X0_1_flat, X0_2_flat) @@ -1345,7 +1357,10 @@ def test_fan_in_fc(self, same_inputs, axis, n_branches, get, branch_in, kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, - device_count=0 if axis in (0, -2) else -1) + device_count=0 if axis in (0, -2) else -1, + implementation=2, + vmap_axes=None if axis in (0, -2) else 0, + ) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) @@ -1476,7 +1491,10 @@ def test_fan_in_conv(self, apply_fn, key, n_samples, - device_count=0 if axis in (0, -4) else -1) + device_count=0 if axis in (0, -4) else -1, + implementation=2, + vmap_axes=None if axis in (0, -4) else 0, + ) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) @@ -1608,7 +1626,7 @@ def test_conv_nd(self, same_inputs, n, get, proj, use_attn, channels_first, raise ValueError(get) kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( - init_fn, apply_fn, key, n_samples) + init_fn, apply_fn, key, n_samples, implementation=2, vmap_axes=0) exact = kernel_fn(X0_1, X0_2, get=get) empirical = kernel_fn_mc(X0_1, X0_2, get=get) @@ -1757,7 +1775,9 @@ def test_input_req(self, same_inputs): ) correct_conv_fn_mc = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, - key, 400) + key, 400, + implementation=2, + vmap_axes=0) K = correct_conv_fn(x1, x2, get='nngp') K_mc = correct_conv_fn_mc(x1, x2, get='nngp') self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05) @@ -1783,7 +1803,9 @@ def test_input_req(self, same_inputs): ) correct_conv_fn_mc = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, - key, 300) + key, 300, + implementation=2, + vmap_axes=0) K = correct_conv_fn(x1, x2, get='nngp') K_mc = correct_conv_fn_mc(x1, x2, get='nngp') self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05) @@ -1806,7 +1828,9 @@ def test_input_req(self, same_inputs): ) correct_conv_fn_mc = monte_carlo.monte_carlo_kernel_fn(init_fn, apply_fn, - key, 200) + key, 200, + implementation=2, + vmap_axes=0) K = correct_conv_fn(x1, x2, get='ntk') K_mc = correct_conv_fn_mc(x1, x2, get='ntk') self.assertAllClose(K, K_mc, atol=0.01, rtol=0.05) @@ -1897,7 +1921,10 @@ def test_mask_fc(self, same_inputs, get, concat, p, mask_axis, mask_constant): kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, - device_count=0 if concat in (0, -2) else -1) + device_count=0 if concat in (0, -2) else -1, + implementation=2, + vmap_axes=None if concat in (0, -2) else 0, + ) kernel_fn = jit(kernel_fn, static_argnums=(2,)) exact = kernel_fn(x1, x2, get, mask_constant=mask_constant) @@ -2067,7 +2094,9 @@ def get_attn(): kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, - device_count=0 if concat in (0, -n) else -1 + device_count=0 if concat in (0, -n) else -1, + implementation=2, + vmap_axes=None if concat in (0, -n) else 0, ) kernel_fn = jit(kernel_fn, static_argnums=(2,)) @@ -2110,7 +2139,10 @@ def net(logits): init_fn, apply_fn, kernel_fn = net(N if kernel_type == 'nngp' else 1) kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn( - init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,)) + init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,), + implementation=2, + vmap_axes=((0, 0), 0, {}) + ) test_utils.assert_close_matrices(self, kernel_fn(x1, x2, kernel_type), kernel_fn_empirical(x1, x2, kernel_type), @@ -2128,7 +2160,7 @@ def test_parallel_out(self, same_inputs, kernel_type): rtol = RTOL if platform != 'tpu' else 0.05 rng = random.PRNGKey(0) - input_key1, input_key2, mc_key = random.split(rng, 3) + input_key1, mc_key = random.split(rng, 2) x1, x2 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 1)) @@ -2143,7 +2175,9 @@ def net(logits): init_fn, apply_fn, kernel_fn = net(N if kernel_type == 'nngp' else 1) kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn( - init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,)) + init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,), + implementation=2, + vmap_axes=(0, [0, 0], {})) test_utils.assert_close_matrices(self, kernel_fn(x1, x2, kernel_type), @@ -2185,7 +2219,10 @@ def test_parallel_in_out(self, same_inputs, kernel_type): K_readout_fn = jit(functools.partial(readout[2], get=kernel_type)) kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn( - init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,)) + init_fn, apply_fn, mc_key, N_SAMPLES, trace_axes=(-1,), + implementation=2, + vmap_axes=((0, 0), [0, 0, 0], {}) + ) test_utils.assert_close_matrices( self, @@ -2220,8 +2257,8 @@ def test_nested_parallel(self, same_inputs, kernel_type): x1_1, x2_1 = _get_inputs(input_key1, same_inputs, (BATCH_SIZE, 5)) x1_2, x2_2 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 2)) - x1_3, x2_3 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 2, 2, 3)) - x1_4, x2_4 = _get_inputs(input_key2, same_inputs, (BATCH_SIZE, 3, 4)) + x1_3, x2_3 = _get_inputs(input_key3, same_inputs, (BATCH_SIZE, 2, 2, 3)) + x1_4, x2_4 = _get_inputs(input_key4, same_inputs, (BATCH_SIZE, 3, 4)) m1_key, m2_key, m3_key, m4_key = random.split(mask_key, 4) @@ -2252,7 +2289,10 @@ def test_nested_parallel(self, same_inputs, kernel_type): stax.Conv(N_in + 3, (2,))) kernel_fn_empirical = monte_carlo.monte_carlo_kernel_fn( - init_fn, apply_fn, mc_key, N_SAMPLES) + init_fn, apply_fn, mc_key, N_SAMPLES, implementation=2, + vmap_axes=(((((0, 0), 0), 0), (((0, 0), 0), 0), {}) + if platform == 'tpu' else None) + ) test_utils.assert_close_matrices( self, @@ -2403,7 +2443,9 @@ def get_attn(): kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, - device_count=-1 + device_count=-1, + implementation=2, + vmap_axes=0 ) kernel_fn = jit(kernel_fn, static_argnums=(2,)) @@ -2475,7 +2517,9 @@ def test_aggregate(self, get, readout, same_input, activation, test_mask, kernel_mc_fn = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, random.PRNGKey(10), 128, - batch_size=2 if xla_bridge.get_backend().platform == 'tpu' else 0) + batch_size=2 if xla_bridge.get_backend().platform == 'tpu' else 0, + implementation=2, + ) empirical = kernel_mc_fn(x1, x2, get, mask_constant=mask_constant if test_mask else None, pattern=(pattern1, pattern2)) @@ -2542,7 +2586,7 @@ def test_conv_transpose(self, same_inputs, padding, filter_shape, strides, kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key, n_samples, diagonal_axes=diagonal_axes, - device_count=0) + device_count=0, implementation=2, vmap_axes=0) k_mc = kernel_fn_mc(x1, None if diagonal_batch else x2, 'nngp') test_utils.assert_close_matrices(self, k_mc, k, tol) @@ -2804,7 +2848,9 @@ def batch_axes(): diagonal_axes=get_diagonal_axes(), device_count=-1 if (get == 'nngp' and batch_axis == out_b_axis == 0 and - 0 not in c_dims + b_dims) else 0) + 0 not in c_dims + b_dims) else 0, + implementation=2, + ) empirical = kernel_fn_mc(x1=x2 if get == 'cov2' else x1, x2=x2 if get == 'nngp' else None, @@ -2910,7 +2956,8 @@ def test_dot_general_nn(self, same_inputs, get, n, is_rhs, do_pool, kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key1, n_samples, trace_axes=(int(out_c_axis > out_b_axis) if do_pool else 1,), - device_count=0 + device_count=0, + implementation=2 ) empirical = kernel_fn_mc(x1, x2, get, mask_constant=mask_constant) @@ -3057,7 +3104,10 @@ def test_conv_local(self, same_inputs, padding, filter_shape, strides, kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( init_fn, apply_fn, key_mc, n_samples=512, diagonal_axes=diagonal_axes, - device_count=0) + device_count=0, + implementation=2, + vmap_axes=0 + ) k_mc = kernel_fn_mc(x1, None if get == 'cov1' else x2, 'nngp' if get == 'cov1' else get) test_utils.assert_close_matrices(self, k_mc, getattr(k, get), 0.011) @@ -3143,7 +3193,10 @@ def get_nn(conv): # Test against MC. kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( - init_fn, apply_fn, key_mc, n_samples=512, device_count=0) + init_fn, apply_fn, key_mc, n_samples=512, device_count=0, + implementation=2, + vmap_axes=0 + ) k_mc = kernel_fn_mc(x1, x2, get) test_utils.assert_close_matrices(self, k_mc, k_local, 0.015) @@ -3209,7 +3262,7 @@ def test_conv_local_conv(self): def test_double_pool(self): _skip_test() - key1, key2, key_mc = random.split(random.PRNGKey(1), 3) + key1, key2 = random.split(random.PRNGKey(1), 2) x1 = np.cos(random.normal(key1, (2, 4, 6, 3))) x2 = np.sin(random.normal(key2, (3, 4, 6, 3))) @@ -3230,7 +3283,10 @@ def test_double_pool(self): def _test_against_mc(self, apply_fn, init_fn, k, x1, x2, tol=0.01, n=256): kernel_fn_mc = monte_carlo.monte_carlo_kernel_fn( - init_fn, apply_fn, random.PRNGKey(2), n_samples=n, device_count=0) + init_fn, apply_fn, random.PRNGKey(2), n_samples=n, device_count=0, + implementation=2, + vmap_axes=0 + ) k_mc = kernel_fn_mc(x1, x2, 'nngp') test_utils.assert_close_matrices(self, k_mc, k, tol)