From 9d8bf89df386cb892369b738a176cc1d928d1614 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 09:01:16 -0700 Subject: [PATCH] improve docs --- docs/source/distributed/tensors.rst | 4 ++ src/olmo_core/distributed/checkpoint.py | 3 +- .../distributed/tensors/dtensor_utils.py | 50 ++++++++++++------- 3 files changed, 38 insertions(+), 19 deletions(-) diff --git a/docs/source/distributed/tensors.rst b/docs/source/distributed/tensors.rst index ea2af14b..25fc5ac5 100644 --- a/docs/source/distributed/tensors.rst +++ b/docs/source/distributed/tensors.rst @@ -4,3 +4,7 @@ .. automodule:: olmo_core.distributed.tensors :members: :member-order: bysource + +.. automodule:: olmo_core.distributed.tensors.dtensor_utils + :members: + :member-order: bysource diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index d3432293..b98b622d 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -8,7 +8,8 @@ -------- - Sharded distributed models, such OLMo-core's :class:`~olmo_core.distributed.fsdp.FSDP` or PyTorch's - :class:`~torch.distributed.fsdp.FullyShardedDataParallel` are supported out-of-the-box. + :class:`~torch.distributed.fsdp.FullyShardedDataParallel` (with ``use_orig_params=True``) + are supported out-of-the-box. - Utilizes `safetensors `_ under the hood for fast, efficient, and safe serialization/deserialization. - Save with one distributed topology, seamlessly load with a different one. For example, diff --git a/src/olmo_core/distributed/tensors/dtensor_utils.py b/src/olmo_core/distributed/tensors/dtensor_utils.py index f2404dda..d26dbcee 100644 --- a/src/olmo_core/distributed/tensors/dtensor_utils.py +++ b/src/olmo_core/distributed/tensors/dtensor_utils.py @@ -16,6 +16,16 @@ def get_local_shape_and_global_offset( dtensor: DTensor, rank: Optional[int] = None ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + """ + Like :func:`compute_local_shape_and_global_offset`, but acts directly on a :class:`DTensor` + instance. + + :param dtensor: A DTensor instance. + :param rank: The global rank to compute the local shape and global offsets for. If ``None``, + defaults to the current rank. + + :returns: The local shape and global offset. + """ global_shape = dtensor.shape mesh = dtensor.device_mesh placements = dtensor.placements @@ -41,9 +51,11 @@ def compute_local_shape_and_global_offset( :param rank: The global rank to compute the local shape and global offsets for. If ``None``, defaults to the current rank. - Example (2 host with 4GPUs each): + :returns: The local shape and global offset. + + Example (2 host with 4GPUs each):: - # Below is a DeviceMesh with mesh_shape of ``(2, 4)`` + # Below is a DeviceMesh with mesh_shape of (2, 4) mesh = DeviceMesh(device_type="cuda", mesh=[ [0, 1, 2, 3], [4, 5, 6, 7] @@ -53,27 +65,29 @@ def compute_local_shape_and_global_offset( with a placements of ``[Shard(0), Shard(0)]``. The local shape and global offset will be as follows: - rank0 -- local_shape:[1, 4], global_offset:[0, 0] - rank1 -- local_shape:[1, 4], global_offset:[1, 0] - rank2 -- local_shape:[1, 4], global_offset:[2, 0] - rank5 -- local_shape:[1, 4], global_offset:[5, 0] - rank3 -- local_shape:[1, 4], global_offset:[3, 0] - rank4 -- local_shape:[1, 4], global_offset:[4, 0] - rank6 -- local_shape:[1, 4], global_offset:[6, 0] - rank7 -- local_shape:[1, 4], global_offset:[7, 0] + + - ``rank0 -- local_shape:[1, 4], global_offset:[0, 0]`` + - ``rank1 -- local_shape:[1, 4], global_offset:[1, 0]`` + - ``rank2 -- local_shape:[1, 4], global_offset:[2, 0]`` + - ``rank5 -- local_shape:[1, 4], global_offset:[5, 0]`` + - ``rank3 -- local_shape:[1, 4], global_offset:[3, 0]`` + - ``rank4 -- local_shape:[1, 4], global_offset:[4, 0]`` + - ``rank6 -- local_shape:[1, 4], global_offset:[6, 0]`` + - ``rank7 -- local_shape:[1, 4], global_offset:[7, 0]`` Let's say we distribute a global_tensor of shape ``(2,)`` over the above DeviceMesh with a placements of ``[Shard(0)]``. We will not have non-empty local tensor for all the ranks. The local shape and global offset will be as follows: - rank0 -- local_shape:[1,], global_offset:[0,] - rank1 -- local_shape:[1,], global_offset:[1,] - rank2 -- local_shape:[0,], global_offset:[2,] - rank5 -- local_shape:[0,], global_offset:[2,] - rank3 -- local_shape:[0,], global_offset:[2,] - rank4 -- local_shape:[0,], global_offset:[2,] - rank6 -- local_shape:[0,], global_offset:[2,] - rank7 -- local_shape:[0,], global_offset:[2,] + + - ``rank0 -- local_shape:[1,], global_offset:[0,]`` + - ``rank1 -- local_shape:[1,], global_offset:[1,]`` + - ``rank2 -- local_shape:[0,], global_offset:[2,]`` + - ``rank5 -- local_shape:[0,], global_offset:[2,]`` + - ``rank3 -- local_shape:[0,], global_offset:[2,]`` + - ``rank4 -- local_shape:[0,], global_offset:[2,]`` + - ``rank6 -- local_shape:[0,], global_offset:[2,]`` + - ``rank7 -- local_shape:[0,], global_offset:[2,]`` """ my_coordinate = mesh.get_coordinate() if rank is None else get_mesh_coordinates(mesh, rank)