What's Changed
- Adding logging for the number of parameters and optimizer state. by @copybara-service in #125
- Adding automatic cross-device averaging of auxiliary loss/models stats to optimizer. by @copybara-service in #139
- Add
rel_grad_norm
andrel_update_norm
stats logging by @copybara-service in #147 - Fixing bug that would sometimes cause an exception for networks with scalar-valued parameters. by @copybara-service in #151
- [JAX] Migrate XlaBuilder users to emit direct stablehlo MLIR lowerings. by @copybara-service in #161
- Still fixing docs requirements dependencies. by @copybara-service in #166
- Still fixing docs requirements dependencies. by @copybara-service in #169
- Still fixing docs requirements dependencies. by @copybara-service in #171
- Still fixing docs requirements dependencies. by @copybara-service in #173
- Adding capability pass custom arguments to the registration functions, and call them in a custom module, for standard losses in the example code. by @copybara-service in #175
- Fix or ignore some pytype errors. by @copybara-service in #177
- [LSC] Ignore incorrect type annotations related to jax.numpy APIs by @copybara-service in #176
-
- Adding a
sum_of_objects
. by @copybara-service in #190
- Adding a
-
- Adding Polyak averaging feature to example experiments codebase. by @copybara-service in #195
- Adding precon_damping_mult feature to optimizer. by @copybara-service in #196
- Reland jax-ml/jax#10573. by @copybara-service in #199
-
- minor refactoring by @copybara-service in #201
- Fixing issue where loss_registered_reldiff was not computed properly in multi-device settings. by @copybara-service in #202
- Adding a new schedule and applying some fixes to existing ones in the examples codebase. by @copybara-service in #204
- Remove gradient normalization from the preconditioning function by @copybara-service in #206
Full Changelog: v0.0.5...v0.0.6