diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 20af56cd..9c274fcd 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -101,3 +101,77 @@ jobs: run: | . .venv/bin/activate pip uninstall -y ai2-olmo-core + + gpu_checks: + name: ${{ matrix.task.name }} + runs-on: ubuntu-latest + timeout-minutes: 8 + env: + BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }} + BEAKER_IMAGE: olmo-torch2-test + BEAKER_WORKSPACE: ai2/llm-testing + strategy: + fail-fast: false + matrix: + task: + - name: Test (GPU) + run: pytest -v --color=yes --durations=3 -m gpu src/test/ --ignore-glob='src/test/distributed/fsdp*' --ignore-glob='src/test/distributed/checkpoint*' + + - name: Test checkpoint (GPU) + run: pytest -v --color=yes --durations=3 -m gpu src/test/distributed/checkpoint* + + - name: Test FSDP (GPU) + run: pytest -v --color=yes --durations=3 -m gpu src/test/distributed/fsdp/ + steps: + - name: Determine current commit SHA (pull request) + if: github.event_name == 'pull_request' + run: | + echo "COMMIT_SHA=${{ github.event.pull_request.head.sha }}" >> $GITHUB_ENV + + - name: Determine current commit SHA (push) + if: github.event_name != 'pull_request' + run: | + echo "COMMIT_SHA=$GITHUB_SHA" >> $GITHUB_ENV + + - name: GPU Tests + uses: allenai/beaker-run-action@v1.2 + if: env.BEAKER_TOKEN != '' + with: + spec: | + version: v2 + description: OLMo-core ${{ matrix.task.name }} + budget: ai2/oe-training + tasks: + - name: tests + image: + beaker: ${{ env.BEAKER_IMAGE }} + context: + priority: normal + preemptible: true + resources: + gpuCount: 2 + constraints: + cluster: + - ai2/general-cirrascale + - ai2/general-cirrascale-a100-80g-ib + - ai2/allennlp-cirrascale + - ai2/allennlp-elanding-a100-40g + - ai2/pluto-cirrascale + - ai2/jupiter-cirrascale + envVars: + - name: CUBLAS_WORKSPACE_CONFIG + value: ":16:8" + - name: TOKENIZERS_PARALLELISM + value: "false" + - name: AWS_ACCESS_KEY_ID + secret: AWS_ACCESS_KEY_ID + - name: AWS_SECRET_ACCESS_KEY + secret: AWS_SECRET_ACCESS_KEY + command: + - "bash" + - "-c" + - "git clone https://github.com/allenai/OLMo-core.git && cd OLMo-core && git checkout ${{ env.COMMIT_SHA }} && pip install -e .[all] && ${{ matrix.task.run }}" + result: + path: /unused + token: ${{ env.BEAKER_TOKEN }} + workspace: ${{ env.BEAKER_WORKSPACE }} diff --git a/docs/source/distributed/checkpoint.rst b/docs/source/distributed/checkpoint.rst index f2797d93..b6d8df57 100644 --- a/docs/source/distributed/checkpoint.rst +++ b/docs/source/distributed/checkpoint.rst @@ -2,5 +2,5 @@ ========================== .. automodule:: olmo_core.distributed.checkpoint - :members: save_model_and_optim_state, load_model_and_optim_state, unshard_model_state, unshard_optim_state, Checkpointer, StorageMetadata, TensorStorageMetadata + :members: save_model_and_optim_state, load_model_and_optim_state, unshard_model_state, unshard_optim_state, Checkpointer, StorageMetadata, TensorStorageMetadata, TensorShardSpec :member-order: bysource diff --git a/docs/source/distributed/tensors.rst b/docs/source/distributed/tensors.rst index ea2af14b..25fc5ac5 100644 --- a/docs/source/distributed/tensors.rst +++ b/docs/source/distributed/tensors.rst @@ -4,3 +4,7 @@ .. automodule:: olmo_core.distributed.tensors :members: :member-order: bysource + +.. automodule:: olmo_core.distributed.tensors.dtensor_utils + :members: + :member-order: bysource diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index ab9f0a4f..2ecbe2a2 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -8,7 +8,8 @@ -------- - Sharded distributed models, such OLMo-core's :class:`~olmo_core.distributed.fsdp.FSDP` or PyTorch's - :class:`~torch.distributed.fsdp.FullyShardedDataParallel` are supported out-of-the-box. + :class:`~torch.distributed.fsdp.FullyShardedDataParallel` (with ``use_orig_params=True``) + are supported out-of-the-box. - Utilizes `safetensors `_ under the hood for fast, efficient, and safe serialization/deserialization. - Save with one distributed topology, seamlessly load with a different one. For example, @@ -39,7 +40,7 @@ from dataclasses import dataclass from functools import cached_property, reduce from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Tuple, TypedDict +from typing import Any, Dict, Generator, List, Optional, Set, Tuple, TypedDict import safetensors as sft import safetensors.torch as sft_torch @@ -47,8 +48,10 @@ import torch.distributed as dist import torch.nn as nn from cached_path import cached_path -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict +from torch.distributed._tensor import DTensor +import olmo_core.distributed.tensors.dtensor_utils as dtensor_utils from olmo_core.exceptions import OLMoUserError from olmo_core.io import ( PathOrStr, @@ -61,7 +64,12 @@ serialize_to_tensor, upload, ) -from olmo_core.utils import TORCH_DTYPE_TO_STR, TORCH_DTYPES +from olmo_core.utils import ( + TORCH_DTYPE_TO_STR, + TORCH_DTYPES, + StrEnum, + default_thread_count, +) from .tensors import ShardedFlatTensor, ShardingSpec from .utils import all_gather_object, barrier, get_rank, get_world_size, scatter_object @@ -225,7 +233,8 @@ class Checkpointer: """ A distributed checkpointer for saving and loading *non-nested* state dictionaries, i.e. where keys are strings and values are either regular :class:`torch.Tensor` instances, - :class:`torch.nn.Parameter` instances, or any sharded tensors from this library. + :class:`torch.nn.Parameter` instances, :class:`DTensor` instances, or any sharded tensors + from this library. For saving and loading model and optimizer states together, use :func:`save_model_and_optim_state()` and :func:`load_model_and_optim_state()` instead. @@ -290,10 +299,8 @@ def save( tensor_save_plan = global_save_plan.tensors[key] local_flat_tensor = flat_views[key] - if (local_offsets := tensor_save_plan.flattened_offsets_per_rank.get(local_rank)) is not None: - local_numel = 0 - for start_idx, end_idx in local_offsets: - local_numel += end_idx - start_idx + if (local_shard_spec := tensor_save_plan.shard_spec_per_rank.get(local_rank)) is not None: + local_numel = local_shard_spec.local_numel assert local_numel == local_flat_tensor.numel() local_state_dict[key] = local_flat_tensor @@ -360,51 +367,65 @@ def load( tensor = state_dict[key] flat_view = self._get_flat_view(tensor) - # Rank 0 will always be present, other ranks will not be for regular unsharded tensors. - all_offsets: Tuple[Tuple[int, int], ...] - if flat_view.is_sharded: - all_offsets = flat_view.flattened_offsets_per_rank[get_rank()] - else: - if get_rank() in flat_view.flattened_offsets_per_rank: - all_offsets = flat_view.flattened_offsets_per_rank[get_rank()] - else: - all_offsets = next(iter(flat_view.flattened_offsets_per_rank.values())) - + # Make sure full unsharded shapes match. if flat_view.full_shape != tensor_storage_metadata.shape: raise ValueError( f"Shape mismatched for '{key}', expected {flat_view.full_shape}, found {tensor_storage_metadata.shape}" ) - for offsets in all_offsets: - for filename, all_offsets_in_file in tensor_storage_metadata.flattened_offsets_per_file.items(): - for offsets_in_file in all_offsets_in_file: - # Check for overlap in offsets, and if there is overlap, load the slice from disk. - if ( - offsets_in_file[0] <= offsets[0] < offsets_in_file[1] - or offsets_in_file[0] < offsets[1] <= offsets_in_file[1] - or (offsets[0] < offsets_in_file[0] and offsets_in_file[1] < offsets[1]) - ): - with safetensors_mfl.open(f"{dir}/{filename}") as loader: - if len((shape_in_file := loader.get_shape(key))) != 1: - raise ValueError( - f"Expected a 1D tensor at {key} in {filename}, found shape {shape_in_file}" - ) + if flat_view.shard_spec.local_numel == 0: + continue # nothing to load into - if (dtype := loader.get_dtype(key)) != tensor.dtype: - raise ValueError( - f"Data type mismatch between tensor to load ({dtype}) and to load into ({tensor.dtype})" - ) + # Loop over each file and load from the file if there is any overlap between the shard + # in the file and the shard in the local state dict. + for filename, shard_spec_in_file in tensor_storage_metadata.shard_spec_per_file.items(): + if shard_spec_in_file.local_numel == 0: + continue # nothing to load from + + # Compute overlap between the slice we want to load and the slice in the given file. + overlap = flat_view.compute_overlap_with(shard_spec_in_file) + if overlap is None: + continue # no overlap with data in file, so nothing to load + + with safetensors_mfl.open(f"{dir}/{filename}") as loader: + # Validate the shard in the file. + if len((shape_in_file := loader.get_shape(key))) != 1: + raise ValueError( + f"Expected a 1D tensor at {key} in {filename}, found shape {shape_in_file}" + ) + if (dtype := loader.get_dtype(key)) != tensor.dtype: + raise ValueError( + f"Data type mismatch between tensor to load ({dtype}) and to load into ({tensor.dtype})" + ) - numel_in_file = loader.get_numel(key) - if numel_in_file == 0: - continue + if overlap == OverlapType.EQUAL: + flat_view.view.copy_(loader.get_flat_slice(key)) + break + slice_in_file: Optional[torch.Tensor] = None + if overlap == OverlapType.SUPERSET: + # Optimization: pre-load the entire slice in the file when the slice to load + # is a superset of the slice in the file. + slice_in_file = loader.get_flat_slice(key) + + for offsets, flat_view_slice in flat_view.get_local_flattened_offsets_with_slice(): + if offsets[1] - offsets[0] == 0: + continue + + numel_in_file_so_far = 0 + for offsets_in_file in tensor_storage_metadata.get_flattened_offsets_in_file(filename): + numel_in_file_slice = offsets_in_file[1] - offsets_in_file[0] + if numel_in_file_slice == 0: + continue + + # Check for overlap in offsets, and if there is overlap, load the slice from disk. + if _offsets_overlap(offsets, offsets_in_file): # Start and end index of the slice within `flat_tensor` that we're going to load # from a slice of `flat_tensor_to_load`. - flat_tensor_start, flat_tensor_end = 0, flat_view.view.numel() + flat_tensor_start, flat_tensor_end = 0, flat_view_slice.numel() # Start and end index of the slice within `flat_tensor_to_load` that we're going # to load into the slice of `flat_tensor`. - flat_tensor_to_load_start, flat_tensor_to_load_end = 0, numel_in_file + flat_tensor_to_load_start, flat_tensor_to_load_end = 0, numel_in_file_slice # There are 5 scenarios to consider in terms of where the tensors overlap. # Suppose the original flat tensor has 6 elements: 'x x x x x x' # ------------------------------------------- @@ -437,31 +458,36 @@ def load( # Scenarios (A), (D) flat_tensor_end -= offsets[1] - offsets_in_file[1] - log.debug( - "Loading '%s'\n offsets: %s\n offsets in file: %s\n load into: (%s, %s)\n load from: (%s, %s)", - key, - offsets, - offsets_in_file, - flat_tensor_start, - flat_tensor_end, - flat_tensor_to_load_start, - flat_tensor_to_load_end, - ) - # Load the slice. - flat_tensor_to_load = loader.get_flat_slice( - key, flat_tensor_to_load_start, flat_tensor_to_load_end - ) + if slice_in_file is not None: + flat_tensor_to_load = slice_in_file[ + numel_in_file_so_far + + flat_tensor_to_load_start : numel_in_file_so_far + + flat_tensor_to_load_end + ] + else: + flat_tensor_to_load = loader.get_flat_slice( + key, + numel_in_file_so_far + flat_tensor_to_load_start, + numel_in_file_so_far + flat_tensor_to_load_end, + ) + if ( - load_shape := flat_view.view[flat_tensor_start:flat_tensor_end].shape + load_shape := flat_view_slice[flat_tensor_start:flat_tensor_end].shape ) != flat_tensor_to_load.shape: raise RuntimeError( - f"error loading {key} from {filename} with offsets ({flat_tensor_start}, {flat_tensor_end}), " - f"expected {load_shape}, found {flat_tensor_to_load.shape}" + f"Error loading tensor '{key}' with offsets {offsets} " + f"from file '{filename}' with offsets {offsets_in_file}.\n" + f"Loading into slice ({flat_tensor_start}, {flat_tensor_end}) from " + f"slice ({flat_tensor_to_load_start}, {flat_tensor_to_load_end}) failed, " + f"expected shape {tuple(load_shape)}, found {tuple(flat_tensor_to_load.shape)}." ) - flat_view.view[flat_tensor_start:flat_tensor_end].copy_(flat_tensor_to_load) + + flat_view_slice[flat_tensor_start:flat_tensor_end].copy_(flat_tensor_to_load) del flat_tensor_to_load + numel_in_file_so_far += numel_in_file_slice + del slice_in_file state_dict[key] = self._copy_into(tensor, flat_view.view) del flat_view @@ -469,7 +495,7 @@ def load( if _check_for_nans: # Check for NaNs which would indicate we didn't fill the state dict correctly. if state_dict[key].isnan().any().item(): - raise RuntimeError(f"error loading {key} from checkpoint, nans encountered") + raise RuntimeError(f"error loading '{key}' from checkpoint, nans encountered") @torch.no_grad() def unshard( @@ -493,8 +519,12 @@ def unshard( :param rank0_only: Set to true if you only want to load the unsharded state to rank 0 in a distributed context. Other ranks will receive an empty dictionary. :param no_dist: Set to true to avoid any distributed communication whatsoever. + :param num_threads: The maximum number of threads to use to unshard the checkpoint. + Increasing ``num_threads`` can lead to a substantial speed up, especially when loading + from a remote checkpoint. Set to ``0`` to disable threading. """ dir = self._normalize_dir(dir) + num_threads = num_threads if num_threads is not None else default_thread_count() if rank0_only and no_dist and get_rank() != 0: raise ValueError( @@ -534,19 +564,31 @@ def get_metadata(self, dir: str, no_dist: bool = False) -> StorageMetadata: Get the storage metadata from a checkpoint directory. """ dir = self._normalize_dir(dir) + metadata: Optional[StorageMetadata] = None if no_dist or get_rank() == 0: with open(cached_path(f"{dir}/{self.METADATA_FILENAME}")) as f: json_metadata = json.load(f) - # For backwards compat, covert offsets `tuple[int, int]` into `tuple[tuple[int, int], ...]` + + # Coerce fields if needed for backwards compatibility. for tensor_metadata in json_metadata["tensors"].values(): - for path in tensor_metadata["flattened_offsets_per_file"]: - offsets = tensor_metadata["flattened_offsets_per_file"][path] - if offsets and isinstance(offsets[0], int): - tensor_metadata["flattened_offsets_per_file"][path] = [offsets] + if "flattened_offsets_per_file" in tensor_metadata: + for path in tensor_metadata["flattened_offsets_per_file"]: + offsets = tensor_metadata["flattened_offsets_per_file"][path] + # covert offsets `tuple[int, int]` into `tuple[tuple[int, int], ...]` + if offsets and isinstance(offsets[0], int): + tensor_metadata["flattened_offsets_per_file"][path] = [offsets] + + tensor_metadata["shard_spec_per_file"] = { + path: {"flattened_offsets": offsets} + for path, offsets in tensor_metadata.pop("flattened_offsets_per_file").items() + } + metadata = StorageMetadata(**json_metadata) + if not no_dist: metadata = scatter_object(metadata) + assert metadata is not None return metadata @@ -562,7 +604,7 @@ def _copy_into(self, target: torch.Tensor, source: torch.Tensor): def _get_flat_view(self, tensor: torch.Tensor) -> TensorFlatView: full_shape: Tuple[int, ...] is_sharded: bool = False - flattened_offsets_per_rank: Dict[int, Tuple[Tuple[int, int], ...]] = {} + shard_spec_per_rank: Dict[int, TensorShardSpec] = {} if isinstance(tensor, ShardedFlatTensor): full_shape = tensor.unsharded_shape is_sharded = True @@ -573,15 +615,26 @@ def _get_flat_view(self, tensor: torch.Tensor) -> TensorFlatView: if tensor.process_group is None else dist.get_global_rank(tensor.process_group, pg_rank) ) - flattened_offsets_per_rank[global_rank] = offsets + shard_spec_per_rank[global_rank] = TensorShardSpec(flattened_offsets=offsets) + elif isinstance(tensor, DTensor): + full_shape = tuple(tensor.shape) + is_sharded = True + for global_rank in tensor.device_mesh.mesh.flatten(): + global_rank = int(global_rank.item()) + local_shape, global_offset = dtensor_utils.get_local_shape_and_global_offset( + tensor, rank=global_rank + ) + shard_spec_per_rank[global_rank] = TensorShardSpec( + local_shape=local_shape, global_offset=global_offset + ) else: full_shape = tuple(tensor.shape) - flattened_offsets_per_rank = {get_rank(): ((0, tensor.numel()),)} + shard_spec_per_rank[get_rank()] = TensorShardSpec(flattened_offsets=((0, tensor.numel()),)) return TensorFlatView( view=_get_local_tensor_data(tensor).view(-1), full_shape=full_shape, is_sharded=is_sharded, - flattened_offsets_per_rank=flattened_offsets_per_rank, + shard_spec_per_rank=shard_spec_per_rank, ) def _get_global_save_plan_and_metadata( @@ -594,16 +647,17 @@ def _get_global_save_plan_and_metadata( flat_view = self._get_flat_view(tensor) tensors_flat_view[key] = flat_view.view tensors_save_plan[key] = TensorSavePlan( - flattened_offsets_per_rank=flat_view.flattened_offsets_per_rank, is_sharded=flat_view.is_sharded + is_sharded=flat_view.is_sharded, + shard_spec_per_rank=flat_view.shard_spec_per_rank, ) tensors_metadata[key] = TensorStorageMetadata( - flattened_offsets_per_file={ - self._filename_for_rank(rank): offsets - for rank, offsets in flat_view.flattened_offsets_per_rank.items() - }, shape=flat_view.full_shape, is_sharded=flat_view.is_sharded, dtype=TORCH_DTYPE_TO_STR[tensor.dtype], + shard_spec_per_file={ + self._filename_for_rank(rank): shard_spec + for rank, shard_spec in flat_view.shard_spec_per_rank.items() + }, ) # All-gather save plans across ranks, merge and validate. @@ -619,9 +673,7 @@ def _get_global_save_plan_and_metadata( if not plan.is_sharded and not final_plan.is_sharded: # default to first rank with a save plan for this tensor pass - elif not set(plan.flattened_offsets_per_rank).intersection( - final_plan.flattened_offsets_per_rank - ): + elif not set(plan.shard_spec_per_rank).intersection(final_plan.shard_spec_per_rank): # tensor may be sharded in separate process groups, that's okay. pass else: @@ -644,9 +696,7 @@ def _get_global_save_plan_and_metadata( if not metadata.is_sharded and not final_metadata.is_sharded: # default to first rank with metadata for this tensor pass - elif not set(metadata.flattened_offsets_per_file).intersection( - final_metadata.flattened_offsets_per_file - ): + elif not set(metadata.shard_spec_per_file).intersection(final_metadata.shard_spec_per_file): # tensor may be sharded in separate process groups, that's okay. pass else: @@ -669,13 +719,170 @@ def _normalize_dir(self, dir: PathOrStr) -> str: return dir -class TensorStorageMetadata(BaseModel): - flattened_offsets_per_file: Dict[str, Tuple[Tuple[int, int], ...]] +class OverlapType(StrEnum): + EQUAL = "EQUAL" + SUPERSET = "SUPERSET" + SUBSET = "SUBSET" + MIXED = "MIXED" + + +class TensorShardSpec(BaseModel): + model_config = ConfigDict(frozen=True) + + flattened_offsets: Optional[Tuple[Tuple[int, int], ...]] = None + """ + Offsets within the full flattened tensor that the given shard corresponds to. + """ + + local_shape: Optional[Tuple[int, ...]] = None + """ + The (unflattened) shape of the local shard. + """ + + global_offset: Optional[Tuple[int, ...]] = None + """ + The starting offset for each dimension in the global unsharded (unflattened) tensor that the + local shard corresponds to. + """ + + @property + def local_numel(self) -> int: + if self.local_shape is not None: + return reduce(lambda x, y: x * y, self.local_shape, 1) + elif self.flattened_offsets is not None: + local_numel = 0 + for start_idx, end_idx in self.flattened_offsets: + local_numel += end_idx - start_idx + return local_numel + else: + raise ValueError("missing required fields to determine local numel") + + def get_flattened_offsets(self, full_shape: Tuple[int, ...]) -> Generator[Tuple[int, int], None, None]: + """ + Get flattened offsets into the full flattened tensor that the given shard corresponds to. + If ``self.flattened_offsets`` is set, this just returns a generator over those, otherwise + it computes them from ``self.local_shape`` and ``self.global_offset``. + """ + if self.flattened_offsets is not None: + yield from self.flattened_offsets + elif self.local_shape is not None and self.global_offset is not None: + assert len(self.local_shape) == len(self.global_offset) == len(full_shape) + if len(full_shape) == 1: # 1D tensor + yield (self.global_offset[0], self.global_offset[0] + self.local_numel) + elif len(full_shape) == 2: + for row in range(self.global_offset[0], self.global_offset[0] + self.local_shape[0]): + offset_start = row * full_shape[1] + self.global_offset[1] + offset_end = offset_start + self.local_shape[1] + yield (offset_start, offset_end) + else: + # TODO: generalize + raise NotImplementedError("only 1D and 2D DTensors are supported") + else: + raise ValueError("missing required fields to produce flattened offsets") + + def get_merged_flattened_offsets(self, full_shape: Tuple[int, ...]) -> Generator[Tuple[int, int], None, None]: + """ + Like :meth:`get_flattened_offset` but it merges consecutive offsets that are contiguous. + """ + current_start: Optional[int] = None + current_end: Optional[int] = None + for offset_start, offset_end in self.get_flattened_offsets(full_shape): + if offset_end - offset_start == 0: + continue + if current_start is None or current_end is None: + current_start = offset_start + current_end = offset_end + elif current_end == offset_start: + current_end = offset_end + else: + yield (current_start, current_end) + current_start = offset_start + current_end = offset_end + + if current_start is not None and current_end is not None: + yield (current_start, current_end) + + def compute_overlap_with(self, other: TensorShardSpec, full_shape: Tuple[int, ...]) -> Optional[OverlapType]: + if self == other: + return OverlapType.EQUAL + + if self.flattened_offsets is not None or other.flattened_offsets is not None: + results: Set[OverlapType] = set() + for offsets in self.get_merged_flattened_offsets(full_shape): + for other_offsets in other.get_merged_flattened_offsets(full_shape): + if offsets == other_offsets: + results.add(OverlapType.EQUAL) + elif offsets[0] <= other_offsets[0] and other_offsets[1] <= offsets[1]: + results.add(OverlapType.SUPERSET) + elif other_offsets[0] <= offsets[0] and offsets[1] <= other_offsets[1]: + results.add(OverlapType.SUBSET) + elif _offsets_overlap(offsets, other_offsets): + results.add(OverlapType.MIXED) + + if not results: + return None + elif len(results) == 1: + return list(results)[0] + elif results == {OverlapType.EQUAL, OverlapType.SUPERSET}: + return OverlapType.SUPERSET + elif results == {OverlapType.EQUAL, OverlapType.SUBSET}: + return OverlapType.SUBSET + else: + return OverlapType.MIXED + + if ( + self.local_shape is not None + and self.global_offset is not None + and other.local_shape is not None + and other.global_offset is not None + ): + results_per_dim: Set[Optional[OverlapType]] = set() + for dim in range(len(self.local_shape)): + dim_offsets = (self.global_offset[dim], self.global_offset[dim] + self.local_shape[dim]) + other_dim_offsets = (other.global_offset[dim], other.global_offset[dim] + other.local_shape[dim]) + if dim_offsets == other_dim_offsets: + results_per_dim.add(OverlapType.EQUAL) + elif dim_offsets[0] <= other_dim_offsets[0] and other_dim_offsets[1] <= dim_offsets[1]: + results_per_dim.add(OverlapType.SUPERSET) + elif other_dim_offsets[0] <= dim_offsets[0] and dim_offsets[1] <= other_dim_offsets[1]: + results_per_dim.add(OverlapType.SUBSET) + elif _offsets_overlap(dim_offsets, other_dim_offsets): + results_per_dim.add(OverlapType.MIXED) + else: + results_per_dim.add(None) + + if None in results_per_dim: + # At least one dimension doesn't have any overlap between `self` and `other`, + # which means no overlap at all. + return None + elif len(results_per_dim) == 1: + return list(results_per_dim)[0] + elif results_per_dim == {OverlapType.EQUAL, OverlapType.SUPERSET}: + return OverlapType.SUPERSET + elif results_per_dim == {OverlapType.EQUAL, OverlapType.SUBSET}: + return OverlapType.SUBSET + else: + return OverlapType.MIXED + + # Fall back to mixed to be safe. + return OverlapType.MIXED + + +def _offsets_overlap(offsets: Tuple[int, int], other_offsets: Tuple[int, int]) -> bool: """ - Maps file name to the offsets within the full flattened tensor that the shard in the file - corresponds to. + Check if a pair of offsets have any overlap. """ + if ( + other_offsets[0] <= offsets[0] < other_offsets[1] + or other_offsets[0] < offsets[1] <= other_offsets[1] + or (offsets[0] < other_offsets[0] and other_offsets[1] < offsets[1]) + ): + return True + else: + return False + +class TensorStorageMetadata(BaseModel): shape: Tuple[int, ...] """ The shape of the full (unflattened) tensor. @@ -691,6 +898,11 @@ class TensorStorageMetadata(BaseModel): The data type of the tensor. """ + shard_spec_per_file: Dict[str, TensorShardSpec] + """ + Maps each filename to the sharding spec of the local shard within that file. + """ + @property def torch_dtype(self) -> torch.dtype: return TORCH_DTYPES[self.dtype] @@ -703,20 +915,17 @@ def materialize_empty( tensor.fill_(torch.nan) return tensor - def materialize_from_sharded( - self, tensor: torch.Tensor, device: Optional[torch.device] = None - ) -> torch.Tensor: - if isinstance(tensor, ShardedFlatTensor): - if tensor.unsharded_shape != self.shape: - raise ValueError( - f"unexpected shape for sharded tensor, expected {self.shape}, got {tensor.unsharded_shape}" - ) - tensor = torch.empty(tensor.shape, device=device, dtype=self.torch_dtype) - if tensor.dtype.is_floating_point: - tensor.fill_(torch.nan) - return tensor + def get_flattened_offsets_in_file(self, filename: str) -> Generator[Tuple[int, int], None, None]: + if (shard_spec := self.shard_spec_per_file.get(filename)) is not None: + yield from shard_spec.get_flattened_offsets(self.shape) else: - raise NotImplementedError(f"`materialize_from_sharded()` not implemented for {tensor}") + yield from [] + + def get_numel_in_file(self, filename: str) -> int: + if (shard_spec := self.shard_spec_per_file.get(filename)) is not None: + return shard_spec.local_numel + else: + return 0 class StorageMetadata(BaseModel): @@ -724,15 +933,14 @@ class StorageMetadata(BaseModel): class TensorSavePlan(BaseModel): - flattened_offsets_per_rank: Dict[int, Tuple[Tuple[int, int], ...]] + is_sharded: bool """ - Maps global process rank to the offsets within the full flattened tensor that the shard for the - rank corresponds to. Some ranks may be omitted. + If the tensor is sharded. """ - is_sharded: bool + shard_spec_per_rank: Dict[int, TensorShardSpec] """ - If the tensor is sharded. + Maps each rank to the sharding spec of the local shard from that rank. Some ranks may be omitted. """ @@ -753,11 +961,30 @@ class TensorFlatView: If the tensor is sharded. """ - flattened_offsets_per_rank: Dict[int, Tuple[Tuple[int, int], ...]] + shard_spec_per_rank: Dict[int, TensorShardSpec] """ - A mapping of *global* rank to offsets into the full flattened tensor. + Maps each rank to the sharding spec of the local shard from that rank. """ + @property + def shard_spec(self) -> TensorShardSpec: + return self.shard_spec_per_rank[get_rank()] + + def get_local_flattened_offsets(self) -> Generator[Tuple[int, int], None, None]: + yield from self.shard_spec_per_rank[get_rank()].get_flattened_offsets(self.full_shape) + + def get_local_flattened_offsets_with_slice( + self, + ) -> Generator[Tuple[Tuple[int, int], torch.Tensor], None, None]: + numel_so_far = 0 + for offset_start, offset_end in self.get_local_flattened_offsets(): + numel_in_slice = offset_end - offset_start + yield (offset_start, offset_end), self.view[numel_so_far : numel_so_far + numel_in_slice] + numel_so_far += numel_in_slice + + def compute_overlap_with(self, other: TensorShardSpec) -> Optional[OverlapType]: + return self.shard_spec.compute_overlap_with(other, self.full_shape) + class SavePlan(BaseModel): tensors: Dict[str, TensorSavePlan] @@ -893,8 +1120,7 @@ def init_optimizer_state(optim: torch.optim.Optimizer): # Some parameters may be empty for sharded models, in which case the state does not need # to be initialized. if p.numel() > 0: - p.grad = p.data.new(p.size()).zero_() - p.grad.requires_grad_(False) + p.grad = torch.zeros_like(p, memory_format=torch.preserve_format) optim.step() optim.zero_grad() @@ -1085,11 +1311,29 @@ def _patch_key(model: nn.Module, key: str) -> str: def _get_local_tensor_data(tensor: torch.Tensor) -> torch.Tensor: - return tensor.data + if isinstance(tensor, DTensor): + return tensor.to_local() + else: + return tensor.data def _wrap_tensor_for_sharded_parameter(tensor: torch.Tensor, param: Optional[torch.Tensor]) -> torch.Tensor: + if isinstance(tensor, (ShardedFlatTensor, DTensor)): + return tensor + if isinstance(param, ShardedFlatTensor): return param.wrap(tensor, requires_grad=False) + elif isinstance(param, DTensor): + return DTensor( # type: ignore + tensor, + param.device_mesh, + param.placements, + shape=param.size(), + dtype=tensor.dtype, + requires_grad=False, + stride=param.stride(), + ) + elif isinstance(param, nn.Parameter) and isinstance(param.data, DTensor): + return _wrap_tensor_for_sharded_parameter(tensor, param.data) else: return tensor diff --git a/src/olmo_core/distributed/tensors/dtensor_utils.py b/src/olmo_core/distributed/tensors/dtensor_utils.py new file mode 100644 index 00000000..9b92922c --- /dev/null +++ b/src/olmo_core/distributed/tensors/dtensor_utils.py @@ -0,0 +1,127 @@ +""" +Helper functions for dealing with PyTorch's :class:`DTensor`. +""" + +from typing import Optional, Sequence, Tuple + +from torch.distributed._tensor import DTensor +from torch.distributed._tensor.placement_types import Placement, Shard +from torch.distributed.device_mesh import DeviceMesh + +from olmo_core.utils import ShapeType + +from ..utils import get_mesh_coordinates + + +def get_local_shape_and_global_offset( + dtensor: DTensor, rank: Optional[int] = None +) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + """ + Like :func:`compute_local_shape_and_global_offset`, but acts directly on a :class:`DTensor` + instance. + + :param dtensor: A DTensor instance. + :param rank: The global rank to compute the local shape and global offsets for. If ``None``, + defaults to the current rank. + + :returns: The local shape and global offset. + """ + global_shape = dtensor.shape + mesh = dtensor.device_mesh + placements = dtensor.placements + local_shape, global_offset = compute_local_shape_and_global_offset(global_shape, mesh, placements, rank=rank) + return local_shape, global_offset + + +# Adapted from `torch.distributed._tensor._utils.py`. +def compute_local_shape_and_global_offset( + global_shape: ShapeType, + mesh: DeviceMesh, + placements: Sequence[Placement], + rank: Optional[int] = None, +) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + """ + Compute the local tensor shape and the global offsets into the original tensor + of a DTensor on its current global rank. This is useful for checkpointing purpose. + + :param global_shape: The shape of the global unsharded tensor. + :param mesh: The device mesh. + :param placements: The placements of the :class:`DTensor`. + :param rank: The global rank to compute the local shape and global offsets for. If ``None``, + defaults to the current rank. + + :returns: The local shape and global offset. + + Example (2 host with 4GPUs each):: + + # Below is a DeviceMesh with mesh_shape of (2, 4) + mesh = DeviceMesh(device_type="cuda", mesh=[ + [0, 1, 2, 3], + [4, 5, 6, 7] + ]) + + Let's say we distribute a global_tensor of shape ``(8,4)`` over the above DeviceMesh + with a placements of ``[Shard(0), Shard(0)]``. + + The local shape and global offset will be as follows: + + - ``rank0 -- local_shape:[1, 4], global_offset:[0, 0]`` + - ``rank1 -- local_shape:[1, 4], global_offset:[1, 0]`` + - ``rank2 -- local_shape:[1, 4], global_offset:[2, 0]`` + - ``rank5 -- local_shape:[1, 4], global_offset:[5, 0]`` + - ``rank3 -- local_shape:[1, 4], global_offset:[3, 0]`` + - ``rank4 -- local_shape:[1, 4], global_offset:[4, 0]`` + - ``rank6 -- local_shape:[1, 4], global_offset:[6, 0]`` + - ``rank7 -- local_shape:[1, 4], global_offset:[7, 0]`` + + Let's say we distribute a global_tensor of shape ``(2,)`` over the above DeviceMesh with + a placements of ``[Shard(0)]``. We will not have non-empty local tensor for all the ranks. + + The local shape and global offset will be as follows: + + - ``rank0 -- local_shape:[1,], global_offset:[0,]`` + - ``rank1 -- local_shape:[1,], global_offset:[1,]`` + - ``rank2 -- local_shape:[0,], global_offset:[2,]`` + - ``rank5 -- local_shape:[0,], global_offset:[2,]`` + - ``rank3 -- local_shape:[0,], global_offset:[2,]`` + - ``rank4 -- local_shape:[0,], global_offset:[2,]`` + - ``rank6 -- local_shape:[0,], global_offset:[2,]`` + - ``rank7 -- local_shape:[0,], global_offset:[2,]`` + """ + my_coordinate = mesh.get_coordinate() if rank is None else get_mesh_coordinates(mesh, rank) + + if my_coordinate is None: + # if rank not in the mesh, return empty offset + return ((), ()) + else: + local_shape = list(global_shape) + global_offset = [0] * len(global_shape) + + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + if isinstance(placement, Shard): + shard_dim = placement.dim + local_offset = [0] * len(global_shape) + assert shard_dim < len( + local_shape + ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + shard_size, shard_offset = placement._local_shard_size_on_dim( + local_shape[shard_dim], + mesh_dim_size, + my_coordinate[idx], + return_offset=True, + ) + + local_shape[shard_dim] = shard_size + local_offset[shard_dim] = shard_offset + + # On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim], + # it means that this dimension has been already sharded in previous placement. + # Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim]. + # Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim]. + if global_offset[shard_dim] <= local_offset[shard_dim]: + global_offset[shard_dim] = local_offset[shard_dim] + else: + global_offset[shard_dim] += local_offset[shard_dim] + + return tuple(local_shape), tuple(global_offset) diff --git a/src/olmo_core/distributed/utils.py b/src/olmo_core/distributed/utils.py index a63a46d7..eb71568d 100644 --- a/src/olmo_core/distributed/utils.py +++ b/src/olmo_core/distributed/utils.py @@ -3,6 +3,7 @@ import torch import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh def is_distributed() -> bool: @@ -83,3 +84,18 @@ def get_gradient_divide_factor(world_size: int) -> float: while world_size % factor == 0 and world_size / factor > factor: factor *= 2 return float(factor) + + +def get_mesh_coordinates(mesh: DeviceMesh, rank: Optional[int] = None) -> Optional[List[int]]: + """ + Calculate the coordinates of a global rank on a device mesh. + + :param mesh: The device mesh. + :param rank: The global rank. If ``None``, the current global rank is used. + + :return: The coordinates or ``None`` if the rank is not part of the mesh. + """ + rank = rank if rank is not None else get_rank() + rank_coords = (mesh.mesh == rank).nonzero() + assert rank_coords.size(0) in (0, 1) + return rank_coords[0].tolist() if rank_coords.size(0) > 0 else None diff --git a/src/olmo_core/utils.py b/src/olmo_core/utils.py index 6338f1fb..9ffd12a3 100644 --- a/src/olmo_core/utils.py +++ b/src/olmo_core/utils.py @@ -3,7 +3,7 @@ import os import time from enum import Enum -from typing import Any, Callable, Iterable +from typing import Any, Callable, Iterable, List, Tuple, Union import numpy as np import torch @@ -28,6 +28,8 @@ def __repr__(self) -> str: return f"'{str(self)}'" +ShapeType = Union[torch.Size, List[int], Tuple[int, ...]] + # torch.float8 formats require 2.1; we do not support these dtypes on earlier versions _float8_e4m3fn = getattr(torch, "float8_e4m3fn", None) _float8_e5m2 = getattr(torch, "float8_e5m2", None) diff --git a/src/test/distributed/checkpoint_test.py b/src/test/distributed/checkpoint_test.py index 2ec2c994..7dbe2c80 100644 --- a/src/test/distributed/checkpoint_test.py +++ b/src/test/distributed/checkpoint_test.py @@ -4,14 +4,25 @@ import pytest import torch import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F from cached_path import cached_path +from torch.distributed._tensor import Shard, distribute_tensor, init_device_mesh +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + RowwiseParallel, + parallelize_module, +) from olmo_core.distributed.checkpoint import ( Checkpointer, OptimStateDict, + OverlapType, SafeTensorsLoader, + TensorShardSpec, _flatten_optimizer_state, _get_model_state_dict_for_checkpoint, + _offsets_overlap, _unflatten_optimizer_state, init_optimizer_state, load_model_and_optim_state, @@ -35,6 +46,164 @@ ) +def test_offsets_overlap(): + assert _offsets_overlap((0, 3), (1, 4)) + assert _offsets_overlap((0, 6), (0, 6)) + assert _offsets_overlap((0, 6), (0, 12)) + assert _offsets_overlap((1, 6), (0, 12)) + assert _offsets_overlap((0, 6), (2, 4)) + assert _offsets_overlap((0, 6), (5, 6)) + assert _offsets_overlap((0, 6), (0, 2)) + + assert _offsets_overlap((1, 4), (0, 3)) + assert _offsets_overlap((0, 6), (0, 6)) + assert _offsets_overlap((0, 12), (0, 6)) + assert _offsets_overlap((0, 12), (1, 6)) + assert _offsets_overlap((2, 4), (0, 6)) + assert _offsets_overlap((5, 6), (0, 6)) + assert _offsets_overlap((0, 2), (0, 6)) + + assert not _offsets_overlap((2, 5), (7, 9)) + + +def test_tensor_shard_spec_get_merged_flattened_offsets(): + assert list(TensorShardSpec(flattened_offsets=((0, 3),)).get_merged_flattened_offsets((128, 256))) == [(0, 3)] + + assert list(TensorShardSpec(flattened_offsets=((0, 3), (3, 6))).get_merged_flattened_offsets((128, 256))) == [ + (0, 6) + ] + + assert list( + TensorShardSpec(flattened_offsets=((0, 3), (6, 9), (9, 12), (15, 18))).get_merged_flattened_offsets( + (128, 256) + ) + ) == [(0, 3), (6, 12), (15, 18)] + + +def test_tensor_shard_spec_compute_overlap_with_flattened_offsets(): + assert ( + TensorShardSpec(flattened_offsets=((0, 3), (3, 6))).compute_overlap_with( + TensorShardSpec(flattened_offsets=((0, 6),)), (128, 256) + ) + == OverlapType.EQUAL + ) + + assert ( + TensorShardSpec(flattened_offsets=((0, 3), (6, 12))).compute_overlap_with( + TensorShardSpec(flattened_offsets=((0, 3), (6, 9))), (128, 256) + ) + == OverlapType.SUPERSET + ) + + assert ( + TensorShardSpec(flattened_offsets=((0, 3), (6, 12))).compute_overlap_with( + TensorShardSpec(flattened_offsets=((0, 15),)), (128, 256) + ) + == OverlapType.SUBSET + ) + + assert ( + TensorShardSpec(flattened_offsets=((0, 3), (6, 12))).compute_overlap_with( + TensorShardSpec(flattened_offsets=((2, 5),)), (128, 256) + ) + == OverlapType.MIXED + ) + + assert ( + TensorShardSpec(flattened_offsets=((0, 3), (6, 12))).compute_overlap_with( + TensorShardSpec(flattened_offsets=((12, 15),)), (128, 256) + ) + is None + ) + + +def test_tensor_shard_spec_compute_overlap_with_dtensor_fields(): + assert ( + TensorShardSpec(local_shape=(2, 8), global_offset=(0, 0)).compute_overlap_with( + TensorShardSpec(local_shape=(2, 8), global_offset=(0, 0)), (16, 8) + ) + == OverlapType.EQUAL + ) + + assert ( + TensorShardSpec(local_shape=(4, 8), global_offset=(0, 0)).compute_overlap_with( + TensorShardSpec(local_shape=(2, 8), global_offset=(1, 0)), (16, 8) + ) + == OverlapType.SUPERSET + ) + + assert ( + TensorShardSpec(local_shape=(2, 8), global_offset=(1, 0)).compute_overlap_with( + TensorShardSpec(local_shape=(4, 8), global_offset=(0, 0)), (16, 8) + ) + == OverlapType.SUBSET + ) + + assert ( + TensorShardSpec(local_shape=(2, 8), global_offset=(0, 0)).compute_overlap_with( + TensorShardSpec(local_shape=(4, 4), global_offset=(0, 0)), (16, 8) + ) + == OverlapType.MIXED + ) + + assert ( + TensorShardSpec(local_shape=(2, 4), global_offset=(1, 2)).compute_overlap_with( + TensorShardSpec(local_shape=(4, 8), global_offset=(0, 0)), (16, 8) + ) + == OverlapType.SUBSET + ) + + assert ( + TensorShardSpec(local_shape=(2, 4), global_offset=(1, 2)).compute_overlap_with( + TensorShardSpec(local_shape=(2, 4), global_offset=(0, 0)), (16, 8) + ) + == OverlapType.MIXED + ) + + +def test_tensor_shard_spec_for_dtensor_1D(): + full_shape = (16,) + shard_spec = TensorShardSpec(local_shape=(8,), global_offset=(0,)) + assert list(shard_spec.get_flattened_offsets(full_shape)) == [(0, 8)] + + +def test_tensor_shard_spec_for_dtensor_2D_colwise(): + # For example: + # from torch.distributed._tensor import Shard, distribute_tensor, init_device_mesh + # mesh = init_device_mesh("cuda", (dist.get_world_size(),)) + # distribute_tensor(torch.randn(16, 8), mesh, [Shard(dim=0)]) + full_shape = (16, 8) + shard_spec = TensorShardSpec(local_shape=(4, 8), global_offset=(4, 0)) + assert list(shard_spec.get_flattened_offsets(full_shape)) == [(32, 40), (40, 48), (48, 56), (56, 64)] + + +def test_tensor_shard_spec_for_dtensor_2D_rowwise(): + # For example: + # from torch.distributed._tensor import Shard, distribute_tensor, init_device_mesh + # mesh = init_device_mesh("cuda", (dist.get_world_size(),)) + # distribute_tensor(torch.randn(16, 8), mesh, [Shard(dim=1)]) + full_shape = (16, 8) + shard_spec = TensorShardSpec(local_shape=(16, 2), global_offset=(0, 2)) + assert list(shard_spec.get_flattened_offsets(full_shape)) == [ + (2, 4), # row 0 + (10, 12), # row 1 + (18, 20), # row 2 + (26, 28), # row 3 + (34, 36), # row 4 + (42, 44), # row 5 + (50, 52), # row 6 + (58, 60), # row 7 + (66, 68), # row 8 + (74, 76), # row 9 + (82, 84), # row 10 + (90, 92), # row 11 + (98, 100), # row 12 + (106, 108), # row 13 + (114, 116), # row 14 + (122, 124), # row 15 + ] + + def save_and_load_checkpoint_with_regular_and_sharded_tensors(dir): checkpointer = Checkpointer() @@ -92,6 +261,91 @@ def test_save_and_load_checkpoint_with_regular_and_sharded_tensors(backend, tmp_ assert full_state_dict["y"].shape == (2, 3) +def save_and_load_checkpoint_with_dtensors(dir): + checkpointer = Checkpointer() + + mesh = init_device_mesh("cuda", (dist.get_world_size(),)) + + state_dict_to_save = { + "1d": distribute_tensor(torch.randn(16, device=get_default_device()), mesh, [Shard(dim=0)]), + "2d_colwise": distribute_tensor(torch.randn(16, 8, device=get_default_device()), mesh, [Shard(dim=0)]), + "2d_rowwise": distribute_tensor(torch.randn(16, 8, device=get_default_device()), mesh, [Shard(dim=1)]), + } + + state_dict_to_load = { + "1d": distribute_tensor(torch.randn(16, device=get_default_device()), mesh, [Shard(dim=0)]), + "2d_colwise": distribute_tensor(torch.randn(16, 8, device=get_default_device()), mesh, [Shard(dim=0)]), + "2d_rowwise": distribute_tensor(torch.randn(16, 8, device=get_default_device()), mesh, [Shard(dim=1)]), + } + + checkpointer.save(dir, state_dict_to_save) # type: ignore[arg-type] + checkpointer.load(dir, state_dict_to_load) # type: ignore[arg-type] + + for key in state_dict_to_load: + torch.testing.assert_close(state_dict_to_save[key], state_dict_to_load[key]) + + +@requires_multi_gpu +def test_save_and_load_checkpoint_with_dtensors(tmp_path): + run_distributed_test(save_and_load_checkpoint_with_dtensors, backend="nccl", func_args=(tmp_path,)) + + +def save_and_load_checkpoint_with_different_dtensor_topology(dir): + checkpointer = Checkpointer() + + mesh = init_device_mesh("cuda", (dist.get_world_size(),)) + + og_tensor = torch.randn(8, 6, device=get_default_device()) + + # Ensure tensor matches on all ranks (could use scatter here too, but whatever). + dist.all_reduce(og_tensor) + + state_dict_to_save = { + "x": distribute_tensor(og_tensor, mesh, [Shard(dim=0)]), + } + checkpointer.save(dir, state_dict_to_save) # type: ignore[arg-type] + + state_dict_to_load = { + "x": distribute_tensor(torch.randn(8, 6, device=get_default_device()), mesh, [Shard(dim=1)]), + } + checkpointer.load(dir, state_dict_to_load) # type: ignore[arg-type] + + # Gather full tensor from the state dict to load and make sure it matches the full OG tensor. + full_loaded_tensor = state_dict_to_load["x"].full_tensor() + torch.testing.assert_close(og_tensor, full_loaded_tensor) + + +@requires_multi_gpu +def test_save_and_load_checkpoint_with_different_dtensor_topology(tmp_path): + run_distributed_test( + save_and_load_checkpoint_with_different_dtensor_topology, backend="nccl", func_args=(tmp_path,) + ) + + +def save_and_unshard_dtensor(dir): + checkpointer = Checkpointer() + + mesh = init_device_mesh("cuda", (dist.get_world_size(),)) + + og_tensor = torch.randn(8, 6, device=get_default_device()) + + # Ensure tensor matches on all ranks (could use scatter here too, but whatever). + dist.all_reduce(og_tensor) + + state_dict_to_save = { + "x": distribute_tensor(og_tensor, mesh, [Shard(dim=0)]), + } + checkpointer.save(dir, state_dict_to_save) # type: ignore[arg-type] + + full_state_dict = checkpointer.unshard(dir, device=get_default_device()) + torch.testing.assert_close(og_tensor, full_state_dict["x"]) + + +@requires_multi_gpu +def test_save_and_unshard_dtensor(tmp_path): + run_distributed_test(save_and_unshard_dtensor, backend="nccl", func_args=(tmp_path,)) + + def save_and_load_checkpoint_with_different_sharding_spec(dir): for idx, (offsets_to_save, offsets_to_load) in enumerate( [ @@ -566,3 +820,60 @@ def test_save_and_load_torch_fsdp_model( use_orig_params, ), ) + + +def run_save_and_load_tensor_parallel_model(dir, take_step_before_checkpoint): + tp_mesh = init_device_mesh("cuda", (dist.get_world_size(),)) + + class FeedForward(nn.Module): + def __init__(self, dim: int = 16): + super().__init__() + self.dim = dim + self.w1 = nn.Linear(dim, dim) + self.w2 = nn.Linear(dim, dim) + self.w3 = nn.Linear(dim, dim) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + feed_forward = FeedForward().cuda() + parallelize_module( + feed_forward, + tp_mesh, + { + # by default ColwiseParallel input layouts is replicated + # and RowwiseParallel output layouts is replicated + "w1": ColwiseParallel(), + "w2": RowwiseParallel(), + "w3": ColwiseParallel(), + }, + ) + optim = torch.optim.AdamW(feed_forward.parameters()) + + # Take a forward and backward pass. + feed_forward(torch.rand((2, feed_forward.dim), device="cuda")).sum().backward() + + if take_step_before_checkpoint: + # Take an optimizer step. + optim.step() + + # Save checkpoint. + save_model_and_optim_state(dir, feed_forward, optim) + + # Now load the checkpoint with a different topology, in this case an unsharded model. + unsharded_feed_forward = FeedForward().cuda() + unsharded_optim = torch.optim.AdamW(unsharded_feed_forward.parameters()) + load_model_and_optim_state(dir, unsharded_feed_forward, unsharded_optim) + + +@requires_multi_gpu +@pytest.mark.parametrize( + "take_step_before_checkpoint", [pytest.param(True, id="after-step"), pytest.param(False, id="pre-step")] +) +def test_save_and_load_tensor_parallel_model(tmp_path, take_step_before_checkpoint): + run_distributed_test( + run_save_and_load_tensor_parallel_model, + backend="nccl", + start_method="spawn", + func_args=(tmp_path, take_step_before_checkpoint), + )