Skip to content

Commit

Permalink
Torch 2.3 patch (mosaicml#2849)
Browse files Browse the repository at this point in the history
* add monkeypatch for verify_options

* patch

* fix

* fix

* partial precommit

* bit of cleanup

* doc

* debug

* fix version pinning

* precommit

* checkdown

* lint

---------

Co-authored-by: Evan Racah <[email protected]>
Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
3 people authored Jan 14, 2024
1 parent 027c3d0 commit a2ae299
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 22 deletions.
3 changes: 2 additions & 1 deletion composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
self._last_flush_time = time.time()
self._flush_interval = flush_interval

self._experiment_id = None
self._experiment_id: Optional[str] = None
self._run_id = None

if self._enabled:
Expand Down Expand Up @@ -150,6 +150,7 @@ def init(self, state: State, logger: Logger) -> None:
self._run_id = env_run_id
else:
# Search for an existing run tagged with this Composer run.
assert self._experiment_id is not None
existing_runs = mlflow.search_runs(experiment_ids=[self._experiment_id],
filter_string=f'tags.composer_run_name = "{state.run_name}"',
output_format='list')
Expand Down
35 changes: 35 additions & 0 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,38 @@ def patch_pytorch():
# Monkeypatch dtensor support
from composer.trainer.mosaic_fsdp_utils import init_fn_t2p2p0
FullyShardedDataParallel.__init__ = init_fn_t2p2p0 # type: ignore

# Monkeypath state_dict
from torch.distributed.checkpoint import state_dict # type: ignore

from composer.trainer.mosaic_fsdp_utils import _verify_options_t2p2p0
state_dict._verify_options = _verify_options_t2p2p0

elif version.parse(torch.__version__) < version.parse('2.3.1'):
# Monkey patch for torch < 2.3.1 ie torch == 2.3.0
# Note: this is the same patch as 2.2.0, we are just making a new if branch
# for clarity and modularity of changes.

# Allow 2D HSDP
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

# # Better overlap communication and computation
# from torch.distributed.fsdp import _runtime_utils

# from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p2,
# _wait_for_computation_stream, forward)
# _runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p2
# _runtime_utils._wait_for_computation_stream = _wait_for_computation_stream
# _runtime_utils._root_pre_forward = _root_pre_forward
# FullyShardedDataParallel.forward = forward

# Monkeypath state_dict
from composer.trainer.mosaic_fsdp_utils import init_fn_t2p2p0
FullyShardedDataParallel.__init__ = init_fn_t2p2p0

# Monkeypath state_dict
from torch.distributed.checkpoint import state_dict # type: ignore

from composer.trainer.mosaic_fsdp_utils import _verify_options_t2p2p0
state_dict._verify_options = _verify_options_t2p2p0
113 changes: 92 additions & 21 deletions composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import logging
import math
import warnings
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Set, Tuple, Union, cast, no_type_check
import contextlib
from dataclasses import asdict
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union, cast, no_type_check

import torch
import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
Expand All @@ -38,6 +40,7 @@
torch.__version__) < version.parse('2.2.0'):
from torch.distributed.fsdp._common_utils import _FSDPState


log = logging.getLogger(__name__)

