Skip to content

Commit

Permalink
Add partial state dict functionality for FSDP (mosaicml#2637)
Browse files Browse the repository at this point in the history
* Use pytorch chunking

commit-id:e4c9b78f

* Add partial state dict functionality for FSDP

commit-id:2a2cae33
  • Loading branch information
b-chu authored Oct 13, 2023
1 parent 1c9d8d1 commit 28cc4ca
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 163 deletions.
10 changes: 7 additions & 3 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed.fsdp import FullyShardedDataParallel

from composer.trainer.mosaic_fsdp_utils import build_metadata, custom_auto_wrap_t1p13p1, shard
from composer.trainer.mosaic_fsdp_utils import (_sharded_pre_load_state_dict_hook, build_metadata,
custom_auto_wrap_t1p13p1)


def patch_pytorch():
Expand All @@ -33,18 +34,21 @@ def patch_pytorch():

# Monkey patch __init__ where __init__ calls the custom _auto_wrap fn
from composer.trainer.mosaic_fsdp_utils import init_fn_t2p0p1

FullyShardedDataParallel.__init__ = init_fn_t2p0p1 # type: ignore

# Monkey patch sharding method
ChunkShardingSpec.build_metadata = build_metadata
ChunkShardingSpec.shard = shard

elif version.parse(torch.__version__) < version.parse('2.1.1'):
# Monkey path for torch < 2.1.1 ie torch == 2.1.0
from torch.distributed.fsdp import _state_dict_utils

# Monkey patch sharding method
ChunkShardingSpec.build_metadata = build_metadata
ChunkShardingSpec.shard = shard

# Monkey patch partial state dict handling
_state_dict_utils._sharded_pre_load_state_dict_hook = (_sharded_pre_load_state_dict_hook)

elif version.parse(torch.__version__) >= version.parse('2.1.1'):
raise NotImplementedError(f'FullyShardedDataParallel is not supported for torch >= 2.2.0')
268 changes: 114 additions & 154 deletions composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,33 @@

import functools
import inspect
import logging
import math
import warnings
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Set, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Set, Tuple, Union, cast, no_type_check

import torch
import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
import torch.nn as nn
import torch.nn.functional as F
from packaging import version
from torch import distributed
from torch.distributed import ProcessGroup
from torch.distributed._shard._utils import narrow_tensor
from torch.distributed._shard.sharded_tensor.shard import Shard
from torch.distributed._shard.sharded_tensor.utils import _parse_and_validate_remote_device
from torch.distributed._shard.sharding_spec import ChunkShardingSpec, ShardMetadata
from torch.distributed._shard.sharding_spec import ShardMetadata
from torch.distributed._shard.sharding_spec._internals import get_chunked_dim_size, get_split_size
from torch.distributed.fsdp import (BackwardPrefetch, CPUOffload, FullyShardedDataParallel, MixedPrecision,
ShardingStrategy)
from torch.distributed.fsdp._fsdp_extensions import _ext_pre_load_state_dict_transform
from torch.distributed.utils import _replace_by_prefix

from composer.core import Precision
from composer.utils import dist

if TYPE_CHECKING:
# Only include ShardedTensor when do type checking, exclude it
# from run-time to resolve circular dependency.
from torch.distributed._shard.sharded_tensor import ShardedTensor
if version.parse(torch.__version__) >= version.parse('2.0.1') and version.parse(
torch.__version__) < version.parse('2.0.2'):
from torch.distributed.fsdp._common_utils import _FSDPState

SHARDING_MAP = {
'NO_SHARD': ShardingStrategy.NO_SHARD,
Expand All @@ -45,6 +48,8 @@
'BACKWARD_POST': BackwardPrefetch.BACKWARD_POST,
}

logger = logging.getLogger(__name__)


def _get_torch_dtype(dtype: Union[Precision, str]):
"""Convert common string representations of dtypes to torch dtypes."""
Expand Down Expand Up @@ -619,50 +624,14 @@ def init_fn_t2p0p1(
_register_all_state_dict_hooks(self)


def get_split_size(dim_size: int, chunks: int) -> int:
"""Gets the minimum size per chunk.
A tensor of dim_size 5 and 4 chunks will have chunks of size [2, 1, 1, 1].
Args:
dim_size(int): Size of the dimension being chunked.
chunks(int): Number of chunks to create for ``dim_size``.
Returns:
An int indicating the split size to use.
"""
return dim_size // chunks


def get_chunked_dim_size(dim_size: int, chunks: int, idx: int) -> int:
"""Computes the dim size of the chunk for provided ``idx`` given ``dim_size`` and ``chunks``.
A tensor of dim_size 5 and 4 chunks will have chunks of size [2, 1, 1, 1].
Args:
dim_size(int): Size of the dimension being chunked.
chunks(int): Number of chunks to create for ``dim_size``.
idx(int): The index of chunk whose dim size is being requested.
Returns:
An int indicating the dim size of the chunk.
"""
assert idx >= 0 and idx < chunks
split_size = get_split_size(dim_size, chunks)
if idx < dim_size % chunks:
split_size += 1
return split_size


def build_metadata(
self: ChunkShardingSpec,
self,
tensor_sizes: torch.Size,
tensor_properties: sharded_tensor_meta.TensorProperties,
) -> sharded_tensor_meta.ShardedTensorMetadata:
"""Updates ChunkShardingSpec's build_metadata fn to use a dynamic sharding dimension.
"""Adds nightly change for ChunkShardingSpec.
modified version of
https://github.com/pytorch/pytorch/blob/v2.0.1/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py#L77
Change implemented in https://github.com/pytorch/pytorch/pull/108915
"""
tensor_num_dim = len(tensor_sizes)

Expand All @@ -671,122 +640,113 @@ def build_metadata(
raise ValueError(f'Invalid sharding dim: {self.dim}')

shards_metadata = []
while True:
sharding_dim_size = tensor_sizes[self.dim] # type: ignore[index]
chunks = len(self.placements)

if sharding_dim_size // chunks == 0:
self.dim += 1 # type: ignore[operator]
else:
break
current_offsets = [0] * tensor_num_dim
sharding_dim_size = tensor_sizes[self.dim] # type: ignore[index]
chunks = len(self.placements)
split_size = get_split_size(sharding_dim_size, chunks)
for idx, placement in enumerate(self.placements):
# generate ShardMetadata for each placement device
chunked_dim_size = get_chunked_dim_size(sharding_dim_size, chunks, idx)
if chunked_dim_size > 0:
shard_size = list(tensor_sizes)
shard_size[self.dim] = chunked_dim_size # type: ignore[index]

shard_metadata = ShardMetadata(
shard_offsets=current_offsets.copy(),
shard_sizes=shard_size,
placement=placement,
)
shards_metadata.append(shard_metadata)
chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
shard_size = list(tensor_sizes)
current_offsets = [0] * tensor_num_dim
current_offsets[self.dim] = split_size * idx # type: ignore[index]
shard_size[self.dim] = chunked_dim_size # type: ignore[index]

shard_metadata = ShardMetadata(
shard_offsets=current_offsets,
shard_sizes=shard_size,
placement=placement,
)
shards_metadata.append(shard_metadata)

current_offsets[self.dim] += chunked_dim_size # type: ignore[index]
self.dim = 0
return sharded_tensor_meta.ShardedTensorMetadata(shards_metadata, tensor_sizes, tensor_properties)


def shard(self: ChunkShardingSpec, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> 'ShardedTensor':
"""Updates ChunkShardingSpec's shard fn to use a dynamic sharding dimension.
@no_type_check
def _sharded_pre_load_state_dict_hook(
module: nn.Module,
fsdp_state: '_FSDPState',
state_dict: Dict[str, Any],
prefix: str,
) -> None:
"""Adds nightly change for partial state dict error handling.
modified version of
https://github.com/pytorch/pytorch/blob/v2.0.1/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py#L116
https://github.com/pytorch/pytorch/blob/0511df0ee9edeb5c2613805ccfb49beb323b87f9/torch/distributed/fsdp/_state_dict_utils.py#L607-L615
The hook combines the unflattened, sharded parameters (ShardedTensor) to
a new FlatParameter and shards the new FlatParameter to the local chunk.
"""
# relative imports to avoid circular dependency
from torch.distributed._shard.sharded_tensor import ShardedTensor

tensor_properties = sharded_tensor_meta.TensorProperties(
dtype=tensor.dtype,
layout=tensor.layout,
requires_grad=tensor.requires_grad,
memory_format=torch.contiguous_format,
pin_memory=tensor.is_pinned(),
)
current_rank = distributed.get_rank(process_group)
tensor_meta = self.build_metadata(tensor.size(), tensor_properties)
local_shards = []
local_tensor = None
local_metadata = None
tensors_to_scatter = [None] * distributed.get_world_size(process_group)

while True:
sharding_dim_size = tensor.size()[self.dim] # type: ignore[index]
chunks = len(self.placements)

if sharding_dim_size // chunks == 0:
self.dim += 1 # type: ignore[operator]
from torch.distributed._tensor import Replicate
from torch.distributed.distributed_c10d import _get_pg_default_device
from torch.distributed.fsdp._common_utils import FSDP_PREFIX, _has_fsdp_params, _is_composable, _module_handle
from torch.distributed.fsdp._runtime_utils import _lazy_init
from torch.distributed.fsdp._state_dict_utils import _enter_unshard_params_ctx, _param_name_infos

_lazy_init(fsdp_state, module)
if not _is_composable(fsdp_state):
_replace_by_prefix(state_dict, prefix, prefix + f'{FSDP_PREFIX}')
if not _has_fsdp_params(fsdp_state, module):
return

handle = _module_handle(fsdp_state, module)
if not handle.uses_sharded_strategy:
raise RuntimeError('load_sharded_state_dict can only be called when parameters '
'are flattened and sharded.')

device = fsdp_state.compute_device
for fqn, _, _ in _param_name_infos(module, fsdp_state):
if not _is_composable(fsdp_state):
fqn_from_global_root = f'{prefix}{FSDP_PREFIX}{fqn}'
else:
break
scatter_shape = list(tensor.size())
scatter_shape[self.dim] = get_chunked_dim_size(sharding_dim_size, chunks, 0) # type: ignore[index]

for shard_meta in tensor_meta.shards_metadata:
rank, device = _parse_and_validate_remote_device(process_group, shard_meta.placement)
if current_rank == src_rank:
# Reshape to get shard for this rank and we don't want autograd
# recording here for the narrow op and 'local_shard' should be a
# leaf variable in the autograd graph.
narrowed_tensor = narrow_tensor(tensor, shard_meta)
if shard_meta.shard_sizes[self.dim] < scatter_shape[self.dim]: # type: ignore[index]
# for the last shard that might be smaller to other shards
# resize the narrowed tensor to the same size and use it for
# the scatter collective as dist.scatter requires same size
# inputs on every rank
tensor_to_scatter = (narrowed_tensor.detach().clone().resize_(scatter_shape))
fqn_from_global_root = f'{prefix}{fqn}'
try:
param = state_dict.pop(fqn_from_global_root)
except KeyError:
logger.warning(f'Did not find param with FQN {fqn_from_global_root}, skipping it. ' # noqa: G004
'The weight will not be filled if you expect it to be.')
continue # TODO: Improve unittesting for state_dict finetuning
# cases: https://github.com/pytorch/pytorch/issues/109134

if not fsdp_state._state_dict_config.use_dtensor:
# All-gather the param (ShardedTensor)
param, shards = _ext_pre_load_state_dict_transform(param)

assert len(shards) < 2, ('Expects 0 or 1 shard per rank '
f'but got {len(shards)} shards on rank {fsdp_state.rank}.')
param_numel = param.size().numel()
dim_0_size = param.size()[0]
chunk_size = (math.ceil(dim_0_size / fsdp_state.world_size) * param_numel // dim_0_size)
if len(shards) == 1:
local_tensor = shards[0].tensor.flatten()
pg_device = _get_pg_default_device(fsdp_state.process_group)
if local_tensor.device.type != pg_device.type:
local_tensor = local_tensor.to(pg_device)
num_padding = chunk_size - local_tensor.numel()
if num_padding > 0:
local_tensor = F.pad(local_tensor, [0, num_padding])
else:
tensor_to_scatter = narrowed_tensor.detach().clone().contiguous()

tensors_to_scatter[rank] = tensor_to_scatter # type: ignore[index]

if current_rank == rank:
local_tensor = torch.empty(scatter_shape, dtype=tensor.dtype, layout=tensor.layout, device=device)
local_metadata = shard_meta

# each rank should have local_tensor and local_metadata initialized if we build
# the metadata list in a correct way.
assert local_tensor is not None
assert local_metadata is not None

# Scatter the shards to all ranks in the pg
# scatter takes the global rank as ``src``
src_for_scatter = src_rank
if (process_group is not None and process_group is not distributed.distributed_c10d._get_default_group()):
src_for_scatter = distributed.distributed_c10d.get_global_rank(process_group, src_for_scatter)

distributed.scatter(
local_tensor,
scatter_list=tensors_to_scatter if current_rank == src_rank else None,
src=src_for_scatter,
group=process_group,
)

if list(local_tensor.size()) != local_metadata.shard_sizes:
# detach again after receiving to ensure local shards remain a leaf node
local_tensor = local_tensor.resize_(local_metadata.shard_sizes).detach()

# Sync requires_grad to local_shard.
local_tensor.requires_grad = tensor.requires_grad

local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata))
local_tensor = torch.zeros(chunk_size, dtype=param.dtype, device=device)
tensor = torch.empty(
chunk_size * fsdp_state.world_size,
dtype=local_tensor.dtype,
device=device,
)
if local_tensor.is_cpu:
# Tensor could be on FSDP GPU compute device, while local_tensor is on CPU.
# Convert to CPU so all_gather can work.
tensor_dev = tensor.device
tensor = tensor.cpu()
tensor_list = list(torch.chunk(tensor, torch.distributed.get_world_size(fsdp_state.process_group)))
torch.distributed.all_gather(tensor_list, local_tensor, group=fsdp_state.process_group)
tensor.to(tensor_dev)
else:
torch.distributed.all_gather_into_tensor(tensor, local_tensor, group=fsdp_state.process_group)
tensor = tensor.narrow(0, 0, param_numel).reshape(param.size())
state_dict[fqn_from_global_root] = tensor
else:
if param.device != fsdp_state._device_mesh.device_type:
param = param.to(fsdp_state._device_mesh.device_type)

st = ShardedTensor._init_from_local_shards_and_global_metadata(local_shards,
tensor_meta,
process_group=process_group)
param = param.redistribute(device_mesh=param.device_mesh, placements=[Replicate()])
state_dict[fqn_from_global_root] = param.to_local()

# Manually set sharding_spec
st._sharding_spec = self
self.dim = 0
return st
_enter_unshard_params_ctx(module, fsdp_state, writeback=True)
13 changes: 7 additions & 6 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,15 +406,16 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
model_state_dict = {'state': {'model': state.get_model_state_dict()}}
else:
cur_state_dict = state.state_dict()
if ignore_keys:
# Filter provided list of key paths
if not callable(ignore_keys):
ignore_keys = glob_filter(ignore_keys)
# Call function to modify state_dict
ignore_keys(cur_state_dict)
cur_state_dict.pop('optimizers')
model_state_dict = {'state': cur_state_dict}

if ignore_keys:
# Filter provided list of key paths
if not callable(ignore_keys):
ignore_keys = glob_filter(ignore_keys)
# Call function to modify state_dict
ignore_keys(model_state_dict)

dist_cp.load_state_dict(model_state_dict, storage_reader)

state.load_state_dict(
Expand Down

0 comments on commit 28cc4ca

Please sign in to comment.