Releases: google/neural-tangents
Releases · google/neural-tangents
v0.6.5
Maintenance release:
- refactoring and minor improvements
- Support and require
v0.6.4
Improvements:
- Support Python 3.11
- Use modern generic type annotations
- Various bugfixes, compatibility and documentation improvements
Breaking changes:
v0.6.2
New features:
nt.stax.repeat
layer allowing fast compilation of very deep networks (see #168 and thanks @jglaser!)- Add a Colab notebook accompanying Precise Learning Curves and Higher-Order Scaling Limits for Dot Product Kernel Regression
Improvements:
Breaking changes:
v0.6.1
New features:
-
nt.stax
: -
nt.empirical
:- An efficient NTK-vector product function
nt.empirical_ntk_vp_fn
(without instantiating the NTK).
- An efficient NTK-vector product function
Improvements:
v0.6.0
New features:
-
nt.empirical
:- New
implementation=3
fornt.empirical
, allowing to often speed-up or reduce the memory of the empirical NTK by orders of magnitude. Please see our ICML2022 paper Fast Finite Width Neural Tangent Kernel, new empirical NTK examples, and visit us on Thursday at ICML in-person! - New experimental prototype of using our empirical NTK implementations in Tensorflow via
nt.experimental.empirical_ntk_fn_tf
. - Make
nt.empircial
work with arbitrary pytrees.
- New
-
nt.stax
:
Improvements:
- Slightly lower memory usage in batching.
- Many improvements to documentation and type annotations.
- Simplify test specifications and avoid relying on JAX testing utilities.
Bugfixes:
Breaking changes:
v0.5.0
Potentially breaking changes:
- Significant internal refactoring, notably splitting
stax
into multiple sub-modules, and moving implementations into an_src
folder. This could break your code if you use internal function likent.utils.typing
,nt.utils.utils
,nt.utils.Kernel
etc. (public API will remain unchanged). This should be easily fixed by updating the imports, e.g.nt.utils -> nt._src.utils
.
New features:
v0.4.0
WARNING:
Our next major release (v0.5.0) will include significant refactoring, and could break your code if you use internal function like nt.utils.typing
, nt.utils.utils
, nt.utils.Kernel
etc. (public API will remain unchanged). This should be easily fixed by updating the imports, e.g. nt.utils -> nt._src.utils
.
This release (v0.4.0):
New feature:
Improvements:
- Various internal refactoring and tighter tests.
Bugfixes:
- Fix values and gradients of non-differentiable
kernel_fn
at zero inputs to be consistent with finite-width kernels, and how JAX defines gradients of non-differentiable functions to be the mean sub-gradient, see also #123. - Fix wrong treatment of
b_std=None
in the infinite-width limit withparameterization='standard'
, see also #123. - Fix a bug in
nt.batch
whenx2 = None
and inputs are PyTrees.
Breaking changes:
- Bump requirements to
jax==0.3
andfrozendict==2.3
.
v0.3.9
v0.3.8
New Features:
stax.Elementwise
- a layer for generic elementwise functions requiring the user to specify only scalar-valuednngp_fn : (cov12, var1, var2) |-> E[fn(x_1) * fn(x_2)]
. The NTK computation (thanks to @SiuMath) and vectorization over the underlyingKernel
happen automatically under the hood. If you can't derive thenngp_fn
for your function, usestax.ElementwiseNumerical
. See docs for more details.
Bugfixes:
- Compatibility with JAX 0.2.21.
Full Changelog: v0.3.7...v0.3.8
v0.3.7
New Features:
nt.stax.Cos
nt.stax.ImageResize
- New implementation
implementation="SPARSE"
innt.stax.Aggregate
for efficient handling of sparse graphs (see #86, #9) - Support
approximate=True
innt.stax.Gelu
Bugfixes:
- Fix a bug that might alter
Kernel
requirements - Fix
nt.batch
handling ofdiagonal_axes
(see #87) - Remove the frequent but redundant warning about type conversion in
kernel_fn
- Minor fixes to documentation and code clean-up
Breaking changes: