Skip to content

Commit

Permalink
fix serialization of partitioned states in optax (#852)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Jan 9, 2025
1 parent 93a8aa9 commit ab7dc65
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/levanter/tensorstore_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def path_from_key_path(key_path):

def _sharding_from_leaf(leaf, axis_mapping, mesh) -> Optional[jax.sharding.Sharding]:
if is_named_array(leaf):
if leaf.array is None:
if not is_jax_array_like(leaf.array):
return None
return hax.partitioning.sharding_for_axis(leaf.axes, axis_mapping, mesh)
elif hasattr(leaf, "sharding") and getattr(leaf, "sharding") is not None:
Expand Down Expand Up @@ -140,11 +140,11 @@ def tree_deserialize_leaves_tensorstore(
manager = array_ser.GlobalAsyncCheckpointManager()

shardings: PyTree[Optional[Sharding]] = jtu.tree_map(
partial(_sharding_from_leaf, axis_mapping=axis_mapping, mesh=mesh), pytree, is_leaf=is_named_array
partial(_sharding_from_leaf, axis_mapping=axis_mapping, mesh=mesh), pytree, is_leaf=_is_named_or_none
)

# TODO: support ShapeDtypeStructs that are not NamedArrays
leaf_key_paths = jax_utils.leaf_key_paths(shardings, is_leaf=is_named_array)
leaf_key_paths = jax_utils.leaf_key_paths(shardings, is_leaf=_is_named_or_none)
paths = _fs_paths_from_key_paths(checkpoint_dir, leaf_key_paths)
paths = jtu.tree_leaves(paths, is_leaf=lambda x: x is None)

Expand All @@ -157,6 +157,8 @@ def tree_deserialize_leaves_tensorstore(
real_leaves = [x for x in shardings_leaves if x is not None]
real_paths = [paths[i] for i in real_indices]

assert len(real_leaves) == len(real_paths), f"{len(real_leaves)} != {len(real_paths)}"

deser_leaves = manager.deserialize_with_paths(shardings=real_leaves, paths=real_paths)
# now we need to recreate the original structure

Expand Down

0 comments on commit ab7dc65

Please sign in to comment.