Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checkpointing support for DTensors #17

Merged
merged 27 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,77 @@ jobs:
run: |
. .venv/bin/activate
pip uninstall -y ai2-olmo-core

gpu_checks:
name: ${{ matrix.task.name }}
runs-on: ubuntu-latest
timeout-minutes: 8
env:
BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }}
BEAKER_IMAGE: olmo-torch2-test
BEAKER_WORKSPACE: ai2/llm-testing
strategy:
fail-fast: false
matrix:
task:
- name: Test (GPU)
run: pytest -v --color=yes --durations=3 -m gpu src/test/ --ignore-glob='src/test/distributed/fsdp*' --ignore-glob='src/test/distributed/checkpoint*'

- name: Test checkpoint (GPU)
run: pytest -v --color=yes --durations=3 -m gpu src/test/distributed/checkpoint*

- name: Test FSDP (GPU)
run: pytest -v --color=yes --durations=3 -m gpu src/test/distributed/fsdp/
steps:
- name: Determine current commit SHA (pull request)
if: github.event_name == 'pull_request'
run: |
echo "COMMIT_SHA=${{ github.event.pull_request.head.sha }}" >> $GITHUB_ENV

- name: Determine current commit SHA (push)
if: github.event_name != 'pull_request'
run: |
echo "COMMIT_SHA=$GITHUB_SHA" >> $GITHUB_ENV

- name: GPU Tests
uses: allenai/[email protected]
if: env.BEAKER_TOKEN != ''
with:
spec: |
version: v2
description: OLMo-core ${{ matrix.task.name }}
budget: ai2/oe-training
tasks:
- name: tests
image:
beaker: ${{ env.BEAKER_IMAGE }}
context:
priority: normal
preemptible: true
resources:
gpuCount: 2
constraints:
cluster:
- ai2/general-cirrascale
- ai2/general-cirrascale-a100-80g-ib
- ai2/allennlp-cirrascale
- ai2/allennlp-elanding-a100-40g
- ai2/pluto-cirrascale
- ai2/jupiter-cirrascale
envVars:
- name: CUBLAS_WORKSPACE_CONFIG
value: ":16:8"
- name: TOKENIZERS_PARALLELISM
value: "false"
- name: AWS_ACCESS_KEY_ID
secret: AWS_ACCESS_KEY_ID
- name: AWS_SECRET_ACCESS_KEY
secret: AWS_SECRET_ACCESS_KEY
command:
- "bash"
- "-c"
- "git clone https://github.com/allenai/OLMo-core.git && cd OLMo-core && git checkout ${{ env.COMMIT_SHA }} && pip install -e .[all] && ${{ matrix.task.run }}"
result:
path: /unused
token: ${{ env.BEAKER_TOKEN }}
workspace: ${{ env.BEAKER_WORKSPACE }}
2 changes: 1 addition & 1 deletion docs/source/distributed/checkpoint.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
==========================

.. automodule:: olmo_core.distributed.checkpoint
:members: save_model_and_optim_state, load_model_and_optim_state, unshard_model_state, unshard_optim_state, Checkpointer, StorageMetadata, TensorStorageMetadata
:members: save_model_and_optim_state, load_model_and_optim_state, unshard_model_state, unshard_optim_state, Checkpointer, StorageMetadata, TensorStorageMetadata, TensorShardSpec
:member-order: bysource
4 changes: 4 additions & 0 deletions docs/source/distributed/tensors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@
.. automodule:: olmo_core.distributed.tensors
:members:
:member-order: bysource

.. automodule:: olmo_core.distributed.tensors.dtensor_utils
:members:
:member-order: bysource
460 changes: 352 additions & 108 deletions src/olmo_core/distributed/checkpoint.py

Large diffs are not rendered by default.

127 changes: 127 additions & 0 deletions src/olmo_core/distributed/tensors/dtensor_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""
Helper functions for dealing with PyTorch's :class:`DTensor`.
"""

from typing import Optional, Sequence, Tuple

from torch.distributed._tensor import DTensor
from torch.distributed._tensor.placement_types import Placement, Shard
from torch.distributed.device_mesh import DeviceMesh

from olmo_core.utils import ShapeType

from ..utils import get_mesh_coordinates


def get_local_shape_and_global_offset(
dtensor: DTensor, rank: Optional[int] = None
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
"""
Like :func:`compute_local_shape_and_global_offset`, but acts directly on a :class:`DTensor`
instance.

:param dtensor: A DTensor instance.
:param rank: The global rank to compute the local shape and global offsets for. If ``None``,
defaults to the current rank.

:returns: The local shape and global offset.
"""
global_shape = dtensor.shape
mesh = dtensor.device_mesh
placements = dtensor.placements
local_shape, global_offset = compute_local_shape_and_global_offset(global_shape, mesh, placements, rank=rank)
return local_shape, global_offset


# Adapted from `torch.distributed._tensor._utils.py`.
def compute_local_shape_and_global_offset(
global_shape: ShapeType,
mesh: DeviceMesh,
placements: Sequence[Placement],
rank: Optional[int] = None,
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
"""
Compute the local tensor shape and the global offsets into the original tensor
of a DTensor on its current global rank. This is useful for checkpointing purpose.