SHARDING_MAP = {
Expand Down Expand Up @@ -227,7 +230,7 @@ def _custom_recursive_wrap_t1p13p1(
modified version of
https://github.com/pytorch/pytorch/blob/d922c29a22e4bf0fba49526f7536395eb8cd66f4/torch/distributed/fsdp/wrap.py#L353
which recursively wraps modules as FSDP modules for parameter sharding.
This modification enables the user to pass custom FSDP arguements for every wrapped module.
This modification enables the user to pass custom FSDP arguments for every wrapped module.
The added process_group_cache enables different FSDP modules to, when appropriate, use the
same process group instead of instantiating a new process group.
Expand Down Expand Up @@ -318,7 +321,7 @@ def custom_auto_wrap_t1p13p1(
modified version of
https://github.com/pytorch/pytorch/blob/d922c29a22e4bf0fba49526f7536395eb8cd66f4/torch/distributed/fsdp/fully_sharded_data_parallel.py#L1252
FSDP's _auto_wrap recursively wraps modules as FSDP modules for parameter sharding.
This modification enables the user to pass custom FSDP arguements for every wrapped module.
This modification enables the user to pass custom FSDP arguments for every wrapped module.
The added process_group_cache enables different FSDP modules to, when appropriate, use the
same process group instead of instantiating a new process group.
Expand Down Expand Up @@ -373,7 +376,7 @@ def _custom_recursive_wrap_t2p0p1(
modified version of
https://github.com/pytorch/pytorch/blob/96ca226a7332be0d8f3d6159d0c797e032ab0721/torch/distributed/fsdp/wrap.py#L320
which recursively wraps modules as FSDP modules for parameter sharding.
This modification enables the user to pass custom FSDP arguements for every wrapped module.
This modification enables the user to pass custom FSDP arguments for every wrapped module.
The added process_group_cache enables different FSDP modules to, when appropriate, use the
same process group instead of instantiating a new process group.
Expand Down Expand Up @@ -471,7 +474,7 @@ def _custom_auto_wrap_t2p0p1(
modified version of
https://github.com/pytorch/pytorch/blob/96ca226a7332be0d8f3d6159d0c797e032ab0721/torch/distributed/fsdp/_wrap_utils.py#L31
FSDP's _auto_wrap recursively wraps modules as FSDP modules for parameter sharding.
This modification enables the user to pass custom FSDP arguements for every wrapped module.
This modification enables the user to pass custom FSDP arguments for every wrapped module.
The added process_group_cache enables different FSDP modules to, when appropriate, use the
same process group instead of instantiating a new process group.
Expand Down Expand Up @@ -760,7 +763,7 @@ def _sharded_pre_load_state_dict_hook(


if version.parse(torch.__version__) > version.parse('2.1.3') and version.parse(
torch.__version__) < version.parse('2.2.1'):
torch.__version__) < version.parse('2.3.1'):
import copy

from torch.distributed._tensor import DeviceMesh, DTensor, Replicate
Expand All @@ -771,7 +774,7 @@ def _sharded_pre_load_state_dict_hook(
from torch.distributed.fsdp._common_utils import _FSDPState
from torch.distributed.fsdp._init_utils import (HYBRID_SHARDING_STRATEGIES, ProcessGroupType,
_get_default_comm_hook_state, _init_intra_and_inter_node_groups,
_is_valid_hybrid_shard_pg_type)
_is_valid_hybrid_shard_pg_type, _init_extension)
from torch.distributed.fsdp.fully_sharded_data_parallel import (_annotate_modules_for_dynamo, _auto_wrap,
_check_orig_params_flattened, _init_buffer_state,
_init_core_state, _init_device_handle,
Expand Down Expand Up @@ -826,7 +829,7 @@ def chunk_dtensor_t2p2p0(
tensor = tensor.clone().detach()

# When a layer is not involved in TP, then the tensor will not be a DTensor.
# e.g. When a layer is not sppecified in the parallelize_plan, TP will have no effect on the layer.
# e.g. When a layer is not specified in the parallelize_plan, TP will have no effect on the layer.
# e.g. When you do PairwiseParallel on a 3 layer model, TP will have no effect on the third layer.
if isinstance(tensor, torch.Tensor) and not isinstance(tensor, DTensor):

Expand Down Expand Up @@ -869,17 +872,6 @@ def chunk_dtensor_t2p2p0(
DTensorExtensions.all_gather_dtensor = all_gather_dtensor_t2p2p0
DTensorExtensions.chunk_dtensor = chunk_dtensor_t2p2p0

def _init_extension_t2p2p0(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState:
# TODO: we need to add additional check once we support FSDP + PiPPy.
# This check is currently sufficient, since we only support FSDP + TP.
if device_mesh and _mesh_resources.get_parent_mesh(state._device_mesh) is not None:
state._fsdp_extension = DTensorExtensions()
else:
# We need to explicilty set _fsdp_extension to None.
# Otherwise, we will run into an infinite recursion when getting the attribute.
state._fsdp_extension = None
return state

def _is_valid_hybrid_shard_device_mesh_t2p2p0(device_mesh: DeviceMesh) -> bool:
#parent_mesh = _mesh_resources.get_parent_mesh(device_mesh)
#if parent_mesh is not None:
Expand Down Expand Up @@ -1052,7 +1044,7 @@ def init_fn_t2p2p0(
_init_prefetching_state(self, backward_prefetch, forward_prefetch)
_init_buffer_state(self, module)
# extension needs to be set before `_init_param_handle_from_module()`
_init_extension_t2p2p0(self, device_mesh)
_init_extension(self, device_mesh)
_init_param_handle_from_module(
self,
module,
Expand All @@ -1070,6 +1062,85 @@ def init_fn_t2p2p0(
_init_state_dict_state(self)
_register_all_state_dict_hooks(self)

from torch.distributed.checkpoint.state_dict import StateDictOptions, _StateDictInfo

def _verify_options_t2p2p0(
model: nn.Module,
optims: Tuple[torch.optim.Optimizer, ...],
optim_only: bool,
*,
submodules: Optional[Set[nn.Module]] = None,
options: Optional[StateDictOptions] = None,
) -> _StateDictInfo:
"""Verify the model and options passed by the user and generates _StateDictInfo."""
from torch.distributed.checkpoint.state_dict import StateDictOptions, _get_fqns, _StateDictInfo
from torch.distributed.fsdp import FullOptimStateDictConfig, FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import (OptimStateDictConfig, ShardedOptimStateDictConfig, ShardedStateDictConfig,
StateDictConfig, StateDictType)

if optim_only and not optims:
raise RuntimeError('Optimizers are not passed in but optim_only is set to True.')

options = options or StateDictOptions()
assert options is not None # pyright

fqn_param_mapping: Dict[Union[str, torch.Tensor], Union[Set[str], torch.Tensor]] = {}
all_fqns = set()
for name, param in model.named_parameters():
fqns = _get_fqns(model, name)
fqns = {fqn.replace('_checkpoint_wrapped_module.', '') for fqn in fqns}
fqn_param_mapping[param] = fqns
for fqn in fqns:
fqn_param_mapping[fqn] = param
all_fqns.add(fqn)

submodule_prefixes = set()
if submodules:
submodules = set(submodules)
for name, module in model.named_modules():
if module not in submodules:
continue
fqns = _get_fqns(model, name)
assert len(fqns) == 1, 'Submodule FQN should only have 1 instance'
for fqn in fqns:
submodule_prefixes.add(f'{fqn}.')
fsdp_modules = FSDP.fsdp_modules(model)
state_dict_config: StateDictConfig
optim_state_dict_config: OptimStateDictConfig
fsdp_context: Callable
if fsdp_modules:
# FSDP API only work if at least one FSDP instance exists.
if options.full_state_dict:
state_dict_config = FullStateDictConfig(offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload)
optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=options.cpu_offload,
rank0_only=options.cpu_offload)
state_dict_type = StateDictType.FULL_STATE_DICT
else:
state_dict_config = ShardedStateDictConfig()
optim_state_dict_config = ShardedOptimStateDictConfig(offload_to_cpu=options.cpu_offload,)
state_dict_type = StateDictType.SHARDED_STATE_DICT

fsdp_context = functools.partial(
FSDP.state_dict_type,
module=model,
state_dict_type=state_dict_type,
state_dict_config=state_dict_config,
optim_state_dict_config=optim_state_dict_config,
)
else:
fsdp_context = contextlib.nullcontext
return _StateDictInfo(
**asdict(options),
fqn_param_mapping=fqn_param_mapping,
all_fqns=all_fqns,
submodule_prefixes=submodule_prefixes,
fsdp_context=fsdp_context,
fsdp_modules=cast(List[nn.Module], fsdp_modules),
handle_model=not optim_only,
handle_optim=(len(optims) > 0),
)


def fsdp_state_has_default_pg(state: '_FSDPState') -> bool:
"""Indicates whether FlatParamHandle has the default process group.
Expand Down Expand Up @@ -1153,7 +1224,7 @@ def _root_pre_forward(
_p_assert(state._is_root is not None, 'Expects a root FSDP to have been set')
if not state._is_root:
# Always cast forward inputs in the root of this local FSDP unit for mixed
# precision, as this is where mixed precision could be configed.
# precision, as this is where mixed precision could be configured.
# This is more useful for auto wrapping that is recommended in composable path.
# For manual wrapping, cast forward inputs on each local FSDP unit root will
# increase some overhead, so not turned on for model wrapper path right now where
Expand Down

0 comments on commit a2ae299

Please sign in to comment.