Skip to content

Commit

Permalink
improve docs
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed May 10, 2024
1 parent de2ec70 commit 9d8bf89
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 19 deletions.
4 changes: 4 additions & 0 deletions docs/source/distributed/tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@
.. automodule:: olmo_core.distributed.tensors
:members:
:member-order: bysource

.. automodule:: olmo_core.distributed.tensors.dtensor_utils
:members:
:member-order: bysource
3 changes: 2 additions & 1 deletion src/olmo_core/distributed/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
--------
- Sharded distributed models, such OLMo-core's :class:`~olmo_core.distributed.fsdp.FSDP` or PyTorch's
:class:`~torch.distributed.fsdp.FullyShardedDataParallel` are supported out-of-the-box.
:class:`~torch.distributed.fsdp.FullyShardedDataParallel` (with ``use_orig_params=True``)
are supported out-of-the-box.
- Utilizes `safetensors <https://huggingface.co/docs/safetensors/>`_ under the hood for fast, efficient, and
safe serialization/deserialization.
- Save with one distributed topology, seamlessly load with a different one. For example,
Expand Down
50 changes: 32 additions & 18 deletions src/olmo_core/distributed/tensors/dtensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
def get_local_shape_and_global_offset(
dtensor: DTensor, rank: Optional[int] = None
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
"""
Like :func:`compute_local_shape_and_global_offset`, but acts directly on a :class:`DTensor`
instance.
:param dtensor: A DTensor instance.
:param rank: The global rank to compute the local shape and global offsets for. If ``None``,
defaults to the current rank.
:returns: The local shape and global offset.
"""
global_shape = dtensor.shape
mesh = dtensor.device_mesh
placements = dtensor.placements
Expand All @@ -41,9 +51,11 @@ def compute_local_shape_and_global_offset(
:param rank: The global rank to compute the local shape and global offsets for. If ``None``,
defaults to the current rank.
Example (2 host with 4GPUs each):
:returns: The local shape and global offset.
Example (2 host with 4GPUs each)::
# Below is a DeviceMesh with mesh_shape of ``(2, 4)``
# Below is a DeviceMesh with mesh_shape of (2, 4)
mesh = DeviceMesh(device_type="cuda", mesh=[
[0, 1, 2, 3],
[4, 5, 6, 7]
Expand All @@ -53,27 +65,29 @@ def compute_local_shape_and_global_offset(
with a placements of ``[Shard(0), Shard(0)]``.
The local shape and global offset will be as follows:
rank0 -- local_shape:[1, 4], global_offset:[0, 0]
rank1 -- local_shape:[1, 4], global_offset:[1, 0]
rank2 -- local_shape:[1, 4], global_offset:[2, 0]
rank5 -- local_shape:[1, 4], global_offset:[5, 0]
rank3 -- local_shape:[1, 4], global_offset:[3, 0]
rank4 -- local_shape:[1, 4], global_offset:[4, 0]
rank6 -- local_shape:[1, 4], global_offset:[6, 0]
rank7 -- local_shape:[1, 4], global_offset:[7, 0]
- ``rank0 -- local_shape:[1, 4], global_offset:[0, 0]``
- ``rank1 -- local_shape:[1, 4], global_offset:[1, 0]``
- ``rank2 -- local_shape:[1, 4], global_offset:[2, 0]``
- ``rank5 -- local_shape:[1, 4], global_offset:[5, 0]``
- ``rank3 -- local_shape:[1, 4], global_offset:[3, 0]``
- ``rank4 -- local_shape:[1, 4], global_offset:[4, 0]``
- ``rank6 -- local_shape:[1, 4], global_offset:[6, 0]``
- ``rank7 -- local_shape:[1, 4], global_offset:[7, 0]``
Let's say we distribute a global_tensor of shape ``(2,)`` over the above DeviceMesh with
a placements of ``[Shard(0)]``. We will not have non-empty local tensor for all the ranks.
The local shape and global offset will be as follows:
rank0 -- local_shape:[1,], global_offset:[0,]
rank1 -- local_shape:[1,], global_offset:[1,]
rank2 -- local_shape:[0,], global_offset:[2,]
rank5 -- local_shape:[0,], global_offset:[2,]
rank3 -- local_shape:[0,], global_offset:[2,]
rank4 -- local_shape:[0,], global_offset:[2,]
rank6 -- local_shape:[0,], global_offset:[2,]
rank7 -- local_shape:[0,], global_offset:[2,]
- ``rank0 -- local_shape:[1,], global_offset:[0,]``
- ``rank1 -- local_shape:[1,], global_offset:[1,]``
- ``rank2 -- local_shape:[0,], global_offset:[2,]``
- ``rank5 -- local_shape:[0,], global_offset:[2,]``
- ``rank3 -- local_shape:[0,], global_offset:[2,]``
- ``rank4 -- local_shape:[0,], global_offset:[2,]``
- ``rank6 -- local_shape:[0,], global_offset:[2,]``
- ``rank7 -- local_shape:[0,], global_offset:[2,]``
"""
my_coordinate = mesh.get_coordinate() if rank is None else get_mesh_coordinates(mesh, rank)

Expand Down

0 comments on commit 9d8bf89

Please sign in to comment.