diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index bf6ebaa228..d4ce020780 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -41,7 +41,7 @@ def patch_pytorch(): ChunkShardingSpec.build_metadata = build_metadata elif version.parse(torch.__version__) < version.parse('2.1.1'): - # Monkey path for torch < 2.1.1 ie torch == 2.1.0 + # Monkey patch for torch < 2.1.1 ie torch == 2.1.0 # Monkey patch sharding method ChunkShardingSpec.build_metadata = build_metadata @@ -55,8 +55,21 @@ def patch_pytorch(): _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None elif version.parse(torch.__version__) < version.parse('2.2.0'): - # Monkey path for torch < 2.2.0 ie torch == 2.1.1, 2.1.2 + # Monkey patch for torch < 2.2.0 ie torch == 2.1.1, 2.1.2 # Allow 2D HSDP from torch.distributed.fsdp import _runtime_utils _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None + + # Better overlap communication and computation + from composer.trainer.mosaic_fsdp_utils import _share_state_and_init_handle_attrs_t2p1 + _runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p1 + + elif version.parse(torch.__version__) < version.parse('2.2.1'): + # Monkey patch for torch < 2.2.1 ie torch == 2.2.0 + + # Better overlap communication and computation + from torch.distributed.fsdp import _runtime_utils + + from composer.trainer.mosaic_fsdp_utils import _share_state_and_init_handle_attrs_t2p2 + _runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p2 diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index da08772a63..5f90c34049 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -21,6 +21,7 @@ from torch.distributed import ProcessGroup from torch.distributed._shard.sharding_spec import ShardMetadata from torch.distributed._shard.sharding_spec._internals import get_chunked_dim_size, get_split_size +from torch.distributed.distributed_c10d import get_process_group_ranks from torch.distributed.fsdp import (BackwardPrefetch, CPUOffload, FullyShardedDataParallel, MixedPrecision, ShardingStrategy) from torch.distributed.fsdp._fsdp_extensions import _ext_pre_load_state_dict_transform @@ -31,7 +32,7 @@ if TYPE_CHECKING: if version.parse(torch.__version__) >= version.parse('2.0.1') and version.parse( - torch.__version__) < version.parse('2.0.2'): + torch.__version__) < version.parse('2.2.0'): from torch.distributed.fsdp._common_utils import _FSDPState log = logging.getLogger(__name__) @@ -753,3 +754,208 @@ def _sharded_pre_load_state_dict_hook( state_dict[fqn_from_global_root] = param.to_local() _enter_unshard_params_ctx(module, fsdp_state, writeback=True) + + +def fsdp_state_has_default_pg(state: '_FSDPState') -> bool: + """Indicates whether FlatParamHandle has the default process group. + + Args: + handle (_FSDPState): FSDP State object + + Returns: + bool: True if the ProcessGroup of the _FSDPState object is the default process group. False + otherwise. + """ + if state.process_group is None: + # If no process group is attached to the _FSDPState, assume it uses default process group. + return True + return len(get_process_group_ranks(state.process_group)) == dist.get_world_size() + + +def fsdp_state_pg_ranks(state: '_FSDPState') -> Tuple[int, ...]: + """Gets the ranks included in the ProcessGroup of an _FSDPState. + + Args: + state (_FSDPState): FSDP State object + + Returns: + Tuple[int]: Ranks for the FSDP State's process group. + """ + if state.process_group is None: + # If no process group is attached to the _FSDPState, assume it uses default process group. + return tuple(range(dist.get_world_size())) + else: + return tuple(get_process_group_ranks(state.process_group)) + + +@no_type_check +def _share_state_and_init_handle_attrs_t2p1( + root_state: '_FSDPState', + root_module: nn.Module, +) -> None: + """Shares state from ``root_state`` to other FSDP states. + + Shares data structure state from the ``root_state`` to all FSDP states in + ``root_module`` 's module tree, and initializes handle attributes. These are + done together to require a single loop over the states. This function has + been modified to assign a different unshard stream to each process group. + """ + from torch.distributed.fsdp._runtime_utils import (HOMOGENEOUS_ATTR_NAMES, _init_device_mesh, + _validate_and_get_hybrid_shard_state) + from torch.distributed.utils import _p_assert + + handle = root_state._handle + if handle: + handle.init_flat_param_attributes() + _validate_and_get_hybrid_shard_state(root_module) + attr_name_to_values: Dict[str, Set[Any]] = {} + for attr_name in HOMOGENEOUS_ATTR_NAMES: + attr_name_to_values[attr_name] = set() + root_state._all_handles = root_state._exec_order_data.all_handles # share reference + root_state._device_mesh = _init_device_mesh(root_state) + # Update _has_optim_in_backward for each handle. + for handle in root_state._all_handles: + flat_param = handle.flat_param + if hasattr(flat_param, '_in_backward_optimizers'): + raise RuntimeError('FSDP optimizer in backward only supported with use_orig_params=True!') + handle._has_optim_in_backward = flat_param._params is not None and any( + hasattr(param, '_in_backward_optimizers') for param in flat_param._params) + + # Patching so that _FSDPStates with different process groups have separate unshard streams. + # Keep track of any new unshard streams we may have to add for specific process groups. + fsdp_pg_unshard_streams = {} + unshard_priority = root_state._unshard_stream.priority + for fsdp_state in root_state._all_fsdp_states: + for attr_name in HOMOGENEOUS_ATTR_NAMES: + _p_assert( + hasattr(fsdp_state, attr_name), + f'FSDP state missing attribute {attr_name}', + ) + attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name)) + if fsdp_state is root_state: + continue + # Relax the assert for non-root FSDP instances in case the nested + # initialized module is wrapped again in FSDP later (e.g. after + # training to run inference) + _p_assert( + fsdp_state._is_root is None or not fsdp_state._is_root, + "Non-root FSDP instance's `_is_root` should not have been " + 'set yet or should have been set to `False`', + ) + fsdp_state._is_root = False + + # Take care of any new unshard streams we have to create for non-default process groups. + if fsdp_state_has_default_pg(fsdp_state): + # If using default process group, unshard stream is the same as root fsdp instance. + fsdp_state._unshard_stream = root_state._unshard_stream + else: + # Otherwise, unshard stream is separate. + state_pg_ranks = fsdp_state_pg_ranks(fsdp_state) + if state_pg_ranks in fsdp_pg_unshard_streams: + # We have created the unshard stream for this process group already. Use it. + fsdp_state._unshard_stream = fsdp_pg_unshard_streams[state_pg_ranks] + else: + # We don't have an unshard stream for this process group yet. Make it. + fsdp_state._unshard_stream = fsdp_state._device_handle.Stream(priority=unshard_priority) + fsdp_pg_unshard_streams[state_pg_ranks] = fsdp_state._unshard_stream + + # All other stream assignments stay common across all of FSDP. + fsdp_state._post_backward_stream = root_state._post_backward_stream + fsdp_state._pre_unshard_stream = root_state._pre_unshard_stream + fsdp_state._all_reduce_stream = root_state._all_reduce_stream + fsdp_state._default_stream = root_state._default_stream + fsdp_state._exec_order_data = root_state._exec_order_data + fsdp_state._free_event_queue = root_state._free_event_queue + fsdp_state._device_mesh = root_state._device_mesh + handle = fsdp_state._handle + if handle: + handle.init_flat_param_attributes() + for attr_name, attr_values in attr_name_to_values.items(): + if len(attr_values) != 1: + raise ValueError(f'Expects one homogeneous value for {attr_name} but got {attr_values}') + + +@no_type_check +def _share_state_and_init_handle_attrs_t2p2( + root_state: '_FSDPState', + root_module: nn.Module, +) -> None: + """Shares state from ``root_state`` to other FSDP states. + + Shares data structure state from the ``root_state`` to all FSDP states in + ``root_module`` 's module tree, and initializes handle attributes. These are + done together to require a single loop over the states. This function has + been modified to assign a different unshard stream to each process group. + """ + from torch.distributed.fsdp._runtime_utils import HOMOGENEOUS_ATTR_NAMES, _validate_and_get_hybrid_shard_state + from torch.distributed.utils import _p_assert + + handle = root_state._handle + if handle: + handle.init_flat_param_attributes() + _validate_and_get_hybrid_shard_state(root_module) + attr_name_to_values: Dict[str, Set[Any]] = {} + for attr_name in HOMOGENEOUS_ATTR_NAMES: + attr_name_to_values[attr_name] = set() + root_state._all_handles = root_state._exec_order_data.all_handles # share reference + # Update _has_optim_in_backward for each handle. + for handle in root_state._all_handles: + flat_param = handle.flat_param + if hasattr(flat_param, '_in_backward_optimizers'): + raise RuntimeError('FSDP optimizer in backward only supported with use_orig_params=True!') + handle._has_optim_in_backward = flat_param._params is not None and any( + hasattr(param, '_in_backward_optimizers') for param in flat_param._params) + if handle._has_optim_in_backward: + torch._C._log_api_usage_once('fsdp.optimizer_in_backward') + + # Patching so that _FSDPStates with different process groups have separate unshard streams. + # Keep track of any new unshard streams we may have to add for specific process groups. + fsdp_pg_unshard_streams = {} + unshard_priority = root_state._unshard_stream.priority + for fsdp_state in root_state._all_fsdp_states: + for attr_name in HOMOGENEOUS_ATTR_NAMES: + _p_assert( + hasattr(fsdp_state, attr_name), + f'FSDP state missing attribute {attr_name}', + ) + attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name)) + if fsdp_state is root_state: + continue + # Relax the assert for non-root FSDP instances in case the nested + # initialized module is wrapped again in FSDP later (e.g. after + # training to run inference) + _p_assert( + fsdp_state._is_root is None or not fsdp_state._is_root, + "Non-root FSDP instance's `_is_root` should not have been " + 'set yet or should have been set to `False`', + ) + fsdp_state._is_root = False + + # Take care of any new unshard streams we have to create for non-default process groups. + if fsdp_state_has_default_pg(fsdp_state): + # If using default process group, unshard stream is the same as root fsdp instance. + fsdp_state._unshard_stream = root_state._unshard_stream + else: + # Otherwise, unshard stream is separate. + state_pg_ranks = fsdp_state_pg_ranks(fsdp_state) + if state_pg_ranks in fsdp_pg_unshard_streams: + # We have created the unshard stream for this process group already. Use it. + fsdp_state._unshard_stream = fsdp_pg_unshard_streams[state_pg_ranks] + else: + # We don't have an unshard stream for this process group yet. Make it. + fsdp_state._unshard_stream = fsdp_state._device_handle.Stream(priority=unshard_priority) + fsdp_pg_unshard_streams[state_pg_ranks] = fsdp_state._unshard_stream + + # All other stream assignments stay common across all of FSDP. + fsdp_state._post_backward_stream = root_state._post_backward_stream + fsdp_state._pre_unshard_stream = root_state._pre_unshard_stream + fsdp_state._all_reduce_stream = root_state._all_reduce_stream + fsdp_state._default_stream = root_state._default_stream + fsdp_state._exec_order_data = root_state._exec_order_data + fsdp_state._free_event_queue = root_state._free_event_queue + handle = fsdp_state._handle + if handle: + handle.init_flat_param_attributes() + for attr_name, attr_values in attr_name_to_values.items(): + if len(attr_values) != 1: + raise ValueError(f'Expects one homogeneous value for {attr_name} but got {attr_values}')