From ae92b8b4b7be031179ff954c65b55acc239ffd05 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 9 May 2024 11:29:13 -0700 Subject: [PATCH 01/27] Add some helpers for dealing with `DTensor`s --- .../distributed/tensors/dtensor_utils.py | 102 ++++++++++++++++++ src/olmo_core/distributed/utils.py | 16 +++ src/olmo_core/utils.py | 4 +- 3 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 src/olmo_core/distributed/tensors/dtensor_utils.py diff --git a/src/olmo_core/distributed/tensors/dtensor_utils.py b/src/olmo_core/distributed/tensors/dtensor_utils.py new file mode 100644 index 00000000..257df20a --- /dev/null +++ b/src/olmo_core/distributed/tensors/dtensor_utils.py @@ -0,0 +1,102 @@ +""" +Helper functions for dealing with PyTorch's :class:`DTensor`. +""" + +from typing import Optional, Sequence, Tuple + +from torch.distributed._tensor.placement_types import Placement, Shard +from torch.distributed.device_mesh import DeviceMesh + +from olmo_core.utils import ShapeType + +from ..utils import get_mesh_coordinates + + +# Adapted from `torch.distributed._tensor._utils.py`. +def compute_local_shape_and_global_offset( + global_shape: ShapeType, + mesh: DeviceMesh, + placements: Sequence[Placement], + rank: Optional[int] = None, +) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + """ + Compute the local tensor shape and the global offsets into the original tensor + of a DTensor on its current global rank. This is useful for checkpointing purpose. + + :param global_shape: The shape of the global unsharded tensor. + :param mesh: The device mesh. + :param placements: The placements of the :class:`DTensor`. + :param rank: The global rank to compute the local shape and global offsets for. If ``None``, + defaults to the current rank. + + Example (2 host with 4GPUs each): + + # Below is a DeviceMesh with mesh_shape of ``(2, 4)`` + mesh = DeviceMesh(device_type="cuda", mesh=[ + [0, 1, 2, 3], + [4, 5, 6, 7] + ]) + + Let's say we distribute a global_tensor of shape ``(8,4)`` over the above DeviceMesh + with a placements of ``[Shard(0), Shard(0)]``. + + The local shape and global offset will be as follows: + rank0 -- local_shape:[1, 4], global_offset:[0, 0] + rank1 -- local_shape:[1, 4], global_offset:[1, 0] + rank2 -- local_shape:[1, 4], global_offset:[2, 0] + rank5 -- local_shape:[1, 4], global_offset:[5, 0] + rank3 -- local_shape:[1, 4], global_offset:[3, 0] + rank4 -- local_shape:[1, 4], global_offset:[4, 0] + rank6 -- local_shape:[1, 4], global_offset:[6, 0] + rank7 -- local_shape:[1, 4], global_offset:[7, 0] + + Let's say we distribute a global_tensor of shape ``(2,)`` over the above DeviceMesh with + a placements of ``[Shard(0)]``. We will not have non-empty local tensor for all the ranks. + + The local shape and global offset will be as follows: + rank0 -- local_shape:[1,], global_offset:[0,] + rank1 -- local_shape:[1,], global_offset:[1,] + rank2 -- local_shape:[0,], global_offset:[2,] + rank5 -- local_shape:[0,], global_offset:[2,] + rank3 -- local_shape:[0,], global_offset:[2,] + rank4 -- local_shape:[0,], global_offset:[2,] + rank6 -- local_shape:[0,], global_offset:[2,] + rank7 -- local_shape:[0,], global_offset:[2,] + """ + my_coordinate = mesh.get_coordinate() if rank is None else get_mesh_coordinates(mesh, rank) + + if my_coordinate is None: + # if rank not in the mesh, return empty offset + return ((), ()) + else: + local_shape = list(global_shape) + global_offset = [0] * len(global_shape) + + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + if isinstance(placement, Shard): + shard_dim = placement.dim + local_offset = [0] * len(global_shape) + assert shard_dim < len( + local_shape + ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + shard_size, shard_offset = placement._local_shard_size_on_dim( + local_shape[shard_dim], + mesh_dim_size, + my_coordinate[idx], + return_offset=True, + ) + + local_shape[shard_dim] = shard_size + local_offset[shard_dim] = shard_offset + + # On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim], + # it means that this dimension has been already sharded in previous placement. + # Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim]. + # Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim]. + if global_offset[shard_dim] <= local_offset[shard_dim]: + global_offset[shard_dim] = local_offset[shard_dim] + else: + global_offset[shard_dim] += local_offset[shard_dim] + + return tuple(local_shape), tuple(global_offset) diff --git a/src/olmo_core/distributed/utils.py b/src/olmo_core/distributed/utils.py index a63a46d7..d526170d 100644 --- a/src/olmo_core/distributed/utils.py +++ b/src/olmo_core/distributed/utils.py @@ -3,6 +3,7 @@ import torch import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh def is_distributed() -> bool: @@ -83,3 +84,18 @@ def get_gradient_divide_factor(world_size: int) -> float: while world_size % factor == 0 and world_size / factor > factor: factor *= 2 return float(factor) + + +def get_mesh_coordinates(mesh: DeviceMesh, rank: Optional[int] = 0) -> Optional[List[int]]: + """ + Calculate the coordinates of a global rank on a device mesh. + + :param mesh: The device mesh. + :param rank: The global rank. If ``None``, the current global rank is used. + + :return: The coordinates or ``None`` if the rank is not part of the mesh. + """ + rank = rank if rank is not None else get_rank() + rank_coords = (mesh.mesh == get_rank()).nonzero() + assert rank_coords.size(0) in (0, 1) + return rank_coords[0].tolist() if rank_coords.size(0) > 0 else None diff --git a/src/olmo_core/utils.py b/src/olmo_core/utils.py index 6338f1fb..9ffd12a3 100644 --- a/src/olmo_core/utils.py +++ b/src/olmo_core/utils.py @@ -3,7 +3,7 @@ import os import time from enum import Enum -from typing import Any, Callable, Iterable +from typing import Any, Callable, Iterable, List, Tuple, Union import numpy as np import torch @@ -28,6 +28,8 @@ def __repr__(self) -> str: return f"'{str(self)}'" +ShapeType = Union[torch.Size, List[int], Tuple[int, ...]] + # torch.float8 formats require 2.1; we do not support these dtypes on earlier versions _float8_e4m3fn = getattr(torch, "float8_e4m3fn", None) _float8_e5m2 = getattr(torch, "float8_e5m2", None) From ab5961df3b65f53dc4a12807335fb88c96cdaefc Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 9 May 2024 12:41:49 -0700 Subject: [PATCH 02/27] Make checkpoint loading more efficient in some cases --- src/olmo_core/distributed/checkpoint.py | 54 +++++++++++++++++-------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index ab9f0a4f..cc0093fc 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -353,6 +353,15 @@ def load( metadata = metadata or self.get_metadata(dir, no_dist=no_dist) safetensors_mfl = _safetensors_mfl or SafeTensorsMultiFileLoader() + def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: str, filename: str): + if len((shape_in_file := loader.get_shape(key))) != 1: + raise ValueError(f"Expected a 1D tensor at {key} in {filename}, found shape {shape_in_file}") + + if (dtype := loader.get_dtype(key)) != tensor.dtype: + raise ValueError( + f"Data type mismatch between tensor to load ({dtype}) and to load into ({tensor.dtype})" + ) + # Load each tensor from the slices in each file. for key in state_dict.keys(): log.debug("Loading tensor '%s' from state dict...", key) @@ -360,24 +369,44 @@ def load( tensor = state_dict[key] flat_view = self._get_flat_view(tensor) - # Rank 0 will always be present, other ranks will not be for regular unsharded tensors. + # Make sure full unsharded shapes match. + if flat_view.full_shape != tensor_storage_metadata.shape: + raise ValueError( + f"Shape mismatched for '{key}', expected {flat_view.full_shape}, found {tensor_storage_metadata.shape}" + ) + + # Get the local offsets to load. all_offsets: Tuple[Tuple[int, int], ...] if flat_view.is_sharded: all_offsets = flat_view.flattened_offsets_per_rank[get_rank()] else: + # NOTE: Rank 0 will always be present, other ranks will not be for regular unsharded tensors. if get_rank() in flat_view.flattened_offsets_per_rank: all_offsets = flat_view.flattened_offsets_per_rank[get_rank()] else: all_offsets = next(iter(flat_view.flattened_offsets_per_rank.values())) - if flat_view.full_shape != tensor_storage_metadata.shape: - raise ValueError( - f"Shape mismatched for '{key}', expected {flat_view.full_shape}, found {tensor_storage_metadata.shape}" - ) + for filename, all_offsets_in_file in tensor_storage_metadata.flattened_offsets_per_file.items(): + if not all_offsets_in_file: + continue + + if all_offsets == all_offsets_in_file: + # Load the whole slice within the file at once. + with safetensors_mfl.open(f"{dir}/{filename}") as loader: + validate_shard_in_file(tensor, loader, key, filename) + numel_in_file = loader.get_numel(key) + if numel_in_file > 0: + flat_view.view.copy_(loader.get_flat_slice(key)) + break + + for offsets in all_offsets: + if offsets[1] - offsets[0] == 0: + continue - for offsets in all_offsets: - for filename, all_offsets_in_file in tensor_storage_metadata.flattened_offsets_per_file.items(): for offsets_in_file in all_offsets_in_file: + if offsets_in_file[1] - offsets_in_file[0] == 0: + continue + # Check for overlap in offsets, and if there is overlap, load the slice from disk. if ( offsets_in_file[0] <= offsets[0] < offsets_in_file[1] @@ -385,16 +414,7 @@ def load( or (offsets[0] < offsets_in_file[0] and offsets_in_file[1] < offsets[1]) ): with safetensors_mfl.open(f"{dir}/{filename}") as loader: - if len((shape_in_file := loader.get_shape(key))) != 1: - raise ValueError( - f"Expected a 1D tensor at {key} in {filename}, found shape {shape_in_file}" - ) - - if (dtype := loader.get_dtype(key)) != tensor.dtype: - raise ValueError( - f"Data type mismatch between tensor to load ({dtype}) and to load into ({tensor.dtype})" - ) - + validate_shard_in_file(tensor, loader, key, filename) numel_in_file = loader.get_numel(key) if numel_in_file == 0: continue From 31807749efc34d825cb1ddf173950646aba81a83 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 9 May 2024 15:41:54 -0700 Subject: [PATCH 03/27] Add checkpointing support for DTensor --- docs/source/distributed/checkpoint.rst | 2 +- src/olmo_core/distributed/checkpoint.py | 198 +++++++++++++----- .../distributed/tensors/dtensor_utils.py | 12 ++ src/test/distributed/checkpoint_test.py | 73 +++++++ 4 files changed, 226 insertions(+), 59 deletions(-) diff --git a/docs/source/distributed/checkpoint.rst b/docs/source/distributed/checkpoint.rst index f2797d93..b6d8df57 100644 --- a/docs/source/distributed/checkpoint.rst +++ b/docs/source/distributed/checkpoint.rst @@ -2,5 +2,5 @@ ========================== .. automodule:: olmo_core.distributed.checkpoint - :members: save_model_and_optim_state, load_model_and_optim_state, unshard_model_state, unshard_optim_state, Checkpointer, StorageMetadata, TensorStorageMetadata + :members: save_model_and_optim_state, load_model_and_optim_state, unshard_model_state, unshard_optim_state, Checkpointer, StorageMetadata, TensorStorageMetadata, TensorShardSpec :member-order: bysource diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index cc0093fc..c9c26891 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -47,8 +47,10 @@ import torch.distributed as dist import torch.nn as nn from cached_path import cached_path -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict +from torch.distributed._tensor import DTensor +import olmo_core.distributed.tensors.dtensor_utils as dtensor_utils from olmo_core.exceptions import OLMoUserError from olmo_core.io import ( PathOrStr, @@ -225,7 +227,8 @@ class Checkpointer: """ A distributed checkpointer for saving and loading *non-nested* state dictionaries, i.e. where keys are strings and values are either regular :class:`torch.Tensor` instances, - :class:`torch.nn.Parameter` instances, or any sharded tensors from this library. + :class:`torch.nn.Parameter` instances, :class:`DTensor` instances, or any sharded tensors + from this library. For saving and loading model and optimizer states together, use :func:`save_model_and_optim_state()` and :func:`load_model_and_optim_state()` instead. @@ -290,10 +293,8 @@ def save( tensor_save_plan = global_save_plan.tensors[key] local_flat_tensor = flat_views[key] - if (local_offsets := tensor_save_plan.flattened_offsets_per_rank.get(local_rank)) is not None: - local_numel = 0 - for start_idx, end_idx in local_offsets: - local_numel += end_idx - start_idx + if (local_shard_spec := tensor_save_plan.shard_spec_per_rank.get(local_rank)) is not None: + local_numel = local_shard_spec.local_numel assert local_numel == local_flat_tensor.numel() local_state_dict[key] = local_flat_tensor @@ -378,15 +379,12 @@ def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: # Get the local offsets to load. all_offsets: Tuple[Tuple[int, int], ...] if flat_view.is_sharded: - all_offsets = flat_view.flattened_offsets_per_rank[get_rank()] + all_offsets = flat_view.get_local_flattened_offsets() else: - # NOTE: Rank 0 will always be present, other ranks will not be for regular unsharded tensors. - if get_rank() in flat_view.flattened_offsets_per_rank: - all_offsets = flat_view.flattened_offsets_per_rank[get_rank()] - else: - all_offsets = next(iter(flat_view.flattened_offsets_per_rank.values())) + all_offsets = ((0, tensor.numel()),) - for filename, all_offsets_in_file in tensor_storage_metadata.flattened_offsets_per_file.items(): + for filename in tensor_storage_metadata.shard_spec_per_file: + all_offsets_in_file = tensor_storage_metadata.get_flattened_offsets_in_file(filename) if not all_offsets_in_file: continue @@ -554,19 +552,31 @@ def get_metadata(self, dir: str, no_dist: bool = False) -> StorageMetadata: Get the storage metadata from a checkpoint directory. """ dir = self._normalize_dir(dir) + metadata: Optional[StorageMetadata] = None if no_dist or get_rank() == 0: with open(cached_path(f"{dir}/{self.METADATA_FILENAME}")) as f: json_metadata = json.load(f) - # For backwards compat, covert offsets `tuple[int, int]` into `tuple[tuple[int, int], ...]` + + # Coerce fields if needed for backwards compatibility. for tensor_metadata in json_metadata["tensors"].values(): - for path in tensor_metadata["flattened_offsets_per_file"]: - offsets = tensor_metadata["flattened_offsets_per_file"][path] - if offsets and isinstance(offsets[0], int): - tensor_metadata["flattened_offsets_per_file"][path] = [offsets] + if "flattened_offsets_per_file" in tensor_metadata: + for path in tensor_metadata["flattened_offsets_per_file"]: + offsets = tensor_metadata["flattened_offsets_per_file"][path] + # covert offsets `tuple[int, int]` into `tuple[tuple[int, int], ...]` + if offsets and isinstance(offsets[0], int): + tensor_metadata["flattened_offsets_per_file"][path] = [offsets] + + tensor_metadata["shard_spec_per_file"] = { + path: {"flattened_offsets": offsets} + for path, offsets in tensor_metadata.pop("flattened_offsets_per_file").items() + } + metadata = StorageMetadata(**json_metadata) + if not no_dist: metadata = scatter_object(metadata) + assert metadata is not None return metadata @@ -582,7 +592,7 @@ def _copy_into(self, target: torch.Tensor, source: torch.Tensor): def _get_flat_view(self, tensor: torch.Tensor) -> TensorFlatView: full_shape: Tuple[int, ...] is_sharded: bool = False - flattened_offsets_per_rank: Dict[int, Tuple[Tuple[int, int], ...]] = {} + shard_spec_per_rank: Dict[int, TensorShardSpec] = {} if isinstance(tensor, ShardedFlatTensor): full_shape = tensor.unsharded_shape is_sharded = True @@ -593,15 +603,25 @@ def _get_flat_view(self, tensor: torch.Tensor) -> TensorFlatView: if tensor.process_group is None else dist.get_global_rank(tensor.process_group, pg_rank) ) - flattened_offsets_per_rank[global_rank] = offsets + shard_spec_per_rank[global_rank] = TensorShardSpec(flattened_offsets=offsets) + elif isinstance(tensor, DTensor): + full_shape = tuple(tensor.shape) + is_sharded = True + for global_rank in tensor.device_mesh.mesh.flatten(): + local_shape, global_offset = dtensor_utils.get_local_shape_and_global_offset( + tensor, rank=int(global_rank.item()) + ) + shard_spec_per_rank[global_rank] = TensorShardSpec( + local_shape=local_shape, global_offset=global_offset + ) else: full_shape = tuple(tensor.shape) - flattened_offsets_per_rank = {get_rank(): ((0, tensor.numel()),)} + shard_spec_per_rank[get_rank()] = TensorShardSpec(flattened_offsets=((0, tensor.numel()),)) return TensorFlatView( view=_get_local_tensor_data(tensor).view(-1), full_shape=full_shape, is_sharded=is_sharded, - flattened_offsets_per_rank=flattened_offsets_per_rank, + shard_spec_per_rank=shard_spec_per_rank, ) def _get_global_save_plan_and_metadata( @@ -614,16 +634,17 @@ def _get_global_save_plan_and_metadata( flat_view = self._get_flat_view(tensor) tensors_flat_view[key] = flat_view.view tensors_save_plan[key] = TensorSavePlan( - flattened_offsets_per_rank=flat_view.flattened_offsets_per_rank, is_sharded=flat_view.is_sharded + is_sharded=flat_view.is_sharded, + shard_spec_per_rank=flat_view.shard_spec_per_rank, ) tensors_metadata[key] = TensorStorageMetadata( - flattened_offsets_per_file={ - self._filename_for_rank(rank): offsets - for rank, offsets in flat_view.flattened_offsets_per_rank.items() - }, shape=flat_view.full_shape, is_sharded=flat_view.is_sharded, dtype=TORCH_DTYPE_TO_STR[tensor.dtype], + shard_spec_per_file={ + self._filename_for_rank(rank): shard_spec + for rank, shard_spec in flat_view.shard_spec_per_rank.items() + }, ) # All-gather save plans across ranks, merge and validate. @@ -639,9 +660,7 @@ def _get_global_save_plan_and_metadata( if not plan.is_sharded and not final_plan.is_sharded: # default to first rank with a save plan for this tensor pass - elif not set(plan.flattened_offsets_per_rank).intersection( - final_plan.flattened_offsets_per_rank - ): + elif not set(plan.shard_spec_per_rank).intersection(final_plan.shard_spec_per_rank): # tensor may be sharded in separate process groups, that's okay. pass else: @@ -664,9 +683,7 @@ def _get_global_save_plan_and_metadata( if not metadata.is_sharded and not final_metadata.is_sharded: # default to first rank with metadata for this tensor pass - elif not set(metadata.flattened_offsets_per_file).intersection( - final_metadata.flattened_offsets_per_file - ): + elif not set(metadata.shard_spec_per_file).intersection(final_metadata.shard_spec_per_file): # tensor may be sharded in separate process groups, that's okay. pass else: @@ -689,13 +706,59 @@ def _normalize_dir(self, dir: PathOrStr) -> str: return dir -class TensorStorageMetadata(BaseModel): - flattened_offsets_per_file: Dict[str, Tuple[Tuple[int, int], ...]] +class TensorShardSpec(BaseModel): + model_config = ConfigDict(frozen=True) + + flattened_offsets: Optional[Tuple[Tuple[int, int], ...]] = None + """ + Offsets within the full flattened tensor that the given shard corresponds to. + """ + + local_shape: Optional[Tuple[int, ...]] = None """ - Maps file name to the offsets within the full flattened tensor that the shard in the file - corresponds to. + The (unflattened) shape of the local shard. """ + global_offset: Optional[Tuple[int, ...]] = None + """ + The starting offset for each dimension in the global unsharded (unflattened) tensor that the + local shard corresponds to. + """ + + @property + def local_numel(self) -> int: + if self.local_shape is not None: + return reduce(lambda x, y: x * y, self.local_shape, 1) + elif self.flattened_offsets is not None: + local_numel = 0 + for start_idx, end_idx in self.flattened_offsets: + local_numel += end_idx - start_idx + return local_numel + else: + raise ValueError("missing required fields to determine local numel") + + def get_flattened_offsets(self, full_shape: Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]: + if self.flattened_offsets is not None: + return self.flattened_offsets + elif self.local_shape is not None and self.global_offset is not None: + assert len(self.local_shape) == len(self.global_offset) == len(full_shape) + if len(full_shape) == 1: # 1D tensor + return ((self.global_offset[0], self.global_offset[0] + self.local_numel),) + elif len(full_shape) == 2: + offsets = [] + for row in range(self.global_offset[0], self.global_offset[0] + self.local_shape[0]): + offset_start = row * full_shape[1] + self.global_offset[1] + offset_end = offset_start + self.local_shape[1] + offsets.append((offset_start, offset_end)) + return tuple(offsets) + else: + # TODO: generalize + raise NotImplementedError("only 1D and 2D DTensors are supported") + else: + raise ValueError("missing required fields to produce flattened offsets") + + +class TensorStorageMetadata(BaseModel): shape: Tuple[int, ...] """ The shape of the full (unflattened) tensor. @@ -711,6 +774,11 @@ class TensorStorageMetadata(BaseModel): The data type of the tensor. """ + shard_spec_per_file: Dict[str, TensorShardSpec] + """ + Maps each filename to the sharding spec of the local shard within that file. + """ + @property def torch_dtype(self) -> torch.dtype: return TORCH_DTYPES[self.dtype] @@ -723,20 +791,11 @@ def materialize_empty( tensor.fill_(torch.nan) return tensor - def materialize_from_sharded( - self, tensor: torch.Tensor, device: Optional[torch.device] = None - ) -> torch.Tensor: - if isinstance(tensor, ShardedFlatTensor): - if tensor.unsharded_shape != self.shape: - raise ValueError( - f"unexpected shape for sharded tensor, expected {self.shape}, got {tensor.unsharded_shape}" - ) - tensor = torch.empty(tensor.shape, device=device, dtype=self.torch_dtype) - if tensor.dtype.is_floating_point: - tensor.fill_(torch.nan) - return tensor + def get_flattened_offsets_in_file(self, filename: str) -> Optional[Tuple[Tuple[int, int], ...]]: + if (shard_spec := self.shard_spec_per_file.get(filename)) is not None: + return shard_spec.get_flattened_offsets(self.shape) else: - raise NotImplementedError(f"`materialize_from_sharded()` not implemented for {tensor}") + return None class StorageMetadata(BaseModel): @@ -744,15 +803,14 @@ class StorageMetadata(BaseModel): class TensorSavePlan(BaseModel): - flattened_offsets_per_rank: Dict[int, Tuple[Tuple[int, int], ...]] + is_sharded: bool """ - Maps global process rank to the offsets within the full flattened tensor that the shard for the - rank corresponds to. Some ranks may be omitted. + If the tensor is sharded. """ - is_sharded: bool + shard_spec_per_rank: Dict[int, TensorShardSpec] """ - If the tensor is sharded. + Maps each rank to the sharding spec of the local shard from that rank. Some ranks may be omitted. """ @@ -773,11 +831,20 @@ class TensorFlatView: If the tensor is sharded. """ - flattened_offsets_per_rank: Dict[int, Tuple[Tuple[int, int], ...]] + shard_spec_per_rank: Dict[int, TensorShardSpec] """ - A mapping of *global* rank to offsets into the full flattened tensor. + Maps each rank to the sharding spec of the local shard from that rank. """ + def get_flattened_offsets_for_rank(self, rank: int) -> Optional[Tuple[Tuple[int, int], ...]]: + if (shard_spec := self.shard_spec_per_rank.get(rank)) is not None: + return shard_spec.get_flattened_offsets(self.full_shape) + else: + return None + + def get_local_flattened_offsets(self) -> Tuple[Tuple[int, int], ...]: + return self.shard_spec_per_rank[get_rank()].get_flattened_offsets(self.full_shape) + class SavePlan(BaseModel): tensors: Dict[str, TensorSavePlan] @@ -1105,11 +1172,26 @@ def _patch_key(model: nn.Module, key: str) -> str: def _get_local_tensor_data(tensor: torch.Tensor) -> torch.Tensor: - return tensor.data + if isinstance(tensor, DTensor): + return tensor.to_local() + else: + return tensor.data def _wrap_tensor_for_sharded_parameter(tensor: torch.Tensor, param: Optional[torch.Tensor]) -> torch.Tensor: if isinstance(param, ShardedFlatTensor): return param.wrap(tensor, requires_grad=False) + elif isinstance(param, DTensor): + return DTensor( # type: ignore + tensor, + param.device_mesh, + param.placements, + shape=param.size(), + dtype=tensor.dtype, + requires_grad=False, + stride=param.stride(), + ) + elif isinstance(param, nn.Parameter) and isinstance(param.data, DTensor): + return _wrap_tensor_for_sharded_parameter(tensor, param.data) else: return tensor diff --git a/src/olmo_core/distributed/tensors/dtensor_utils.py b/src/olmo_core/distributed/tensors/dtensor_utils.py index 257df20a..f2404dda 100644 --- a/src/olmo_core/distributed/tensors/dtensor_utils.py +++ b/src/olmo_core/distributed/tensors/dtensor_utils.py @@ -4,6 +4,7 @@ from typing import Optional, Sequence, Tuple +from torch.distributed._tensor import DTensor from torch.distributed._tensor.placement_types import Placement, Shard from torch.distributed.device_mesh import DeviceMesh @@ -12,6 +13,17 @@ from ..utils import get_mesh_coordinates +def get_local_shape_and_global_offset( + dtensor: DTensor, rank: Optional[int] = None +) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + global_shape = dtensor.shape + mesh = dtensor.device_mesh + placements = dtensor.placements + local_shape, global_offset = compute_local_shape_and_global_offset(global_shape, mesh, placements, rank=rank) + assert local_shape == dtensor.to_local().shape + return local_shape, global_offset + + # Adapted from `torch.distributed._tensor._utils.py`. def compute_local_shape_and_global_offset( global_shape: ShapeType, diff --git a/src/test/distributed/checkpoint_test.py b/src/test/distributed/checkpoint_test.py index 2ec2c994..81e6c456 100644 --- a/src/test/distributed/checkpoint_test.py +++ b/src/test/distributed/checkpoint_test.py @@ -5,11 +5,13 @@ import torch import torch.distributed as dist from cached_path import cached_path +from torch.distributed._tensor import Shard, distribute_tensor, init_device_mesh from olmo_core.distributed.checkpoint import ( Checkpointer, OptimStateDict, SafeTensorsLoader, + TensorShardSpec, _flatten_optimizer_state, _get_model_state_dict_for_checkpoint, _unflatten_optimizer_state, @@ -35,6 +37,49 @@ ) +def test_tensor_shard_spec_for_dtensor_1D(): + full_shape = (16,) + shard_spec = TensorShardSpec(local_shape=(8,), global_offset=(0,)) + assert shard_spec.get_flattened_offsets(full_shape) == ((0, 8),) + + +def test_tensor_shard_spec_for_dtensor_2D_colwise(): + # For example: + # from torch.distributed._tensor import Shard, distribute_tensor, init_device_mesh + # mesh = init_device_mesh("cuda", (dist.get_world_size(),)) + # distribute_tensor(torch.randn(16, 8), mesh, [Shard(dim=0)]) + full_shape = (16, 8) + shard_spec = TensorShardSpec(local_shape=(4, 8), global_offset=(4, 0)) + assert shard_spec.get_flattened_offsets(full_shape) == ((32, 40), (40, 48), (48, 56), (56, 64)) + + +def test_tensor_shard_spec_for_dtensor_2D_rowwise(): + # For example: + # from torch.distributed._tensor import Shard, distribute_tensor, init_device_mesh + # mesh = init_device_mesh("cuda", (dist.get_world_size(),)) + # distribute_tensor(torch.randn(16, 8), mesh, [Shard(dim=1)]) + full_shape = (16, 8) + shard_spec = TensorShardSpec(local_shape=(16, 2), global_offset=(0, 2)) + assert shard_spec.get_flattened_offsets(full_shape) == ( + (2, 4), # row 0 + (10, 12), # row 1 + (18, 20), # row 2 + (26, 28), # row 3 + (34, 36), # row 4 + (42, 44), # row 5 + (50, 52), # row 6 + (58, 60), # row 7 + (66, 68), # row 8 + (74, 76), # row 9 + (82, 84), # row 10 + (90, 92), # row 11 + (98, 100), # row 12 + (106, 108), # row 13 + (114, 116), # row 14 + (122, 124), # row 15 + ) + + def save_and_load_checkpoint_with_regular_and_sharded_tensors(dir): checkpointer = Checkpointer() @@ -92,6 +137,34 @@ def test_save_and_load_checkpoint_with_regular_and_sharded_tensors(backend, tmp_ assert full_state_dict["y"].shape == (2, 3) +def save_and_load_checkpoint_with_dtensors(dir): + checkpointer = Checkpointer() + + mesh = init_device_mesh("cuda", (dist.get_world_size(),)) + + state_dict_to_save = { + "1d": distribute_tensor(torch.randn(16, device=get_default_device()), mesh, [Shard(dim=0)]), + "2d_colwise": distribute_tensor(torch.randn(16, 8, device=get_default_device()), mesh, [Shard(dim=0)]), + "2d_rowwise": distribute_tensor(torch.randn(16, 8, device=get_default_device()), mesh, [Shard(dim=1)]), + } + + state_dict_to_load = { + "1d": distribute_tensor(torch.randn(16, device=get_default_device()), mesh, [Shard(dim=0)]), + "2d_colwise": distribute_tensor(torch.randn(16, 8, device=get_default_device()), mesh, [Shard(dim=0)]), + "2d_rowwise": distribute_tensor(torch.randn(16, 8, device=get_default_device()), mesh, [Shard(dim=1)]), + } + + checkpointer.load(dir, state_dict_to_load) # type: ignore[arg-type] + + for key in state_dict_to_load: + torch.testing.assert_close(state_dict_to_save[key], state_dict_to_load[key]) + + +@requires_multi_gpu +def test_save_and_load_checkpoint_with_dtensors(tmp_path): + run_distributed_test(save_and_load_checkpoint_with_dtensors, backend="nccl", func_args=(tmp_path,)) + + def save_and_load_checkpoint_with_different_sharding_spec(dir): for idx, (offsets_to_save, offsets_to_load) in enumerate( [ From 8d367b724e6901cd71bfafe464d3d3518c024b31 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 9 May 2024 15:43:36 -0700 Subject: [PATCH 04/27] fix test --- src/test/distributed/checkpoint_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/test/distributed/checkpoint_test.py b/src/test/distributed/checkpoint_test.py index 81e6c456..1a0e46bd 100644 --- a/src/test/distributed/checkpoint_test.py +++ b/src/test/distributed/checkpoint_test.py @@ -154,6 +154,7 @@ def save_and_load_checkpoint_with_dtensors(dir): "2d_rowwise": distribute_tensor(torch.randn(16, 8, device=get_default_device()), mesh, [Shard(dim=1)]), } + checkpointer.save(dir, state_dict_to_save) # type: ignore[arg-type] checkpointer.load(dir, state_dict_to_load) # type: ignore[arg-type] for key in state_dict_to_load: From 6e1ef025179c8e4a8abd66d8c290fcba5bcf6802 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 9 May 2024 15:50:03 -0700 Subject: [PATCH 05/27] fix --- src/olmo_core/distributed/checkpoint.py | 3 ++- src/olmo_core/distributed/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index c9c26891..34b3f84b 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -608,8 +608,9 @@ def _get_flat_view(self, tensor: torch.Tensor) -> TensorFlatView: full_shape = tuple(tensor.shape) is_sharded = True for global_rank in tensor.device_mesh.mesh.flatten(): + global_rank = int(global_rank.item()) local_shape, global_offset = dtensor_utils.get_local_shape_and_global_offset( - tensor, rank=int(global_rank.item()) + tensor, rank=global_rank ) shard_spec_per_rank[global_rank] = TensorShardSpec( local_shape=local_shape, global_offset=global_offset diff --git a/src/olmo_core/distributed/utils.py b/src/olmo_core/distributed/utils.py index d526170d..eb71568d 100644 --- a/src/olmo_core/distributed/utils.py +++ b/src/olmo_core/distributed/utils.py @@ -86,7 +86,7 @@ def get_gradient_divide_factor(world_size: int) -> float: return float(factor) -def get_mesh_coordinates(mesh: DeviceMesh, rank: Optional[int] = 0) -> Optional[List[int]]: +def get_mesh_coordinates(mesh: DeviceMesh, rank: Optional[int] = None) -> Optional[List[int]]: """ Calculate the coordinates of a global rank on a device mesh. @@ -96,6 +96,6 @@ def get_mesh_coordinates(mesh: DeviceMesh, rank: Optional[int] = 0) -> Optional[ :return: The coordinates or ``None`` if the rank is not part of the mesh. """ rank = rank if rank is not None else get_rank() - rank_coords = (mesh.mesh == get_rank()).nonzero() + rank_coords = (mesh.mesh == rank).nonzero() assert rank_coords.size(0) in (0, 1) return rank_coords[0].tolist() if rank_coords.size(0) > 0 else None From a581b06ec7b8180b0c79cf44d6f8e5e4276b85fa Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 9 May 2024 16:10:42 -0700 Subject: [PATCH 06/27] clean up --- src/olmo_core/distributed/checkpoint.py | 32 ++++++++++++++----------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index 34b3f84b..05785b1f 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -376,19 +376,14 @@ def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: f"Shape mismatched for '{key}', expected {flat_view.full_shape}, found {tensor_storage_metadata.shape}" ) - # Get the local offsets to load. - all_offsets: Tuple[Tuple[int, int], ...] - if flat_view.is_sharded: - all_offsets = flat_view.get_local_flattened_offsets() - else: - all_offsets = ((0, tensor.numel()),) + if flat_view.shard_spec.local_numel == 0: + continue # nothing to load - for filename in tensor_storage_metadata.shard_spec_per_file: - all_offsets_in_file = tensor_storage_metadata.get_flattened_offsets_in_file(filename) - if not all_offsets_in_file: + for filename, shard_spec_in_file in tensor_storage_metadata.shard_spec_per_file.items(): + if shard_spec_in_file.local_numel == 0: continue - if all_offsets == all_offsets_in_file: + if flat_view.shard_spec == shard_spec_in_file: # Load the whole slice within the file at once. with safetensors_mfl.open(f"{dir}/{filename}") as loader: validate_shard_in_file(tensor, loader, key, filename) @@ -397,7 +392,11 @@ def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: flat_view.view.copy_(loader.get_flat_slice(key)) break - for offsets in all_offsets: + all_offsets_in_file = tensor_storage_metadata.get_flattened_offsets_in_file(filename) + if not all_offsets_in_file: + continue + + for offsets in flat_view.local_flattened_offsets: if offsets[1] - offsets[0] == 0: continue @@ -837,15 +836,20 @@ class TensorFlatView: Maps each rank to the sharding spec of the local shard from that rank. """ + @property + def shard_spec(self) -> TensorShardSpec: + return self.shard_spec_per_rank[get_rank()] + + @cached_property + def local_flattened_offsets(self) -> Tuple[Tuple[int, int], ...]: + return self.shard_spec_per_rank[get_rank()].get_flattened_offsets(self.full_shape) + def get_flattened_offsets_for_rank(self, rank: int) -> Optional[Tuple[Tuple[int, int], ...]]: if (shard_spec := self.shard_spec_per_rank.get(rank)) is not None: return shard_spec.get_flattened_offsets(self.full_shape) else: return None - def get_local_flattened_offsets(self) -> Tuple[Tuple[int, int], ...]: - return self.shard_spec_per_rank[get_rank()].get_flattened_offsets(self.full_shape) - class SavePlan(BaseModel): tensors: Dict[str, TensorSavePlan] From de2ec70be6a5292d34065116bf07554ac797fa3a Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 9 May 2024 16:26:16 -0700 Subject: [PATCH 07/27] use some threads when unsharding by default --- src/olmo_core/distributed/checkpoint.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index 05785b1f..d3432293 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -63,7 +63,7 @@ serialize_to_tensor, upload, ) -from olmo_core.utils import TORCH_DTYPE_TO_STR, TORCH_DTYPES +from olmo_core.utils import TORCH_DTYPE_TO_STR, TORCH_DTYPES, default_thread_count from .tensors import ShardedFlatTensor, ShardingSpec from .utils import all_gather_object, barrier, get_rank, get_world_size, scatter_object @@ -510,8 +510,12 @@ def unshard( :param rank0_only: Set to true if you only want to load the unsharded state to rank 0 in a distributed context. Other ranks will receive an empty dictionary. :param no_dist: Set to true to avoid any distributed communication whatsoever. + :param num_threads: The maximum number of threads to use to unshard the checkpoint. + Increasing ``num_threads`` can lead to a substantial speed up, especially when loading + from a remote checkpoint. Set to ``0`` to disable threading. """ dir = self._normalize_dir(dir) + num_threads = num_threads if num_threads is not None else default_thread_count() if rank0_only and no_dist and get_rank() != 0: raise ValueError( From 9d8bf89df386cb892369b738a176cc1d928d1614 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 09:01:16 -0700 Subject: [PATCH 08/27] improve docs --- docs/source/distributed/tensors.rst | 4 ++ src/olmo_core/distributed/checkpoint.py | 3 +- .../distributed/tensors/dtensor_utils.py | 50 ++++++++++++------- 3 files changed, 38 insertions(+), 19 deletions(-) diff --git a/docs/source/distributed/tensors.rst b/docs/source/distributed/tensors.rst index ea2af14b..25fc5ac5 100644 --- a/docs/source/distributed/tensors.rst +++ b/docs/source/distributed/tensors.rst @@ -4,3 +4,7 @@ .. automodule:: olmo_core.distributed.tensors :members: :member-order: bysource + +.. automodule:: olmo_core.distributed.tensors.dtensor_utils + :members: + :member-order: bysource diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index d3432293..b98b622d 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -8,7 +8,8 @@ -------- - Sharded distributed models, such OLMo-core's :class:`~olmo_core.distributed.fsdp.FSDP` or PyTorch's - :class:`~torch.distributed.fsdp.FullyShardedDataParallel` are supported out-of-the-box. + :class:`~torch.distributed.fsdp.FullyShardedDataParallel` (with ``use_orig_params=True``) + are supported out-of-the-box. - Utilizes `safetensors `_ under the hood for fast, efficient, and safe serialization/deserialization. - Save with one distributed topology, seamlessly load with a different one. For example, diff --git a/src/olmo_core/distributed/tensors/dtensor_utils.py b/src/olmo_core/distributed/tensors/dtensor_utils.py index f2404dda..d26dbcee 100644 --- a/src/olmo_core/distributed/tensors/dtensor_utils.py +++ b/src/olmo_core/distributed/tensors/dtensor_utils.py @@ -16,6 +16,16 @@ def get_local_shape_and_global_offset( dtensor: DTensor, rank: Optional[int] = None ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + """ + Like :func:`compute_local_shape_and_global_offset`, but acts directly on a :class:`DTensor` + instance. + + :param dtensor: A DTensor instance. + :param rank: The global rank to compute the local shape and global offsets for. If ``None``, + defaults to the current rank. + + :returns: The local shape and global offset. + """ global_shape = dtensor.shape mesh = dtensor.device_mesh placements = dtensor.placements @@ -41,9 +51,11 @@ def compute_local_shape_and_global_offset( :param rank: The global rank to compute the local shape and global offsets for. If ``None``, defaults to the current rank. - Example (2 host with 4GPUs each): + :returns: The local shape and global offset. + + Example (2 host with 4GPUs each):: - # Below is a DeviceMesh with mesh_shape of ``(2, 4)`` + # Below is a DeviceMesh with mesh_shape of (2, 4) mesh = DeviceMesh(device_type="cuda", mesh=[ [0, 1, 2, 3], [4, 5, 6, 7] @@ -53,27 +65,29 @@ def compute_local_shape_and_global_offset( with a placements of ``[Shard(0), Shard(0)]``. The local shape and global offset will be as follows: - rank0 -- local_shape:[1, 4], global_offset:[0, 0] - rank1 -- local_shape:[1, 4], global_offset:[1, 0] - rank2 -- local_shape:[1, 4], global_offset:[2, 0] - rank5 -- local_shape:[1, 4], global_offset:[5, 0] - rank3 -- local_shape:[1, 4], global_offset:[3, 0] - rank4 -- local_shape:[1, 4], global_offset:[4, 0] - rank6 -- local_shape:[1, 4], global_offset:[6, 0] - rank7 -- local_shape:[1, 4], global_offset:[7, 0] + + - ``rank0 -- local_shape:[1, 4], global_offset:[0, 0]`` + - ``rank1 -- local_shape:[1, 4], global_offset:[1, 0]`` + - ``rank2 -- local_shape:[1, 4], global_offset:[2, 0]`` + - ``rank5 -- local_shape:[1, 4], global_offset:[5, 0]`` + - ``rank3 -- local_shape:[1, 4], global_offset:[3, 0]`` + - ``rank4 -- local_shape:[1, 4], global_offset:[4, 0]`` + - ``rank6 -- local_shape:[1, 4], global_offset:[6, 0]`` + - ``rank7 -- local_shape:[1, 4], global_offset:[7, 0]`` Let's say we distribute a global_tensor of shape ``(2,)`` over the above DeviceMesh with a placements of ``[Shard(0)]``. We will not have non-empty local tensor for all the ranks. The local shape and global offset will be as follows: - rank0 -- local_shape:[1,], global_offset:[0,] - rank1 -- local_shape:[1,], global_offset:[1,] - rank2 -- local_shape:[0,], global_offset:[2,] - rank5 -- local_shape:[0,], global_offset:[2,] - rank3 -- local_shape:[0,], global_offset:[2,] - rank4 -- local_shape:[0,], global_offset:[2,] - rank6 -- local_shape:[0,], global_offset:[2,] - rank7 -- local_shape:[0,], global_offset:[2,] + + - ``rank0 -- local_shape:[1,], global_offset:[0,]`` + - ``rank1 -- local_shape:[1,], global_offset:[1,]`` + - ``rank2 -- local_shape:[0,], global_offset:[2,]`` + - ``rank5 -- local_shape:[0,], global_offset:[2,]`` + - ``rank3 -- local_shape:[0,], global_offset:[2,]`` + - ``rank4 -- local_shape:[0,], global_offset:[2,]`` + - ``rank6 -- local_shape:[0,], global_offset:[2,]`` + - ``rank7 -- local_shape:[0,], global_offset:[2,]`` """ my_coordinate = mesh.get_coordinate() if rank is None else get_mesh_coordinates(mesh, rank) From 41f20f10e64b784e01b2f32e034b964e7435df8f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 09:07:08 -0700 Subject: [PATCH 09/27] Add another test --- src/test/distributed/checkpoint_test.py | 32 +++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/test/distributed/checkpoint_test.py b/src/test/distributed/checkpoint_test.py index 1a0e46bd..177ff6a0 100644 --- a/src/test/distributed/checkpoint_test.py +++ b/src/test/distributed/checkpoint_test.py @@ -166,6 +166,38 @@ def test_save_and_load_checkpoint_with_dtensors(tmp_path): run_distributed_test(save_and_load_checkpoint_with_dtensors, backend="nccl", func_args=(tmp_path,)) +def save_and_load_checkpoint_with_different_dtensor_topology(dir): + checkpointer = Checkpointer() + + mesh = init_device_mesh("cuda", (dist.get_world_size(),)) + + og_tensor = torch.randn(8, 6, device=get_default_device()) + + # Ensure tensor matches on all ranks (could use scatter here too, but whatever). + dist.all_reduce(og_tensor) + + state_dict_to_save = { + "x": distribute_tensor(og_tensor, mesh, [Shard(dim=0)]), + } + checkpointer.save(dir, state_dict_to_save) # type: ignore[arg-type] + + state_dict_to_load = { + "x": distribute_tensor(torch.randn(8, 6, device=get_default_device()), mesh, [Shard(dim=1)]), + } + checkpointer.load(dir, state_dict_to_load) # type: ignore[arg-type] + + # Gather full tensor from the state dict to load and make sure it matches the full OG tensor. + full_loaded_tensor = state_dict_to_load["x"].full_tensor() + torch.testing.assert_close(og_tensor, full_loaded_tensor) + + +@requires_multi_gpu +def test_save_and_load_checkpoint_with_different_dtensor_topology(tmp_path): + run_distributed_test( + save_and_load_checkpoint_with_different_dtensor_topology, backend="nccl", func_args=(tmp_path,) + ) + + def save_and_load_checkpoint_with_different_sharding_spec(dir): for idx, (offsets_to_save, offsets_to_load) in enumerate( [ From d3fac0a807f2acebb6394dd3aa31851eee628c86 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 09:23:54 -0700 Subject: [PATCH 10/27] fix --- src/olmo_core/distributed/checkpoint.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index b98b622d..e5188517 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -40,7 +40,7 @@ from dataclasses import dataclass from functools import cached_property, reduce from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, TypedDict +from typing import Any, Dict, Generator, List, Optional, Set, Tuple, TypedDict import safetensors as sft import safetensors.torch as sft_torch @@ -397,7 +397,7 @@ def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: if not all_offsets_in_file: continue - for offsets in flat_view.local_flattened_offsets: + for offsets, flat_view_slice in flat_view.get_local_flattened_offsets_with_slice(): if offsets[1] - offsets[0] == 0: continue @@ -419,7 +419,7 @@ def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: # Start and end index of the slice within `flat_tensor` that we're going to load # from a slice of `flat_tensor_to_load`. - flat_tensor_start, flat_tensor_end = 0, flat_view.view.numel() + flat_tensor_start, flat_tensor_end = 0, flat_view_slice.numel() # Start and end index of the slice within `flat_tensor_to_load` that we're going # to load into the slice of `flat_tensor`. flat_tensor_to_load_start, flat_tensor_to_load_end = 0, numel_in_file @@ -471,13 +471,14 @@ def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: key, flat_tensor_to_load_start, flat_tensor_to_load_end ) if ( - load_shape := flat_view.view[flat_tensor_start:flat_tensor_end].shape + load_shape := flat_view_slice[flat_tensor_start:flat_tensor_end].shape ) != flat_tensor_to_load.shape: raise RuntimeError( - f"error loading {key} from {filename} with offsets ({flat_tensor_start}, {flat_tensor_end}), " - f"expected {load_shape}, found {flat_tensor_to_load.shape}" + f"error loading tensor '{key}' from file '{filename}' with offsets " + f"({flat_tensor_start}, {flat_tensor_end}), " + f"expected shape {tuple(load_shape)}, found {tuple(flat_tensor_to_load.shape)}" ) - flat_view.view[flat_tensor_start:flat_tensor_end].copy_(flat_tensor_to_load) + flat_view_slice[flat_tensor_start:flat_tensor_end].copy_(flat_tensor_to_load) del flat_tensor_to_load @@ -849,6 +850,15 @@ def shard_spec(self) -> TensorShardSpec: def local_flattened_offsets(self) -> Tuple[Tuple[int, int], ...]: return self.shard_spec_per_rank[get_rank()].get_flattened_offsets(self.full_shape) + def get_local_flattened_offsets_with_slice( + self, + ) -> Generator[Tuple[Tuple[int, int], torch.Tensor], None, None]: + numel_so_far = 0 + for offset_start, offset_end in self.local_flattened_offsets: + numel_in_slice = offset_end - offset_start + yield (offset_start, offset_end), self.view[numel_so_far : numel_so_far + numel_in_slice] + numel_so_far += numel_in_slice + def get_flattened_offsets_for_rank(self, rank: int) -> Optional[Tuple[Tuple[int, int], ...]]: if (shard_spec := self.shard_spec_per_rank.get(rank)) is not None: return shard_spec.get_flattened_offsets(self.full_shape) From 49e5b8bff5e6434075e82fb262bc1cc4f5189da3 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 09:30:31 -0700 Subject: [PATCH 11/27] try again --- src/olmo_core/distributed/checkpoint.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index e5188517..af2af08e 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -401,8 +401,10 @@ def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: if offsets[1] - offsets[0] == 0: continue + numel_in_file_so_far = 0 for offsets_in_file in all_offsets_in_file: - if offsets_in_file[1] - offsets_in_file[0] == 0: + numel_in_file_slice = offsets_in_file[1] - offsets_in_file[0] + if numel_in_file_slice == 0: continue # Check for overlap in offsets, and if there is overlap, load the slice from disk. @@ -422,7 +424,10 @@ def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: flat_tensor_start, flat_tensor_end = 0, flat_view_slice.numel() # Start and end index of the slice within `flat_tensor_to_load` that we're going # to load into the slice of `flat_tensor`. - flat_tensor_to_load_start, flat_tensor_to_load_end = 0, numel_in_file + flat_tensor_to_load_start, flat_tensor_to_load_end = ( + numel_in_file_so_far, + numel_in_file_so_far + numel_in_file_slice, + ) # There are 5 scenarios to consider in terms of where the tensors overlap. # Suppose the original flat tensor has 6 elements: 'x x x x x x' # ------------------------------------------- @@ -481,6 +486,7 @@ def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: flat_view_slice[flat_tensor_start:flat_tensor_end].copy_(flat_tensor_to_load) del flat_tensor_to_load + numel_in_file_so_far += numel_in_file_slice state_dict[key] = self._copy_into(tensor, flat_view.view) del flat_view From e61ec90267d2fd23f2dc7b4fed83a9f940a8de84 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 09:36:21 -0700 Subject: [PATCH 12/27] try again --- src/olmo_core/distributed/checkpoint.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index af2af08e..4bcfe560 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -424,10 +424,7 @@ def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: flat_tensor_start, flat_tensor_end = 0, flat_view_slice.numel() # Start and end index of the slice within `flat_tensor_to_load` that we're going # to load into the slice of `flat_tensor`. - flat_tensor_to_load_start, flat_tensor_to_load_end = ( - numel_in_file_so_far, - numel_in_file_so_far + numel_in_file_slice, - ) + flat_tensor_to_load_start, flat_tensor_to_load_end = 0, numel_in_file_slice # There are 5 scenarios to consider in terms of where the tensors overlap. # Suppose the original flat tensor has 6 elements: 'x x x x x x' # ------------------------------------------- @@ -473,7 +470,9 @@ def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: # Load the slice. flat_tensor_to_load = loader.get_flat_slice( - key, flat_tensor_to_load_start, flat_tensor_to_load_end + key, + numel_in_file_so_far + flat_tensor_to_load_start, + numel_in_file_so_far + flat_tensor_to_load_end, ) if ( load_shape := flat_view_slice[flat_tensor_start:flat_tensor_end].shape From 53c9519baf6a326815e4b52fce088d355cbc09b4 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 09:52:04 -0700 Subject: [PATCH 13/27] clean up --- src/olmo_core/distributed/checkpoint.py | 44 ++++++++++++------------- src/test/distributed/checkpoint_test.py | 8 ++--- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index 4bcfe560..75959961 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -393,16 +393,19 @@ def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: flat_view.view.copy_(loader.get_flat_slice(key)) break - all_offsets_in_file = tensor_storage_metadata.get_flattened_offsets_in_file(filename) - if not all_offsets_in_file: + if tensor_storage_metadata.get_numel_in_file(filename) == 0: continue + # TODO: (optimization) if the offsets in the file are a subset of the offsets in the + # flat view, load the entire slice from the file all at once and then copy into the + # flat view slice-by-slice. + for offsets, flat_view_slice in flat_view.get_local_flattened_offsets_with_slice(): if offsets[1] - offsets[0] == 0: continue numel_in_file_so_far = 0 - for offsets_in_file in all_offsets_in_file: + for offsets_in_file in tensor_storage_metadata.get_flattened_offsets_in_file(filename): numel_in_file_slice = offsets_in_file[1] - offsets_in_file[0] if numel_in_file_slice == 0: continue @@ -748,20 +751,18 @@ def local_numel(self) -> int: else: raise ValueError("missing required fields to determine local numel") - def get_flattened_offsets(self, full_shape: Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]: + def get_flattened_offsets(self, full_shape: Tuple[int, ...]) -> Generator[Tuple[int, int], None, None]: if self.flattened_offsets is not None: - return self.flattened_offsets + yield from self.flattened_offsets elif self.local_shape is not None and self.global_offset is not None: assert len(self.local_shape) == len(self.global_offset) == len(full_shape) if len(full_shape) == 1: # 1D tensor - return ((self.global_offset[0], self.global_offset[0] + self.local_numel),) + yield (self.global_offset[0], self.global_offset[0] + self.local_numel) elif len(full_shape) == 2: - offsets = [] for row in range(self.global_offset[0], self.global_offset[0] + self.local_shape[0]): offset_start = row * full_shape[1] + self.global_offset[1] offset_end = offset_start + self.local_shape[1] - offsets.append((offset_start, offset_end)) - return tuple(offsets) + yield (offset_start, offset_end) else: # TODO: generalize raise NotImplementedError("only 1D and 2D DTensors are supported") @@ -802,11 +803,17 @@ def materialize_empty( tensor.fill_(torch.nan) return tensor - def get_flattened_offsets_in_file(self, filename: str) -> Optional[Tuple[Tuple[int, int], ...]]: + def get_flattened_offsets_in_file(self, filename: str) -> Generator[Tuple[int, int], None, None]: + if (shard_spec := self.shard_spec_per_file.get(filename)) is not None: + yield from shard_spec.get_flattened_offsets(self.shape) + else: + yield from [] + + def get_numel_in_file(self, filename: str) -> int: if (shard_spec := self.shard_spec_per_file.get(filename)) is not None: - return shard_spec.get_flattened_offsets(self.shape) + return shard_spec.local_numel else: - return None + return 0 class StorageMetadata(BaseModel): @@ -851,25 +858,18 @@ class TensorFlatView: def shard_spec(self) -> TensorShardSpec: return self.shard_spec_per_rank[get_rank()] - @cached_property - def local_flattened_offsets(self) -> Tuple[Tuple[int, int], ...]: - return self.shard_spec_per_rank[get_rank()].get_flattened_offsets(self.full_shape) + def get_local_flattened_offsets(self) -> Generator[Tuple[int, int], None, None]: + yield from self.shard_spec_per_rank[get_rank()].get_flattened_offsets(self.full_shape) def get_local_flattened_offsets_with_slice( self, ) -> Generator[Tuple[Tuple[int, int], torch.Tensor], None, None]: numel_so_far = 0 - for offset_start, offset_end in self.local_flattened_offsets: + for offset_start, offset_end in self.get_local_flattened_offsets(): numel_in_slice = offset_end - offset_start yield (offset_start, offset_end), self.view[numel_so_far : numel_so_far + numel_in_slice] numel_so_far += numel_in_slice - def get_flattened_offsets_for_rank(self, rank: int) -> Optional[Tuple[Tuple[int, int], ...]]: - if (shard_spec := self.shard_spec_per_rank.get(rank)) is not None: - return shard_spec.get_flattened_offsets(self.full_shape) - else: - return None - class SavePlan(BaseModel): tensors: Dict[str, TensorSavePlan] diff --git a/src/test/distributed/checkpoint_test.py b/src/test/distributed/checkpoint_test.py index 177ff6a0..e0e983c1 100644 --- a/src/test/distributed/checkpoint_test.py +++ b/src/test/distributed/checkpoint_test.py @@ -40,7 +40,7 @@ def test_tensor_shard_spec_for_dtensor_1D(): full_shape = (16,) shard_spec = TensorShardSpec(local_shape=(8,), global_offset=(0,)) - assert shard_spec.get_flattened_offsets(full_shape) == ((0, 8),) + assert list(shard_spec.get_flattened_offsets(full_shape)) == [(0, 8)] def test_tensor_shard_spec_for_dtensor_2D_colwise(): @@ -50,7 +50,7 @@ def test_tensor_shard_spec_for_dtensor_2D_colwise(): # distribute_tensor(torch.randn(16, 8), mesh, [Shard(dim=0)]) full_shape = (16, 8) shard_spec = TensorShardSpec(local_shape=(4, 8), global_offset=(4, 0)) - assert shard_spec.get_flattened_offsets(full_shape) == ((32, 40), (40, 48), (48, 56), (56, 64)) + assert list(shard_spec.get_flattened_offsets(full_shape)) == [(32, 40), (40, 48), (48, 56), (56, 64)] def test_tensor_shard_spec_for_dtensor_2D_rowwise(): @@ -60,7 +60,7 @@ def test_tensor_shard_spec_for_dtensor_2D_rowwise(): # distribute_tensor(torch.randn(16, 8), mesh, [Shard(dim=1)]) full_shape = (16, 8) shard_spec = TensorShardSpec(local_shape=(16, 2), global_offset=(0, 2)) - assert shard_spec.get_flattened_offsets(full_shape) == ( + assert list(shard_spec.get_flattened_offsets(full_shape)) == [ (2, 4), # row 0 (10, 12), # row 1 (18, 20), # row 2 @@ -77,7 +77,7 @@ def test_tensor_shard_spec_for_dtensor_2D_rowwise(): (106, 108), # row 13 (114, 116), # row 14 (122, 124), # row 15 - ) + ] def save_and_load_checkpoint_with_regular_and_sharded_tensors(dir): From 4f5299c072e6559d0c99dac8dcff61e9fc603187 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 11:20:06 -0700 Subject: [PATCH 14/27] optimize --- src/olmo_core/distributed/checkpoint.py | 156 ++++++++++++++++++------ src/test/distributed/checkpoint_test.py | 73 +++++++++++ 2 files changed, 189 insertions(+), 40 deletions(-) diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index 75959961..b4f75ca7 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -64,7 +64,12 @@ serialize_to_tensor, upload, ) -from olmo_core.utils import TORCH_DTYPE_TO_STR, TORCH_DTYPES, default_thread_count +from olmo_core.utils import ( + TORCH_DTYPE_TO_STR, + TORCH_DTYPES, + StrEnum, + default_thread_count, +) from .tensors import ShardedFlatTensor, ShardingSpec from .utils import all_gather_object, barrier, get_rank, get_world_size, scatter_object @@ -378,50 +383,40 @@ def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: ) if flat_view.shard_spec.local_numel == 0: - continue # nothing to load + continue # nothing to load into for filename, shard_spec_in_file in tensor_storage_metadata.shard_spec_per_file.items(): if shard_spec_in_file.local_numel == 0: - continue - - if flat_view.shard_spec == shard_spec_in_file: - # Load the whole slice within the file at once. - with safetensors_mfl.open(f"{dir}/{filename}") as loader: - validate_shard_in_file(tensor, loader, key, filename) - numel_in_file = loader.get_numel(key) - if numel_in_file > 0: - flat_view.view.copy_(loader.get_flat_slice(key)) - break - - if tensor_storage_metadata.get_numel_in_file(filename) == 0: - continue - - # TODO: (optimization) if the offsets in the file are a subset of the offsets in the - # flat view, load the entire slice from the file all at once and then copy into the - # flat view slice-by-slice. - - for offsets, flat_view_slice in flat_view.get_local_flattened_offsets_with_slice(): - if offsets[1] - offsets[0] == 0: - continue - - numel_in_file_so_far = 0 - for offsets_in_file in tensor_storage_metadata.get_flattened_offsets_in_file(filename): - numel_in_file_slice = offsets_in_file[1] - offsets_in_file[0] - if numel_in_file_slice == 0: + continue # nothing to load from + + # Compute overlap between the slice we want to load and the slice in the given file. + overlap = flat_view.compute_overlap_with(shard_spec_in_file) + if overlap is None: + continue # no overlap with data in file, so nothing to load + + with safetensors_mfl.open(f"{dir}/{filename}") as loader: + validate_shard_in_file(tensor, loader, key, filename) + + if overlap == OverlapType.EQUAL: + flat_view.view.copy_(loader.get_flat_slice(key)) + break + + # TODO: (optimization) if the offsets in the file are a subset of the offsets in the + # flat view, load the entire slice from the file all at once and then copy into the + # flat view slice-by-slice. + + for offsets, flat_view_slice in flat_view.get_local_flattened_offsets_with_slice(): + if offsets[1] - offsets[0] == 0: continue - # Check for overlap in offsets, and if there is overlap, load the slice from disk. - if ( - offsets_in_file[0] <= offsets[0] < offsets_in_file[1] - or offsets_in_file[0] < offsets[1] <= offsets_in_file[1] - or (offsets[0] < offsets_in_file[0] and offsets_in_file[1] < offsets[1]) - ): - with safetensors_mfl.open(f"{dir}/{filename}") as loader: - validate_shard_in_file(tensor, loader, key, filename) - numel_in_file = loader.get_numel(key) - if numel_in_file == 0: - continue + numel_in_file_so_far = 0 + for offsets_in_file in tensor_storage_metadata.get_flattened_offsets_in_file(filename): + numel_in_file_slice = offsets_in_file[1] - offsets_in_file[0] + if numel_in_file_slice == 0: + continue + # Check for overlap in offsets, and if there is overlap, load the slice from disk. + if _offsets_overlap(offsets, offsets_in_file): # Start and end index of the slice within `flat_tensor` that we're going to load # from a slice of `flat_tensor_to_load`. flat_tensor_start, flat_tensor_end = 0, flat_view_slice.numel() @@ -488,7 +483,7 @@ def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: flat_view_slice[flat_tensor_start:flat_tensor_end].copy_(flat_tensor_to_load) del flat_tensor_to_load - numel_in_file_so_far += numel_in_file_slice + numel_in_file_so_far += numel_in_file_slice state_dict[key] = self._copy_into(tensor, flat_view.view) del flat_view @@ -720,6 +715,13 @@ def _normalize_dir(self, dir: PathOrStr) -> str: return dir +class OverlapType(StrEnum): + EQUAL = "EQUAL" + SUPERSET = "SUPERSET" + SUBSET = "SUBSET" + MIXED = "MIXED" + + class TensorShardSpec(BaseModel): model_config = ConfigDict(frozen=True) @@ -752,6 +754,11 @@ def local_numel(self) -> int: raise ValueError("missing required fields to determine local numel") def get_flattened_offsets(self, full_shape: Tuple[int, ...]) -> Generator[Tuple[int, int], None, None]: + """ + Get flattened offsets into the full flattened tensor that the given shard corresponds to. + If ``self.flattened_offsets`` is set, this just returns a generator over those, otherwise + it computes them from ``self.local_shape`` and ``self.global_offset``. + """ if self.flattened_offsets is not None: yield from self.flattened_offsets elif self.local_shape is not None and self.global_offset is not None: @@ -769,6 +776,72 @@ def get_flattened_offsets(self, full_shape: Tuple[int, ...]) -> Generator[Tuple[ else: raise ValueError("missing required fields to produce flattened offsets") + def get_merged_flattened_offsets(self, full_shape: Tuple[int, int]) -> Generator[Tuple[int, int], None, None]: + """ + Like :meth:`get_flattened_offset` but it merges consecutive offsets that are contiguous. + """ + current_start: Optional[int] = None + current_end: Optional[int] = None + for offset_start, offset_end in self.get_flattened_offsets(full_shape): + if offset_end - offset_start == 0: + continue + if current_start is None or current_end is None: + current_start = offset_start + current_end = offset_end + elif current_end == offset_start: + current_end = offset_end + else: + yield (current_start, current_end) + current_start = offset_start + current_end = offset_end + + if current_start is not None and current_end is not None: + yield (current_start, current_end) + + def compute_overlap_with(self, other: TensorShardSpec, full_shape: Tuple[int, ...]) -> Optional[OverlapType]: + if self == other: + return OverlapType.EQUAL + + if self.flattened_offsets is not None or other.flattened_offsets is not None: + results: Set[OverlapType] = set() + for offsets in self.get_merged_flattened_offsets(full_shape): + for other_offsets in other.get_merged_flattened_offsets(full_shape): + if offsets == other_offsets: + results.add(OverlapType.EQUAL) + elif offsets[0] <= other_offsets[0] and other_offsets[1] <= offsets[1]: + results.add(OverlapType.SUPERSET) + elif other_offsets[0] <= offsets[0] and offsets[1] <= other_offsets[1]: + results.add(OverlapType.SUBSET) + elif _offsets_overlap(offsets, other_offsets): + results.add(OverlapType.MIXED) + + if not results: + return None + elif len(results) == 1: + return list(results)[0] + elif results == {OverlapType.EQUAL, OverlapType.SUPERSET}: + return OverlapType.SUPERSET + elif results == {OverlapType.EQUAL, OverlapType.SUBSET}: + return OverlapType.SUBSET + else: + return OverlapType.MIXED + + return None + + +def _offsets_overlap(offsets: Tuple[int, int], other_offsets: Tuple[int, int]) -> bool: + """ + Check if a pair of offsets have any overlap. + """ + if ( + other_offsets[0] <= offsets[0] < other_offsets[1] + or other_offsets[0] < offsets[1] <= other_offsets[1] + or (offsets[0] < other_offsets[0] and other_offsets[1] < offsets[1]) + ): + return True + else: + return False + class TensorStorageMetadata(BaseModel): shape: Tuple[int, ...] @@ -870,6 +943,9 @@ def get_local_flattened_offsets_with_slice( yield (offset_start, offset_end), self.view[numel_so_far : numel_so_far + numel_in_slice] numel_so_far += numel_in_slice + def compute_overlap_with(self, other: TensorShardSpec) -> Optional[OverlapType]: + return self.shard_spec.compute_overlap_with(other, self.full_shape) + class SavePlan(BaseModel): tensors: Dict[str, TensorSavePlan] diff --git a/src/test/distributed/checkpoint_test.py b/src/test/distributed/checkpoint_test.py index e0e983c1..89361644 100644 --- a/src/test/distributed/checkpoint_test.py +++ b/src/test/distributed/checkpoint_test.py @@ -10,10 +10,12 @@ from olmo_core.distributed.checkpoint import ( Checkpointer, OptimStateDict, + OverlapType, SafeTensorsLoader, TensorShardSpec, _flatten_optimizer_state, _get_model_state_dict_for_checkpoint, + _offsets_overlap, _unflatten_optimizer_state, init_optimizer_state, load_model_and_optim_state, @@ -37,6 +39,77 @@ ) +def test_offsets_overlap(): + assert _offsets_overlap((0, 3), (1, 4)) + assert _offsets_overlap((0, 6), (0, 6)) + assert _offsets_overlap((0, 6), (0, 12)) + assert _offsets_overlap((1, 6), (0, 12)) + assert _offsets_overlap((0, 6), (2, 4)) + assert _offsets_overlap((0, 6), (5, 6)) + assert _offsets_overlap((0, 6), (0, 2)) + + assert _offsets_overlap((1, 4), (0, 3)) + assert _offsets_overlap((0, 6), (0, 6)) + assert _offsets_overlap((0, 12), (0, 6)) + assert _offsets_overlap((0, 12), (1, 6)) + assert _offsets_overlap((2, 4), (0, 6)) + assert _offsets_overlap((5, 6), (0, 6)) + assert _offsets_overlap((0, 2), (0, 6)) + + assert not _offsets_overlap((2, 5), (7, 9)) + + +def test_tensor_shard_spec_get_merged_flattened_offsets(): + assert list(TensorShardSpec(flattened_offsets=((0, 3),)).get_merged_flattened_offsets((128, 256))) == [(0, 3)] + + assert list(TensorShardSpec(flattened_offsets=((0, 3), (3, 6))).get_merged_flattened_offsets((128, 256))) == [ + (0, 6) + ] + + assert list( + TensorShardSpec(flattened_offsets=((0, 3), (6, 9), (9, 12), (15, 18))).get_merged_flattened_offsets( + (128, 256) + ) + ) == [(0, 3), (6, 12), (15, 18)] + + +def test_tensor_shard_spec_compute_overlap(): + assert ( + TensorShardSpec(flattened_offsets=((0, 3), (3, 6))).compute_overlap_with( + TensorShardSpec(flattened_offsets=((0, 6),)), (128, 256) + ) + == OverlapType.EQUAL + ) + + assert ( + TensorShardSpec(flattened_offsets=((0, 3), (6, 12))).compute_overlap_with( + TensorShardSpec(flattened_offsets=((0, 3), (6, 9))), (128, 256) + ) + == OverlapType.SUPERSET + ) + + assert ( + TensorShardSpec(flattened_offsets=((0, 3), (6, 12))).compute_overlap_with( + TensorShardSpec(flattened_offsets=((0, 15),)), (128, 256) + ) + == OverlapType.SUBSET + ) + + assert ( + TensorShardSpec(flattened_offsets=((0, 3), (6, 12))).compute_overlap_with( + TensorShardSpec(flattened_offsets=((2, 5),)), (128, 256) + ) + == OverlapType.MIXED + ) + + assert ( + TensorShardSpec(flattened_offsets=((0, 3), (6, 12))).compute_overlap_with( + TensorShardSpec(flattened_offsets=((12, 15),)), (128, 256) + ) + is None + ) + + def test_tensor_shard_spec_for_dtensor_1D(): full_shape = (16,) shard_spec = TensorShardSpec(local_shape=(8,), global_offset=(0,)) From 329fa00c60fbff5e6bfbee8f7ad245e7178b5e33 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 11:22:27 -0700 Subject: [PATCH 15/27] clean up --- src/olmo_core/distributed/checkpoint.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index b4f75ca7..16315553 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -360,15 +360,6 @@ def load( metadata = metadata or self.get_metadata(dir, no_dist=no_dist) safetensors_mfl = _safetensors_mfl or SafeTensorsMultiFileLoader() - def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: str, filename: str): - if len((shape_in_file := loader.get_shape(key))) != 1: - raise ValueError(f"Expected a 1D tensor at {key} in {filename}, found shape {shape_in_file}") - - if (dtype := loader.get_dtype(key)) != tensor.dtype: - raise ValueError( - f"Data type mismatch between tensor to load ({dtype}) and to load into ({tensor.dtype})" - ) - # Load each tensor from the slices in each file. for key in state_dict.keys(): log.debug("Loading tensor '%s' from state dict...", key) @@ -395,7 +386,15 @@ def validate_shard_in_file(tensor: torch.Tensor, loader: SafeTensorsLoader, key: continue # no overlap with data in file, so nothing to load with safetensors_mfl.open(f"{dir}/{filename}") as loader: - validate_shard_in_file(tensor, loader, key, filename) + # Validate the shard in the file. + if len((shape_in_file := loader.get_shape(key))) != 1: + raise ValueError( + f"Expected a 1D tensor at {key} in {filename}, found shape {shape_in_file}" + ) + if (dtype := loader.get_dtype(key)) != tensor.dtype: + raise ValueError( + f"Data type mismatch between tensor to load ({dtype}) and to load into ({tensor.dtype})" + ) if overlap == OverlapType.EQUAL: flat_view.view.copy_(loader.get_flat_slice(key)) @@ -826,7 +825,7 @@ def compute_overlap_with(self, other: TensorShardSpec, full_shape: Tuple[int, .. else: return OverlapType.MIXED - return None + return OverlapType.MIXED def _offsets_overlap(offsets: Tuple[int, int], other_offsets: Tuple[int, int]) -> bool: From 5395a8cdbb5c6e49a59ff258595f8b17991fd1d9 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 13:28:42 -0700 Subject: [PATCH 16/27] Optimizer further, add another test --- src/olmo_core/distributed/checkpoint.py | 65 ++++++++++++++++++++--- src/test/distributed/checkpoint_test.py | 70 ++++++++++++++++++++++++- 2 files changed, 126 insertions(+), 9 deletions(-) diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index 16315553..326ee358 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -376,6 +376,8 @@ def load( if flat_view.shard_spec.local_numel == 0: continue # nothing to load into + # Loop over each file and load from the file if there is any overlap between the shard + # in the file and the shard in the local state dict. for filename, shard_spec_in_file in tensor_storage_metadata.shard_spec_per_file.items(): if shard_spec_in_file.local_numel == 0: continue # nothing to load from @@ -400,9 +402,11 @@ def load( flat_view.view.copy_(loader.get_flat_slice(key)) break - # TODO: (optimization) if the offsets in the file are a subset of the offsets in the - # flat view, load the entire slice from the file all at once and then copy into the - # flat view slice-by-slice. + slice_in_file: Optional[torch.Tensor] = None + if overlap == OverlapType.SUPERSET: + # Optimization: pre-load the entire slice in the file when the slice to load + # is a superset of the slice in the file. + slice_in_file = loader.get_flat_slice(key) for offsets, flat_view_slice in flat_view.get_local_flattened_offsets_with_slice(): if offsets[1] - offsets[0] == 0: @@ -466,11 +470,19 @@ def load( ) # Load the slice. - flat_tensor_to_load = loader.get_flat_slice( - key, - numel_in_file_so_far + flat_tensor_to_load_start, - numel_in_file_so_far + flat_tensor_to_load_end, - ) + if slice_in_file is not None: + flat_tensor_to_load = slice_in_file[ + numel_in_file_so_far + + flat_tensor_to_load_start : numel_in_file_so_far + + flat_tensor_to_load_end + ] + else: + flat_tensor_to_load = loader.get_flat_slice( + key, + numel_in_file_so_far + flat_tensor_to_load_start, + numel_in_file_so_far + flat_tensor_to_load_end, + ) + if ( load_shape := flat_view_slice[flat_tensor_start:flat_tensor_end].shape ) != flat_tensor_to_load.shape: @@ -479,10 +491,12 @@ def load( f"({flat_tensor_start}, {flat_tensor_end}), " f"expected shape {tuple(load_shape)}, found {tuple(flat_tensor_to_load.shape)}" ) + flat_view_slice[flat_tensor_start:flat_tensor_end].copy_(flat_tensor_to_load) del flat_tensor_to_load numel_in_file_so_far += numel_in_file_slice + del slice_in_file state_dict[key] = self._copy_into(tensor, flat_view.view) del flat_view @@ -825,6 +839,41 @@ def compute_overlap_with(self, other: TensorShardSpec, full_shape: Tuple[int, .. else: return OverlapType.MIXED + if ( + self.local_shape is not None + and self.global_offset is not None + and other.local_shape is not None + and other.global_offset is not None + ): + results_per_dim: Set[Optional[OverlapType]] = set() + for dim in range(len(self.local_shape)): + dim_offsets = (self.global_offset[dim], self.global_offset[dim] + self.local_shape[dim]) + other_dim_offsets = (other.global_offset[dim], other.global_offset[dim] + other.local_shape[dim]) + if dim_offsets == other_dim_offsets: + results_per_dim.add(OverlapType.EQUAL) + elif dim_offsets[0] <= other_dim_offsets[0] and other_dim_offsets[1] <= dim_offsets[1]: + results_per_dim.add(OverlapType.SUPERSET) + elif other_dim_offsets[0] <= dim_offsets[0] and dim_offsets[1] <= other_dim_offsets[1]: + results_per_dim.add(OverlapType.SUBSET) + elif _offsets_overlap(dim_offsets, other_dim_offsets): + results_per_dim.add(OverlapType.MIXED) + else: + results_per_dim.add(None) + + if None in results_per_dim: + # At least one dimension doesn't have any overlap between `self` and `other`, + # which means no overlap at all. + return None + elif len(results_per_dim) == 1: + return list(results_per_dim)[0] + elif results_per_dim == {OverlapType.EQUAL, OverlapType.SUPERSET}: + return OverlapType.SUPERSET + elif results_per_dim == {OverlapType.EQUAL, OverlapType.SUBSET}: + return OverlapType.SUBSET + else: + return OverlapType.MIXED + + # Fall back to mixed to be safe. return OverlapType.MIXED diff --git a/src/test/distributed/checkpoint_test.py b/src/test/distributed/checkpoint_test.py index 89361644..a9035695 100644 --- a/src/test/distributed/checkpoint_test.py +++ b/src/test/distributed/checkpoint_test.py @@ -73,7 +73,7 @@ def test_tensor_shard_spec_get_merged_flattened_offsets(): ) == [(0, 3), (6, 12), (15, 18)] -def test_tensor_shard_spec_compute_overlap(): +def test_tensor_shard_spec_compute_overlap_with_flattened_offsets(): assert ( TensorShardSpec(flattened_offsets=((0, 3), (3, 6))).compute_overlap_with( TensorShardSpec(flattened_offsets=((0, 6),)), (128, 256) @@ -110,6 +110,50 @@ def test_tensor_shard_spec_compute_overlap(): ) +def test_tensor_shard_spec_compute_overlap_with_dtensor_fields(): + assert ( + TensorShardSpec(local_shape=(2, 8), global_offset=(0, 0)).compute_overlap_with( + TensorShardSpec(local_shape=(2, 8), global_offset=(0, 0)), (16, 8) + ) + == OverlapType.EQUAL + ) + + assert ( + TensorShardSpec(local_shape=(4, 8), global_offset=(0, 0)).compute_overlap_with( + TensorShardSpec(local_shape=(2, 8), global_offset=(1, 0)), (16, 8) + ) + == OverlapType.SUPERSET + ) + + assert ( + TensorShardSpec(local_shape=(2, 8), global_offset=(1, 0)).compute_overlap_with( + TensorShardSpec(local_shape=(4, 8), global_offset=(0, 0)), (16, 8) + ) + == OverlapType.SUBSET + ) + + assert ( + TensorShardSpec(local_shape=(2, 8), global_offset=(0, 0)).compute_overlap_with( + TensorShardSpec(local_shape=(4, 4), global_offset=(0, 0)), (16, 8) + ) + == OverlapType.MIXED + ) + + assert ( + TensorShardSpec(local_shape=(2, 4), global_offset=(1, 2)).compute_overlap_with( + TensorShardSpec(local_shape=(4, 8), global_offset=(0, 0)), (16, 8) + ) + == OverlapType.SUBSET + ) + + assert ( + TensorShardSpec(local_shape=(2, 4), global_offset=(1, 2)).compute_overlap_with( + TensorShardSpec(local_shape=(2, 4), global_offset=(0, 0)), (16, 8) + ) + == OverlapType.MIXED + ) + + def test_tensor_shard_spec_for_dtensor_1D(): full_shape = (16,) shard_spec = TensorShardSpec(local_shape=(8,), global_offset=(0,)) @@ -271,6 +315,30 @@ def test_save_and_load_checkpoint_with_different_dtensor_topology(tmp_path): ) +def save_and_unshard_dtensor(dir): + checkpointer = Checkpointer() + + mesh = init_device_mesh("cuda", (dist.get_world_size(),)) + + og_tensor = torch.randn(8, 6, device=get_default_device()) + + # Ensure tensor matches on all ranks (could use scatter here too, but whatever). + dist.all_reduce(og_tensor) + + state_dict_to_save = { + "x": distribute_tensor(og_tensor, mesh, [Shard(dim=0)]), + } + checkpointer.save(dir, state_dict_to_save) # type: ignore[arg-type] + + full_state_dict = checkpointer.unshard(dir, device=get_default_device()) + torch.testing.assert_close(og_tensor, full_state_dict["x"]) + + +@requires_multi_gpu +def test_save_and_unshard_dtensor(tmp_path): + run_distributed_test(save_and_unshard_dtensor, backend="nccl", func_args=(tmp_path,)) + + def save_and_load_checkpoint_with_different_sharding_spec(dir): for idx, (offsets_to_save, offsets_to_load) in enumerate( [ From ff852a2200f6a7d497bdb3db2872bc80bea753f8 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 13:48:13 -0700 Subject: [PATCH 17/27] fix type hint --- src/olmo_core/distributed/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index 326ee358..1e8fc373 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -789,7 +789,7 @@ def get_flattened_offsets(self, full_shape: Tuple[int, ...]) -> Generator[Tuple[ else: raise ValueError("missing required fields to produce flattened offsets") - def get_merged_flattened_offsets(self, full_shape: Tuple[int, int]) -> Generator[Tuple[int, int], None, None]: + def get_merged_flattened_offsets(self, full_shape: Tuple[int, ...]) -> Generator[Tuple[int, int], None, None]: """ Like :meth:`get_flattened_offset` but it merges consecutive offsets that are contiguous. """ From efc5d6b8d3d4dde27b87d9d58366c5dd8b8ebda1 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 13:57:10 -0700 Subject: [PATCH 18/27] clean up --- src/olmo_core/distributed/checkpoint.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index 1e8fc373..bf60dd7d 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -458,17 +458,6 @@ def load( # Scenarios (A), (D) flat_tensor_end -= offsets[1] - offsets_in_file[1] - log.debug( - "Loading '%s'\n offsets: %s\n offsets in file: %s\n load into: (%s, %s)\n load from: (%s, %s)", - key, - offsets, - offsets_in_file, - flat_tensor_start, - flat_tensor_end, - flat_tensor_to_load_start, - flat_tensor_to_load_end, - ) - # Load the slice. if slice_in_file is not None: flat_tensor_to_load = slice_in_file[ @@ -487,9 +476,11 @@ def load( load_shape := flat_view_slice[flat_tensor_start:flat_tensor_end].shape ) != flat_tensor_to_load.shape: raise RuntimeError( - f"error loading tensor '{key}' from file '{filename}' with offsets " - f"({flat_tensor_start}, {flat_tensor_end}), " - f"expected shape {tuple(load_shape)}, found {tuple(flat_tensor_to_load.shape)}" + f"Error loading tensor '{key}' with offsets {offsets} " + f"from file '{filename}' with offsets {offsets_in_file}.\n" + f"Loading into slice ({flat_tensor_start}, {flat_tensor_end}) from " + f"slice ({flat_tensor_to_load_start}, {flat_tensor_to_load_end}) failed, " + f"expected shape {tuple(load_shape)}, found {tuple(flat_tensor_to_load.shape)}." ) flat_view_slice[flat_tensor_start:flat_tensor_end].copy_(flat_tensor_to_load) From 763e1dff4d84261789460e3cdb46405181466fb9 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 14:07:00 -0700 Subject: [PATCH 19/27] clean up --- src/olmo_core/distributed/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index bf60dd7d..c023d36a 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -495,7 +495,7 @@ def load( if _check_for_nans: # Check for NaNs which would indicate we didn't fill the state dict correctly. if state_dict[key].isnan().any().item(): - raise RuntimeError(f"error loading {key} from checkpoint, nans encountered") + raise RuntimeError(f"error loading '{key}' from checkpoint, nans encountered") @torch.no_grad() def unshard( From 256380192f481154979b3b97e476a2eac1dca700 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 15:42:58 -0700 Subject: [PATCH 20/27] add a tensor parallel model test --- src/test/distributed/checkpoint_test.py | 54 +++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/src/test/distributed/checkpoint_test.py b/src/test/distributed/checkpoint_test.py index a9035695..823d73b5 100644 --- a/src/test/distributed/checkpoint_test.py +++ b/src/test/distributed/checkpoint_test.py @@ -4,8 +4,15 @@ import pytest import torch import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F from cached_path import cached_path from torch.distributed._tensor import Shard, distribute_tensor, init_device_mesh +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + parallelize_module, +) from olmo_core.distributed.checkpoint import ( Checkpointer, @@ -813,3 +820,50 @@ def test_save_and_load_torch_fsdp_model( use_orig_params, ), ) + + +def run_save_and_load_tensor_parallel_model(dir): + tp_mesh = init_device_mesh("cuda", (dist.get_world_size(),)) + + class FeedForward(nn.Module): + def __init__(self, dim: int = 16): + super().__init__() + self.dim = dim + self.w1 = nn.Linear(dim, dim) + self.w2 = nn.Linear(dim, dim) + self.w3 = nn.Linear(dim, dim) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + feed_forward = FeedForward().cuda() + parallelize_module( + feed_forward, + tp_mesh, + { + # by default ColwiseParallel input layouts is replicated + # and RowwiseParallel output layouts is replicated + "w1": ColwiseParallel(), + "w2": RowwiseParallel(), + "w3": ColwiseParallel(), + }, + ) + optim = torch.optim.AdamW(feed_forward.parameters()) + + # Take a forward and backward pass. + feed_forward(torch.rand((2, feed_forward.dim), device="cuda")).sum().backward() + + # Take an optimizer step. + + # Save checkpoint. + save_model_and_optim_state(dir, feed_forward, optim) + + +@requires_multi_gpu +def test_save_and_load_tensor_parallel_model(tmp_path): + run_distributed_test( + run_save_and_load_tensor_parallel_model, + backend="nccl", + start_method="spawn", + func_args=(tmp_path,), + ) From f9260eb384d9b73762bf144713634ea0d9017bcc Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 15:46:45 -0700 Subject: [PATCH 21/27] remove faulty assert --- src/olmo_core/distributed/tensors/dtensor_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/olmo_core/distributed/tensors/dtensor_utils.py b/src/olmo_core/distributed/tensors/dtensor_utils.py index d26dbcee..9b92922c 100644 --- a/src/olmo_core/distributed/tensors/dtensor_utils.py +++ b/src/olmo_core/distributed/tensors/dtensor_utils.py @@ -30,7 +30,6 @@ def get_local_shape_and_global_offset( mesh = dtensor.device_mesh placements = dtensor.placements local_shape, global_offset = compute_local_shape_and_global_offset(global_shape, mesh, placements, rank=rank) - assert local_shape == dtensor.to_local().shape return local_shape, global_offset From 9e278f30fbe39f1a37a6fddb90e1eaf77d5bbe6f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 15:51:11 -0700 Subject: [PATCH 22/27] Don't wrap sharded state tensors that are already sharded --- src/olmo_core/distributed/checkpoint.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index c023d36a..10f471d8 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -1319,6 +1319,9 @@ def _get_local_tensor_data(tensor: torch.Tensor) -> torch.Tensor: def _wrap_tensor_for_sharded_parameter(tensor: torch.Tensor, param: Optional[torch.Tensor]) -> torch.Tensor: + if isinstance(tensor, (ShardedFlatTensor, DTensor)): + return tensor + if isinstance(param, ShardedFlatTensor): return param.wrap(tensor, requires_grad=False) elif isinstance(param, DTensor): From 0ab1db55909bfaf5735a726489c1402911a3b397 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 15:54:32 -0700 Subject: [PATCH 23/27] parametrize test --- src/test/distributed/checkpoint_test.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/test/distributed/checkpoint_test.py b/src/test/distributed/checkpoint_test.py index 823d73b5..e1af96dc 100644 --- a/src/test/distributed/checkpoint_test.py +++ b/src/test/distributed/checkpoint_test.py @@ -822,7 +822,7 @@ def test_save_and_load_torch_fsdp_model( ) -def run_save_and_load_tensor_parallel_model(dir): +def run_save_and_load_tensor_parallel_model(dir, take_step_before_checkpoint): tp_mesh = init_device_mesh("cuda", (dist.get_world_size(),)) class FeedForward(nn.Module): @@ -853,17 +853,22 @@ def forward(self, x): # Take a forward and backward pass. feed_forward(torch.rand((2, feed_forward.dim), device="cuda")).sum().backward() - # Take an optimizer step. + if take_step_before_checkpoint: + # Take an optimizer step. + optim.step() # Save checkpoint. save_model_and_optim_state(dir, feed_forward, optim) @requires_multi_gpu -def test_save_and_load_tensor_parallel_model(tmp_path): +@pytest.mark.parametrize( + "take_step_before_checkpoint", [pytest.param(True, id="after-step"), pytest.param(False, id="pre-step")] +) +def test_save_and_load_tensor_parallel_model(tmp_path, take_step_before_checkpoint): run_distributed_test( run_save_and_load_tensor_parallel_model, backend="nccl", start_method="spawn", - func_args=(tmp_path,), + func_args=(tmp_path, take_step_before_checkpoint), ) From 97f2e6bb28fbc9447a7797395c1febe07f98b5a0 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 16:05:32 -0700 Subject: [PATCH 24/27] fix how we initialize optim state --- src/olmo_core/distributed/checkpoint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index 10f471d8..2ecbe2a2 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -1120,8 +1120,7 @@ def init_optimizer_state(optim: torch.optim.Optimizer): # Some parameters may be empty for sharded models, in which case the state does not need # to be initialized. if p.numel() > 0: - p.grad = p.data.new(p.size()).zero_() - p.grad.requires_grad_(False) + p.grad = torch.zeros_like(p, memory_format=torch.preserve_format) optim.step() optim.zero_grad() From 44433c064851751a7ad9d07e38a499db15d7ae2e Mon Sep 17 00:00:00 2001 From: epwalsh Date: Fri, 10 May 2024 16:09:00 -0700 Subject: [PATCH 25/27] Aaand try loading --- src/test/distributed/checkpoint_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/test/distributed/checkpoint_test.py b/src/test/distributed/checkpoint_test.py index e1af96dc..7dbe2c80 100644 --- a/src/test/distributed/checkpoint_test.py +++ b/src/test/distributed/checkpoint_test.py @@ -860,6 +860,11 @@ def forward(self, x): # Save checkpoint. save_model_and_optim_state(dir, feed_forward, optim) + # Now load the checkpoint with a different topology, in this case an unsharded model. + unsharded_feed_forward = FeedForward().cuda() + unsharded_optim = torch.optim.AdamW(unsharded_feed_forward.parameters()) + load_model_and_optim_state(dir, unsharded_feed_forward, unsharded_optim) + @requires_multi_gpu @pytest.mark.parametrize( From 9b80a5dba63492417cd729d504450fc57249129b Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 13 May 2024 13:17:52 -0700 Subject: [PATCH 26/27] Add GPU tests --- .github/workflows/main.yml | 58 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 20af56cd..7f9b5a17 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -101,3 +101,61 @@ jobs: run: | . .venv/bin/activate pip uninstall -y ai2-olmo-core + + gpu_tests: + name: GPU Tests + runs-on: ubuntu-latest + timeout-minutes: 8 + env: + BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }} + BEAKER_IMAGE: olmo-torch2-test + BEAKER_WORKSPACE: ai2/llm-testing + steps: + - name: Determine current commit SHA (pull request) + if: github.event_name == 'pull_request' + run: | + echo "COMMIT_SHA=${{ github.event.pull_request.head.sha }}" >> $GITHUB_ENV + + - name: Determine current commit SHA (push) + if: github.event_name != 'pull_request' + run: | + echo "COMMIT_SHA=$GITHUB_SHA" >> $GITHUB_ENV + + - name: GPU Tests + uses: allenai/beaker-run-action@v1.2 + if: env.BEAKER_TOKEN != '' + with: + spec: | + version: v2 + description: GPU Tests + budget: ai2/oe-training + tasks: + - name: tests + image: + beaker: ${{ env.BEAKER_IMAGE }} + context: + priority: normal + preemptible: true + resources: + gpuCount: 2 + constraints: + cluster: + - ai2/general-cirrascale + - ai2/general-cirrascale-a100-80g-ib + - ai2/allennlp-cirrascale + - ai2/allennlp-elanding-a100-40g + - ai2/pluto-cirrascale + - ai2/jupiter-cirrascale + envVars: + - name: CUBLAS_WORKSPACE_CONFIG + value: ":16:8" + - name: TOKENIZERS_PARALLELISM + value: "false" + command: + - "bash" + - "-c" + - "git clone https://github.com/allenai/OLMo-core.git && cd OLMo-core && git checkout ${{ env.COMMIT_SHA }} && pip install -e .[all] && pytest -v src/test -m gpu" + result: + path: /unused + token: ${{ env.BEAKER_TOKEN }} + workspace: ${{ env.BEAKER_WORKSPACE }} From 19a8f68b71d78a64c3254a1e1ebd33782854a109 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 13 May 2024 13:27:18 -0700 Subject: [PATCH 27/27] run GPU tests in a matrix --- .github/workflows/main.yml | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 7f9b5a17..9c274fcd 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -102,14 +102,26 @@ jobs: . .venv/bin/activate pip uninstall -y ai2-olmo-core - gpu_tests: - name: GPU Tests + gpu_checks: + name: ${{ matrix.task.name }} runs-on: ubuntu-latest timeout-minutes: 8 env: BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }} BEAKER_IMAGE: olmo-torch2-test BEAKER_WORKSPACE: ai2/llm-testing + strategy: + fail-fast: false + matrix: + task: + - name: Test (GPU) + run: pytest -v --color=yes --durations=3 -m gpu src/test/ --ignore-glob='src/test/distributed/fsdp*' --ignore-glob='src/test/distributed/checkpoint*' + + - name: Test checkpoint (GPU) + run: pytest -v --color=yes --durations=3 -m gpu src/test/distributed/checkpoint* + + - name: Test FSDP (GPU) + run: pytest -v --color=yes --durations=3 -m gpu src/test/distributed/fsdp/ steps: - name: Determine current commit SHA (pull request) if: github.event_name == 'pull_request' @@ -127,7 +139,7 @@ jobs: with: spec: | version: v2 - description: GPU Tests + description: OLMo-core ${{ matrix.task.name }} budget: ai2/oe-training tasks: - name: tests @@ -151,10 +163,14 @@ jobs: value: ":16:8" - name: TOKENIZERS_PARALLELISM value: "false" + - name: AWS_ACCESS_KEY_ID + secret: AWS_ACCESS_KEY_ID + - name: AWS_SECRET_ACCESS_KEY + secret: AWS_SECRET_ACCESS_KEY command: - "bash" - "-c" - - "git clone https://github.com/allenai/OLMo-core.git && cd OLMo-core && git checkout ${{ env.COMMIT_SHA }} && pip install -e .[all] && pytest -v src/test -m gpu" + - "git clone https://github.com/allenai/OLMo-core.git && cd OLMo-core && git checkout ${{ env.COMMIT_SHA }} && pip install -e .[all] && ${{ matrix.task.run }}" result: path: /unused token: ${{ env.BEAKER_TOKEN }}