From 28cc4caf654260a8e2ec15fede5b309a08232fcf Mon Sep 17 00:00:00 2001 From: Brian <23239305+b-chu@users.noreply.github.com> Date: Fri, 13 Oct 2023 15:06:22 -0400 Subject: [PATCH] Add partial state dict functionality for FSDP (#2637) * Use pytorch chunking commit-id:e4c9b78f * Add partial state dict functionality for FSDP commit-id:2a2cae33 --- composer/trainer/mosaic_fsdp.py | 10 +- composer/trainer/mosaic_fsdp_utils.py | 268 +++++++++++--------------- composer/utils/checkpoint.py | 13 +- 3 files changed, 128 insertions(+), 163 deletions(-) diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index adf934b9e6..fd786efe6f 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -11,7 +11,8 @@ from torch.distributed._shard.sharding_spec import ChunkShardingSpec from torch.distributed.fsdp import FullyShardedDataParallel -from composer.trainer.mosaic_fsdp_utils import build_metadata, custom_auto_wrap_t1p13p1, shard +from composer.trainer.mosaic_fsdp_utils import (_sharded_pre_load_state_dict_hook, build_metadata, + custom_auto_wrap_t1p13p1) def patch_pytorch(): @@ -33,18 +34,21 @@ def patch_pytorch(): # Monkey patch __init__ where __init__ calls the custom _auto_wrap fn from composer.trainer.mosaic_fsdp_utils import init_fn_t2p0p1 + FullyShardedDataParallel.__init__ = init_fn_t2p0p1 # type: ignore # Monkey patch sharding method ChunkShardingSpec.build_metadata = build_metadata - ChunkShardingSpec.shard = shard elif version.parse(torch.__version__) < version.parse('2.1.1'): # Monkey path for torch < 2.1.1 ie torch == 2.1.0 + from torch.distributed.fsdp import _state_dict_utils # Monkey patch sharding method ChunkShardingSpec.build_metadata = build_metadata - ChunkShardingSpec.shard = shard + + # Monkey patch partial state dict handling + _state_dict_utils._sharded_pre_load_state_dict_hook = (_sharded_pre_load_state_dict_hook) elif version.parse(torch.__version__) >= version.parse('2.1.1'): raise NotImplementedError(f'FullyShardedDataParallel is not supported for torch >= 2.2.0') diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 8b45cc48db..585490cd2c 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -8,30 +8,33 @@ import functools import inspect +import logging +import math import warnings from functools import partial -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Set, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Set, Tuple, Union, cast, no_type_check import torch import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta import torch.nn as nn +import torch.nn.functional as F from packaging import version from torch import distributed from torch.distributed import ProcessGroup -from torch.distributed._shard._utils import narrow_tensor -from torch.distributed._shard.sharded_tensor.shard import Shard -from torch.distributed._shard.sharded_tensor.utils import _parse_and_validate_remote_device -from torch.distributed._shard.sharding_spec import ChunkShardingSpec, ShardMetadata +from torch.distributed._shard.sharding_spec import ShardMetadata +from torch.distributed._shard.sharding_spec._internals import get_chunked_dim_size, get_split_size from torch.distributed.fsdp import (BackwardPrefetch, CPUOffload, FullyShardedDataParallel, MixedPrecision, ShardingStrategy) +from torch.distributed.fsdp._fsdp_extensions import _ext_pre_load_state_dict_transform +from torch.distributed.utils import _replace_by_prefix from composer.core import Precision from composer.utils import dist if TYPE_CHECKING: - # Only include ShardedTensor when do type checking, exclude it - # from run-time to resolve circular dependency. - from torch.distributed._shard.sharded_tensor import ShardedTensor + if version.parse(torch.__version__) >= version.parse('2.0.1') and version.parse( + torch.__version__) < version.parse('2.0.2'): + from torch.distributed.fsdp._common_utils import _FSDPState SHARDING_MAP = { 'NO_SHARD': ShardingStrategy.NO_SHARD, @@ -45,6 +48,8 @@ 'BACKWARD_POST': BackwardPrefetch.BACKWARD_POST, } +logger = logging.getLogger(__name__) + def _get_torch_dtype(dtype: Union[Precision, str]): """Convert common string representations of dtypes to torch dtypes.""" @@ -619,50 +624,14 @@ def init_fn_t2p0p1( _register_all_state_dict_hooks(self) -def get_split_size(dim_size: int, chunks: int) -> int: - """Gets the minimum size per chunk. - - A tensor of dim_size 5 and 4 chunks will have chunks of size [2, 1, 1, 1]. - - Args: - dim_size(int): Size of the dimension being chunked. - chunks(int): Number of chunks to create for ``dim_size``. - - Returns: - An int indicating the split size to use. - """ - return dim_size // chunks - - -def get_chunked_dim_size(dim_size: int, chunks: int, idx: int) -> int: - """Computes the dim size of the chunk for provided ``idx`` given ``dim_size`` and ``chunks``. - - A tensor of dim_size 5 and 4 chunks will have chunks of size [2, 1, 1, 1]. - - Args: - dim_size(int): Size of the dimension being chunked. - chunks(int): Number of chunks to create for ``dim_size``. - idx(int): The index of chunk whose dim size is being requested. - - Returns: - An int indicating the dim size of the chunk. - """ - assert idx >= 0 and idx < chunks - split_size = get_split_size(dim_size, chunks) - if idx < dim_size % chunks: - split_size += 1 - return split_size - - def build_metadata( - self: ChunkShardingSpec, + self, tensor_sizes: torch.Size, tensor_properties: sharded_tensor_meta.TensorProperties, ) -> sharded_tensor_meta.ShardedTensorMetadata: - """Updates ChunkShardingSpec's build_metadata fn to use a dynamic sharding dimension. + """Adds nightly change for ChunkShardingSpec. - modified version of - https://github.com/pytorch/pytorch/blob/v2.0.1/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py#L77 + Change implemented in https://github.com/pytorch/pytorch/pull/108915 """ tensor_num_dim = len(tensor_sizes) @@ -671,122 +640,113 @@ def build_metadata( raise ValueError(f'Invalid sharding dim: {self.dim}') shards_metadata = [] - while True: - sharding_dim_size = tensor_sizes[self.dim] # type: ignore[index] - chunks = len(self.placements) - - if sharding_dim_size // chunks == 0: - self.dim += 1 # type: ignore[operator] - else: - break - current_offsets = [0] * tensor_num_dim + sharding_dim_size = tensor_sizes[self.dim] # type: ignore[index] + chunks = len(self.placements) + split_size = get_split_size(sharding_dim_size, chunks) for idx, placement in enumerate(self.placements): # generate ShardMetadata for each placement device - chunked_dim_size = get_chunked_dim_size(sharding_dim_size, chunks, idx) - if chunked_dim_size > 0: - shard_size = list(tensor_sizes) - shard_size[self.dim] = chunked_dim_size # type: ignore[index] - - shard_metadata = ShardMetadata( - shard_offsets=current_offsets.copy(), - shard_sizes=shard_size, - placement=placement, - ) - shards_metadata.append(shard_metadata) + chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) + shard_size = list(tensor_sizes) + current_offsets = [0] * tensor_num_dim + current_offsets[self.dim] = split_size * idx # type: ignore[index] + shard_size[self.dim] = chunked_dim_size # type: ignore[index] + + shard_metadata = ShardMetadata( + shard_offsets=current_offsets, + shard_sizes=shard_size, + placement=placement, + ) + shards_metadata.append(shard_metadata) - current_offsets[self.dim] += chunked_dim_size # type: ignore[index] - self.dim = 0 return sharded_tensor_meta.ShardedTensorMetadata(shards_metadata, tensor_sizes, tensor_properties) -def shard(self: ChunkShardingSpec, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> 'ShardedTensor': - """Updates ChunkShardingSpec's shard fn to use a dynamic sharding dimension. +@no_type_check +def _sharded_pre_load_state_dict_hook( + module: nn.Module, + fsdp_state: '_FSDPState', + state_dict: Dict[str, Any], + prefix: str, +) -> None: + """Adds nightly change for partial state dict error handling. - modified version of - https://github.com/pytorch/pytorch/blob/v2.0.1/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py#L116 + https://github.com/pytorch/pytorch/blob/0511df0ee9edeb5c2613805ccfb49beb323b87f9/torch/distributed/fsdp/_state_dict_utils.py#L607-L615 + + The hook combines the unflattened, sharded parameters (ShardedTensor) to + a new FlatParameter and shards the new FlatParameter to the local chunk. """ - # relative imports to avoid circular dependency - from torch.distributed._shard.sharded_tensor import ShardedTensor - - tensor_properties = sharded_tensor_meta.TensorProperties( - dtype=tensor.dtype, - layout=tensor.layout, - requires_grad=tensor.requires_grad, - memory_format=torch.contiguous_format, - pin_memory=tensor.is_pinned(), - ) - current_rank = distributed.get_rank(process_group) - tensor_meta = self.build_metadata(tensor.size(), tensor_properties) - local_shards = [] - local_tensor = None - local_metadata = None - tensors_to_scatter = [None] * distributed.get_world_size(process_group) - - while True: - sharding_dim_size = tensor.size()[self.dim] # type: ignore[index] - chunks = len(self.placements) - - if sharding_dim_size // chunks == 0: - self.dim += 1 # type: ignore[operator] + from torch.distributed._tensor import Replicate + from torch.distributed.distributed_c10d import _get_pg_default_device + from torch.distributed.fsdp._common_utils import FSDP_PREFIX, _has_fsdp_params, _is_composable, _module_handle + from torch.distributed.fsdp._runtime_utils import _lazy_init + from torch.distributed.fsdp._state_dict_utils import _enter_unshard_params_ctx, _param_name_infos + + _lazy_init(fsdp_state, module) + if not _is_composable(fsdp_state): + _replace_by_prefix(state_dict, prefix, prefix + f'{FSDP_PREFIX}') + if not _has_fsdp_params(fsdp_state, module): + return + + handle = _module_handle(fsdp_state, module) + if not handle.uses_sharded_strategy: + raise RuntimeError('load_sharded_state_dict can only be called when parameters ' + 'are flattened and sharded.') + + device = fsdp_state.compute_device + for fqn, _, _ in _param_name_infos(module, fsdp_state): + if not _is_composable(fsdp_state): + fqn_from_global_root = f'{prefix}{FSDP_PREFIX}{fqn}' else: - break - scatter_shape = list(tensor.size()) - scatter_shape[self.dim] = get_chunked_dim_size(sharding_dim_size, chunks, 0) # type: ignore[index] - - for shard_meta in tensor_meta.shards_metadata: - rank, device = _parse_and_validate_remote_device(process_group, shard_meta.placement) - if current_rank == src_rank: - # Reshape to get shard for this rank and we don't want autograd - # recording here for the narrow op and 'local_shard' should be a - # leaf variable in the autograd graph. - narrowed_tensor = narrow_tensor(tensor, shard_meta) - if shard_meta.shard_sizes[self.dim] < scatter_shape[self.dim]: # type: ignore[index] - # for the last shard that might be smaller to other shards - # resize the narrowed tensor to the same size and use it for - # the scatter collective as dist.scatter requires same size - # inputs on every rank - tensor_to_scatter = (narrowed_tensor.detach().clone().resize_(scatter_shape)) + fqn_from_global_root = f'{prefix}{fqn}' + try: + param = state_dict.pop(fqn_from_global_root) + except KeyError: + logger.warning(f'Did not find param with FQN {fqn_from_global_root}, skipping it. ' # noqa: G004 + 'The weight will not be filled if you expect it to be.') + continue # TODO: Improve unittesting for state_dict finetuning + # cases: https://github.com/pytorch/pytorch/issues/109134 + + if not fsdp_state._state_dict_config.use_dtensor: + # All-gather the param (ShardedTensor) + param, shards = _ext_pre_load_state_dict_transform(param) + + assert len(shards) < 2, ('Expects 0 or 1 shard per rank ' + f'but got {len(shards)} shards on rank {fsdp_state.rank}.') + param_numel = param.size().numel() + dim_0_size = param.size()[0] + chunk_size = (math.ceil(dim_0_size / fsdp_state.world_size) * param_numel // dim_0_size) + if len(shards) == 1: + local_tensor = shards[0].tensor.flatten() + pg_device = _get_pg_default_device(fsdp_state.process_group) + if local_tensor.device.type != pg_device.type: + local_tensor = local_tensor.to(pg_device) + num_padding = chunk_size - local_tensor.numel() + if num_padding > 0: + local_tensor = F.pad(local_tensor, [0, num_padding]) else: - tensor_to_scatter = narrowed_tensor.detach().clone().contiguous() - - tensors_to_scatter[rank] = tensor_to_scatter # type: ignore[index] - - if current_rank == rank: - local_tensor = torch.empty(scatter_shape, dtype=tensor.dtype, layout=tensor.layout, device=device) - local_metadata = shard_meta - - # each rank should have local_tensor and local_metadata initialized if we build - # the metadata list in a correct way. - assert local_tensor is not None - assert local_metadata is not None - - # Scatter the shards to all ranks in the pg - # scatter takes the global rank as ``src`` - src_for_scatter = src_rank - if (process_group is not None and process_group is not distributed.distributed_c10d._get_default_group()): - src_for_scatter = distributed.distributed_c10d.get_global_rank(process_group, src_for_scatter) - - distributed.scatter( - local_tensor, - scatter_list=tensors_to_scatter if current_rank == src_rank else None, - src=src_for_scatter, - group=process_group, - ) - - if list(local_tensor.size()) != local_metadata.shard_sizes: - # detach again after receiving to ensure local shards remain a leaf node - local_tensor = local_tensor.resize_(local_metadata.shard_sizes).detach() - - # Sync requires_grad to local_shard. - local_tensor.requires_grad = tensor.requires_grad - - local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata)) + local_tensor = torch.zeros(chunk_size, dtype=param.dtype, device=device) + tensor = torch.empty( + chunk_size * fsdp_state.world_size, + dtype=local_tensor.dtype, + device=device, + ) + if local_tensor.is_cpu: + # Tensor could be on FSDP GPU compute device, while local_tensor is on CPU. + # Convert to CPU so all_gather can work. + tensor_dev = tensor.device + tensor = tensor.cpu() + tensor_list = list(torch.chunk(tensor, torch.distributed.get_world_size(fsdp_state.process_group))) + torch.distributed.all_gather(tensor_list, local_tensor, group=fsdp_state.process_group) + tensor.to(tensor_dev) + else: + torch.distributed.all_gather_into_tensor(tensor, local_tensor, group=fsdp_state.process_group) + tensor = tensor.narrow(0, 0, param_numel).reshape(param.size()) + state_dict[fqn_from_global_root] = tensor + else: + if param.device != fsdp_state._device_mesh.device_type: + param = param.to(fsdp_state._device_mesh.device_type) - st = ShardedTensor._init_from_local_shards_and_global_metadata(local_shards, - tensor_meta, - process_group=process_group) + param = param.redistribute(device_mesh=param.device_mesh, placements=[Replicate()]) + state_dict[fqn_from_global_root] = param.to_local() - # Manually set sharding_spec - st._sharding_spec = self - self.dim = 0 - return st + _enter_unshard_params_ctx(module, fsdp_state, writeback=True) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 683ed666f5..6e7013a2a2 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -406,15 +406,16 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): model_state_dict = {'state': {'model': state.get_model_state_dict()}} else: cur_state_dict = state.state_dict() - if ignore_keys: - # Filter provided list of key paths - if not callable(ignore_keys): - ignore_keys = glob_filter(ignore_keys) - # Call function to modify state_dict - ignore_keys(cur_state_dict) cur_state_dict.pop('optimizers') model_state_dict = {'state': cur_state_dict} + if ignore_keys: + # Filter provided list of key paths + if not callable(ignore_keys): + ignore_keys = glob_filter(ignore_keys) + # Call function to modify state_dict + ignore_keys(model_state_dict) + dist_cp.load_state_dict(model_state_dict, storage_reader) state.load_state_dict(