From 6401ee610935288e26db55690d37ca272fd9f1a9 Mon Sep 17 00:00:00 2001 From: T5X Team Date: Thu, 31 Oct 2024 09:20:11 -0700 Subject: [PATCH] Fix a bug in Orbax checkpointing where None values PyTree are not handled correctly. This is caused by a recent update in jax where None values are no longer considered as a leaf node: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-34-october-4-2024. PiperOrigin-RevId: 691828986 --- t5x/checkpoints.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/t5x/checkpoints.py b/t5x/checkpoints.py index 315604c87..39118f050 100644 --- a/t5x/checkpoints.py +++ b/t5x/checkpoints.py @@ -2165,7 +2165,12 @@ def _modify_orbax_param_info(info, value): ) return info - item_ = jax.tree.map(_make_orbax_internal_metadata, item_, restore_args) + item_ = jax.tree.map( + lambda x, y: None if x is None else _make_orbax_internal_metadata(x, y), + item_, + restore_args, + is_leaf=lambda x: x is None, + ) param_infos_, _ = checkpoint_utils.get_restore_parameters(directory_, item_) param_infos_ = jax.tree.map( _modify_orbax_param_info, param_infos_, state_dict_to_restore