-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add checkpointing support for DTensors (#17)
- Loading branch information
Showing
8 changed files
with
888 additions
and
110 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,3 +101,77 @@ jobs: | |
run: | | ||
. .venv/bin/activate | ||
pip uninstall -y ai2-olmo-core | ||
gpu_checks: | ||
name: ${{ matrix.task.name }} | ||
runs-on: ubuntu-latest | ||
timeout-minutes: 8 | ||
env: | ||
BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }} | ||
BEAKER_IMAGE: olmo-torch2-test | ||
BEAKER_WORKSPACE: ai2/llm-testing | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
task: | ||
- name: Test (GPU) | ||
run: pytest -v --color=yes --durations=3 -m gpu src/test/ --ignore-glob='src/test/distributed/fsdp*' --ignore-glob='src/test/distributed/checkpoint*' | ||
|
||
- name: Test checkpoint (GPU) | ||
run: pytest -v --color=yes --durations=3 -m gpu src/test/distributed/checkpoint* | ||
|
||
- name: Test FSDP (GPU) | ||
run: pytest -v --color=yes --durations=3 -m gpu src/test/distributed/fsdp/ | ||
steps: | ||
- name: Determine current commit SHA (pull request) | ||
if: github.event_name == 'pull_request' | ||
run: | | ||
echo "COMMIT_SHA=${{ github.event.pull_request.head.sha }}" >> $GITHUB_ENV | ||
- name: Determine current commit SHA (push) | ||
if: github.event_name != 'pull_request' | ||
run: | | ||
echo "COMMIT_SHA=$GITHUB_SHA" >> $GITHUB_ENV | ||
- name: GPU Tests | ||
uses: allenai/[email protected] | ||
if: env.BEAKER_TOKEN != '' | ||
with: | ||
spec: | | ||
version: v2 | ||
description: OLMo-core ${{ matrix.task.name }} | ||
budget: ai2/oe-training | ||
tasks: | ||
- name: tests | ||
image: | ||
beaker: ${{ env.BEAKER_IMAGE }} | ||
context: | ||
priority: normal | ||
preemptible: true | ||
resources: | ||
gpuCount: 2 | ||
constraints: | ||
cluster: | ||
- ai2/general-cirrascale | ||
- ai2/general-cirrascale-a100-80g-ib | ||
- ai2/allennlp-cirrascale | ||
- ai2/allennlp-elanding-a100-40g | ||
- ai2/pluto-cirrascale | ||
- ai2/jupiter-cirrascale | ||
envVars: | ||
- name: CUBLAS_WORKSPACE_CONFIG | ||
value: ":16:8" | ||
- name: TOKENIZERS_PARALLELISM | ||
value: "false" | ||
- name: AWS_ACCESS_KEY_ID | ||
secret: AWS_ACCESS_KEY_ID | ||
- name: AWS_SECRET_ACCESS_KEY | ||
secret: AWS_SECRET_ACCESS_KEY | ||
command: | ||
- "bash" | ||
- "-c" | ||
- "git clone https://github.com/allenai/OLMo-core.git && cd OLMo-core && git checkout ${{ env.COMMIT_SHA }} && pip install -e .[all] && ${{ matrix.task.run }}" | ||
result: | ||
path: /unused | ||
token: ${{ env.BEAKER_TOKEN }} | ||
workspace: ${{ env.BEAKER_WORKSPACE }} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
""" | ||
Helper functions for dealing with PyTorch's :class:`DTensor`. | ||
""" | ||
|
||
from typing import Optional, Sequence, Tuple | ||
|
||
from torch.distributed._tensor import DTensor | ||
from torch.distributed._tensor.placement_types import Placement, Shard | ||
from torch.distributed.device_mesh import DeviceMesh | ||
|
||
from olmo_core.utils import ShapeType | ||
|
||
from ..utils import get_mesh_coordinates | ||
|
||
|
||
def get_local_shape_and_global_offset( | ||
dtensor: DTensor, rank: Optional[int] = None | ||
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: | ||
""" | ||
Like :func:`compute_local_shape_and_global_offset`, but acts directly on a :class:`DTensor` | ||
instance. | ||
:param dtensor: A DTensor instance. | ||
:param rank: The global rank to compute the local shape and global offsets for. If ``None``, | ||
defaults to the current rank. | ||
:returns: The local shape and global offset. | ||
""" | ||
global_shape = dtensor.shape | ||
mesh = dtensor.device_mesh | ||
placements = dtensor.placements | ||
local_shape, global_offset = compute_local_shape_and_global_offset(global_shape, mesh, placements, rank=rank) | ||
return local_shape, global_offset | ||
|
||
|
||
# Adapted from `torch.distributed._tensor._utils.py`. | ||
def compute_local_shape_and_global_offset( | ||
global_shape: ShapeType, | ||
mesh: DeviceMesh, | ||
placements: Sequence[Placement], | ||
rank: Optional[int] = None, | ||
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: | ||
""" | ||
Compute the local tensor shape and the global offsets into the original tensor | ||
of a DTensor on its current global rank. This is useful for checkpointing purpose. | ||
:param global_shape: The shape of the global unsharded tensor. | ||
:param mesh: The device mesh. | ||
:param placements: The placements of the :class:`DTensor`. | ||
:param rank: The global rank to compute the local shape and global offsets for. If ``None``, | ||
defaults to the current rank. | ||
:returns: The local shape and global offset. | ||
Example (2 host with 4GPUs each):: | ||
# Below is a DeviceMesh with mesh_shape of (2, 4) | ||
mesh = DeviceMesh(device_type="cuda", mesh=[ | ||
[0, 1, 2, 3], | ||
[4, 5, 6, 7] | ||
]) | ||
Let's say we distribute a global_tensor of shape ``(8,4)`` over the above DeviceMesh | ||
with a placements of ``[Shard(0), Shard(0)]``. | ||
The local shape and global offset will be as follows: | ||
- ``rank0 -- local_shape:[1, 4], global_offset:[0, 0]`` | ||
- ``rank1 -- local_shape:[1, 4], global_offset:[1, 0]`` | ||
- ``rank2 -- local_shape:[1, 4], global_offset:[2, 0]`` | ||
- ``rank5 -- local_shape:[1, 4], global_offset:[5, 0]`` | ||
- ``rank3 -- local_shape:[1, 4], global_offset:[3, 0]`` | ||
- ``rank4 -- local_shape:[1, 4], global_offset:[4, 0]`` | ||
- ``rank6 -- local_shape:[1, 4], global_offset:[6, 0]`` | ||
- ``rank7 -- local_shape:[1, 4], global_offset:[7, 0]`` | ||
Let's say we distribute a global_tensor of shape ``(2,)`` over the above DeviceMesh with | ||
a placements of ``[Shard(0)]``. We will not have non-empty local tensor for all the ranks. | ||
The local shape and global offset will be as follows: | ||
- ``rank0 -- local_shape:[1,], global_offset:[0,]`` | ||
- ``rank1 -- local_shape:[1,], global_offset:[1,]`` | ||
- ``rank2 -- local_shape:[0,], global_offset:[2,]`` | ||
- ``rank5 -- local_shape:[0,], global_offset:[2,]`` | ||
- ``rank3 -- local_shape:[0,], global_offset:[2,]`` | ||
- ``rank4 -- local_shape:[0,], global_offset:[2,]`` | ||
- ``rank6 -- local_shape:[0,], global_offset:[2,]`` | ||
- ``rank7 -- local_shape:[0,], global_offset:[2,]`` | ||
""" | ||
my_coordinate = mesh.get_coordinate() if rank is None else get_mesh_coordinates(mesh, rank) | ||
|
||
if my_coordinate is None: | ||
# if rank not in the mesh, return empty offset | ||
return ((), ()) | ||
else: | ||
local_shape = list(global_shape) | ||
global_offset = [0] * len(global_shape) | ||
|
||
for idx, placement in enumerate(placements): | ||
mesh_dim_size = mesh.size(idx) | ||
if isinstance(placement, Shard): | ||
shard_dim = placement.dim | ||
local_offset = [0] * len(global_shape) | ||
assert shard_dim < len( | ||
local_shape | ||
), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" | ||
shard_size, shard_offset = placement._local_shard_size_on_dim( | ||
local_shape[shard_dim], | ||
mesh_dim_size, | ||
my_coordinate[idx], | ||
return_offset=True, | ||
) | ||
|
||
local_shape[shard_dim] = shard_size | ||
local_offset[shard_dim] = shard_offset | ||
|
||
# On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim], | ||
# it means that this dimension has been already sharded in previous placement. | ||
# Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim]. | ||
# Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim]. | ||
if global_offset[shard_dim] <= local_offset[shard_dim]: | ||
global_offset[shard_dim] = local_offset[shard_dim] | ||
else: | ||
global_offset[shard_dim] += local_offset[shard_dim] | ||
|
||
return tuple(local_shape), tuple(global_offset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.