:param global_shape: The shape of the global unsharded tensor.
:param mesh: The device mesh.
:param placements: The placements of the :class:`DTensor`.
:param rank: The global rank to compute the local shape and global offsets for. If ``None``,
defaults to the current rank.

:returns: The local shape and global offset.

Example (2 host with 4GPUs each)::

# Below is a DeviceMesh with mesh_shape of (2, 4)
mesh = DeviceMesh(device_type="cuda", mesh=[
[0, 1, 2, 3],
[4, 5, 6, 7]
])

Let's say we distribute a global_tensor of shape ``(8,4)`` over the above DeviceMesh
with a placements of ``[Shard(0), Shard(0)]``.

The local shape and global offset will be as follows:

- ``rank0 -- local_shape:[1, 4], global_offset:[0, 0]``
- ``rank1 -- local_shape:[1, 4], global_offset:[1, 0]``
- ``rank2 -- local_shape:[1, 4], global_offset:[2, 0]``
- ``rank5 -- local_shape:[1, 4], global_offset:[5, 0]``
- ``rank3 -- local_shape:[1, 4], global_offset:[3, 0]``
- ``rank4 -- local_shape:[1, 4], global_offset:[4, 0]``
- ``rank6 -- local_shape:[1, 4], global_offset:[6, 0]``
- ``rank7 -- local_shape:[1, 4], global_offset:[7, 0]``

Let's say we distribute a global_tensor of shape ``(2,)`` over the above DeviceMesh with
a placements of ``[Shard(0)]``. We will not have non-empty local tensor for all the ranks.

The local shape and global offset will be as follows:

- ``rank0 -- local_shape:[1,], global_offset:[0,]``
- ``rank1 -- local_shape:[1,], global_offset:[1,]``
- ``rank2 -- local_shape:[0,], global_offset:[2,]``
- ``rank5 -- local_shape:[0,], global_offset:[2,]``
- ``rank3 -- local_shape:[0,], global_offset:[2,]``
- ``rank4 -- local_shape:[0,], global_offset:[2,]``
- ``rank6 -- local_shape:[0,], global_offset:[2,]``
- ``rank7 -- local_shape:[0,], global_offset:[2,]``
"""
my_coordinate = mesh.get_coordinate() if rank is None else get_mesh_coordinates(mesh, rank)

if my_coordinate is None:
# if rank not in the mesh, return empty offset
return ((), ())
else:
local_shape = list(global_shape)
global_offset = [0] * len(global_shape)

for idx, placement in enumerate(placements):
mesh_dim_size = mesh.size(idx)
if isinstance(placement, Shard):
shard_dim = placement.dim
local_offset = [0] * len(global_shape)
assert shard_dim < len(
local_shape
), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}"
shard_size, shard_offset = placement._local_shard_size_on_dim(
local_shape[shard_dim],
mesh_dim_size,
my_coordinate[idx],
return_offset=True,
)

local_shape[shard_dim] = shard_size
local_offset[shard_dim] = shard_offset

# On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim],
# it means that this dimension has been already sharded in previous placement.
# Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim].
# Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim].
if global_offset[shard_dim] <= local_offset[shard_dim]:
global_offset[shard_dim] = local_offset[shard_dim]
else:
global_offset[shard_dim] += local_offset[shard_dim]

return tuple(local_shape), tuple(global_offset)
16 changes: 16 additions & 0 deletions src/olmo_core/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
import torch.distributed as dist
from torch.distributed.device_mesh import DeviceMesh


def is_distributed() -> bool:
Expand Down Expand Up @@ -83,3 +84,18 @@ def get_gradient_divide_factor(world_size: int) -> float:
while world_size % factor == 0 and world_size / factor > factor:
factor *= 2
return float(factor)


def get_mesh_coordinates(mesh: DeviceMesh, rank: Optional[int] = None) -> Optional[List[int]]:
"""
Calculate the coordinates of a global rank on a device mesh.

:param mesh: The device mesh.
:param rank: The global rank. If ``None``, the current global rank is used.

:return: The coordinates or ``None`` if the rank is not part of the mesh.
"""
rank = rank if rank is not None else get_rank()
rank_coords = (mesh.mesh == rank).nonzero()
assert rank_coords.size(0) in (0, 1)
return rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
4 changes: 3 additions & 1 deletion src/olmo_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import time
from enum import Enum
from typing import Any, Callable, Iterable
from typing import Any, Callable, Iterable, List, Tuple, Union

import numpy as np
import torch
Expand All @@ -28,6 +28,8 @@ def __repr__(self) -> str:
return f"'{str(self)}'"


ShapeType = Union[torch.Size, List[int], Tuple[int, ...]]

# torch.float8 formats require 2.1; we do not support these dtypes on earlier versions
_float8_e4m3fn = getattr(torch, "float8_e4m3fn", None)
_float8_e5m2 = getattr(torch, "float8_e5m2", None)
Expand Down
Loading
Loading