Skip to content

Commit

Permalink
Fix a bug in Orbax checkpointing where None values PyTree are not han…
Browse files Browse the repository at this point in the history
…dled correctly. This is caused by a recent update in jax where None values are no longer considered as a leaf node: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-34-october-4-2024.

PiperOrigin-RevId: 691828986
  • Loading branch information
T5X Team authored and t5-copybara committed Oct 31, 2024
1 parent c723ab0 commit 6401ee6
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2165,7 +2165,12 @@ def _modify_orbax_param_info(info, value):
)
return info

item_ = jax.tree.map(_make_orbax_internal_metadata, item_, restore_args)
item_ = jax.tree.map(
lambda x, y: None if x is None else _make_orbax_internal_metadata(x, y),
item_,
restore_args,
is_leaf=lambda x: x is None,
)
param_infos_, _ = checkpoint_utils.get_restore_parameters(directory_, item_)
param_infos_ = jax.tree.map(
_modify_orbax_param_info, param_infos_, state_dict_to_restore
Expand Down

0 comments on commit 6401ee6

Please sign in to comment.