diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index 06d0a26dda..0987656947 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -9,7 +9,6 @@ import torch from packaging import version from torch.distributed._shard.sharding_spec import ChunkShardingSpec -from torch.distributed.fsdp import FullyShardedDataParallel from composer.trainer.mosaic_fsdp_utils import (_sharded_pre_load_state_dict_hook, build_metadata, custom_auto_wrap_t1p13p1) @@ -62,14 +61,25 @@ def patch_pytorch(): _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None # Better overlap communication and computation - from composer.trainer.mosaic_fsdp_utils import _share_state_and_init_handle_attrs_t2p1 + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel + + from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p1, + _wait_for_computation_stream, forward) _runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p1 + _runtime_utils._wait_for_computation_stream = _wait_for_computation_stream + _runtime_utils._root_pre_forward = _root_pre_forward + FullyShardedDataParallel.forward = forward elif version.parse(torch.__version__) < version.parse('2.2.1'): # Monkey patch for torch < 2.2.1 ie torch == 2.2.0 # Better overlap communication and computation from torch.distributed.fsdp import _runtime_utils + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel - from composer.trainer.mosaic_fsdp_utils import _share_state_and_init_handle_attrs_t2p2 + from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p2, + _wait_for_computation_stream, forward) _runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p2 + _runtime_utils._wait_for_computation_stream = _wait_for_computation_stream + _runtime_utils._root_pre_forward = _root_pre_forward + FullyShardedDataParallel.forward = forward diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 2f8a575ebe..3cf26d79ec 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -788,6 +788,153 @@ def fsdp_state_pg_ranks(state: '_FSDPState') -> Tuple[int, ...]: return tuple(get_process_group_ranks(state.process_group)) +def _wait_for_computation_stream( + computation_stream: torch.Stream, + root_state: '_FSDPState', + pre_unshard_stream: torch.Stream, +): + """Unshard and pre-unshard streams wait for computation stream. + + Has the unshard and pre-unshard streams wait for the computation stream. + For example, this should be called in the FSDP root's pre-forward to + respect optimizer step computation. + """ + # Tracing does not need to wait + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + return + # Ensure all unshard streams wait for the computation stream. + unshard_streams = set() + for fsdp_state in root_state._all_fsdp_states: + unshard_streams.add(fsdp_state._unshard_stream) + for unshard_stream in unshard_streams: + unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined] + # Having the pre-all-gather stream wait for the current stream even if we + # do not leverage the pre-all-gather stream is tolerable since this only + # runs once per iteration + pre_unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined] + + +@no_type_check +def _root_pre_forward( + state: '_FSDPState', + module: nn.Module, + args, + kwargs, +) -> None: + """Runs pre-forward logic specific to the root FSDP instance. + + This should run before any individual module's pre-forward. This starts + with an attempt at lazy initialization (which only runs non-vacuously once). + Otherwise, if this is called on a non-root FSDP instance, then it returns + directly. + """ + from torch.distributed.fsdp._common_utils import _is_composable + from torch.distributed.fsdp._runtime_utils import (_cast_buffers_to_dtype_and_device, + _get_buffers_and_dtypes_for_computation, _lazy_init, + _reset_flat_param_grad_info_if_needed, _root_cast_forward_input) + from torch.distributed.utils import _p_assert, _to_kwargs + with torch.profiler.record_function('FullyShardedDataParallel._root_pre_forward'): + _lazy_init(state, module) + _p_assert(state._is_root is not None, 'Expects a root FSDP to have been set') + if not state._is_root: + # Always cast forward inputs in the root of this local FSDP unit for mixed + # precision, as this is where mixed precision could be configed. + # This is more useful for auto wrapping that is recommended in composable path. + # For manual wrapping, cast forward inputs on each local FSDP unit root will + # increase some overhead, so not turned on for model wrapper path right now where + # manual wrapping is more broadly used. + if _is_composable(state): + return _root_cast_forward_input(state, module, args, kwargs) + return args, kwargs + + # We cast buffers back to full precision if we're forcing full precision. Disjointly, we check if buffers + # are in full precision and if we should cast them back to lower precision, which happens when + # exiting eval() mode. + handle = state._handle + if handle: + should_cast_buffers_to_full_prec = handle._force_full_precision + else: + should_cast_buffers_to_full_prec = True + + if should_cast_buffers_to_full_prec: + _cast_buffers_to_dtype_and_device( + buffers=dict(module.named_buffers()).values(), + buffer_dtypes=list(state._buffer_name_to_orig_dtype.values()), + device=state.compute_device, + ) + # This flag is only set when we cast buffers to full precision, to avoid the + # CPU overhead that can stem from retrieving all buffers and their types in the + # following else branch. + state._needs_buffer_dtype_restore_check = True + elif getattr(state, '_needs_buffer_dtype_restore_check', False): + # Check if buffers are in full precision and we need to cast them + # back down. + ( + buffers, + buffer_dtypes_for_computation, + ) = _get_buffers_and_dtypes_for_computation(state, module) + if len(buffers) > 0 and len(buffer_dtypes_for_computation) > 0: + if any(buffer.dtype != buffer_dtype_for_computation + for buffer, buffer_dtype_for_computation in zip(buffers, buffer_dtypes_for_computation)): + # Assume we have to cast everything if there is one mismatch + _cast_buffers_to_dtype_and_device(buffers, buffer_dtypes_for_computation, state.compute_device) + # We don't have to check this again until we cast buffers to full precision again. + state._needs_buffer_dtype_restore_check = False + + if state.forward_prefetch: + handles = [] + for fsdp_state in state._all_fsdp_states: + if fsdp_state._handle: + handles.append(fsdp_state._handle) + for handle in handles: + handle._needs_pre_forward_unshard = True + handle._prefetched = False + + _wait_for_computation_stream( + state._device_handle.current_stream(), + state, + state._pre_unshard_stream, + ) + _reset_flat_param_grad_info_if_needed(state._all_handles) + + # Prepares the forward inputs by moving them to ``compute_device`` + # TODO: Do not use the side stream for tensor copies for now; investigate + # the perf with/without it. + with torch.profiler.record_function('FullyShardedDataParallel._to_kwargs'): + args_tuple, kwargs_tuple = _to_kwargs(args, kwargs, state.compute_device, False) + args = args_tuple[0] + kwargs = kwargs_tuple[0] + + return _root_cast_forward_input(state, module, args, kwargs) + + +def forward(self, *args: Any, **kwargs: Any) -> Any: + """Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic.""" + from torch.distributed.fsdp._runtime_utils import (_post_forward, _post_forward_reshard, _pre_forward, + _pre_forward_unshard) + from torch.distributed.utils import _p_assert + handle = self._handle + with torch.autograd.profiler.record_function('FullyShardedDataParallel.forward'): + args, kwargs = _root_pre_forward(self, self, args, kwargs) + unused = None + args, kwargs = _pre_forward( + self, + handle, + _pre_forward_unshard, + self._fsdp_wrapped_module, + args, + kwargs, + ) + if handle: + _p_assert( + handle.flat_param.device == self.compute_device, + 'Expected `FlatParameter` to be on the compute device ' + f'{self.compute_device} but got {handle.flat_param.device}', + ) + output = self._fsdp_wrapped_module(*args, **kwargs) + return _post_forward(self, handle, _post_forward_reshard, self, unused, output) + + @no_type_check def _share_state_and_init_handle_attrs_t2p1( root_state: '_FSDPState', @@ -801,8 +948,7 @@ def _share_state_and_init_handle_attrs_t2p1( been modified to assign a different unshard stream to each process group. """ from torch.distributed.fsdp._runtime_utils import (HOMOGENEOUS_ATTR_NAMES, _init_device_mesh, - _validate_and_get_hybrid_shard_state, - _wait_for_computation_stream) + _validate_and_get_hybrid_shard_state) from torch.distributed.utils import _p_assert handle = root_state._handle @@ -875,13 +1021,6 @@ def _share_state_and_init_handle_attrs_t2p1( handle = fsdp_state._handle if handle: handle.init_flat_param_attributes() - # Ensure that all unshard streams wait on the default computation stream - for pg_unshard_stream in fsdp_pg_unshard_streams.values(): - _wait_for_computation_stream( - root_state._device_handle.current_stream(), - pg_unshard_stream, - root_state._pre_unshard_stream, - ) for attr_name, attr_values in attr_name_to_values.items(): if len(attr_values) != 1: raise ValueError(f'Expects one homogeneous value for {attr_name} but got {attr_values}') @@ -899,8 +1038,7 @@ def _share_state_and_init_handle_attrs_t2p2( done together to require a single loop over the states. This function has been modified to assign a different unshard stream to each process group. """ - from torch.distributed.fsdp._runtime_utils import (HOMOGENEOUS_ATTR_NAMES, _validate_and_get_hybrid_shard_state, - _wait_for_computation_stream) + from torch.distributed.fsdp._runtime_utils import HOMOGENEOUS_ATTR_NAMES, _validate_and_get_hybrid_shard_state from torch.distributed.utils import _p_assert handle = root_state._handle @@ -973,13 +1111,6 @@ def _share_state_and_init_handle_attrs_t2p2( handle = fsdp_state._handle if handle: handle.init_flat_param_attributes() - # Ensure that all unshard streams wait on the default computation stream - for pg_unshard_stream in fsdp_pg_unshard_streams.values(): - _wait_for_computation_stream( - root_state._device_handle.current_stream(), - pg_unshard_stream, - root_state._pre_unshard_stream, - ) for attr_name, attr_values in attr_name_to_values.items(): if len(attr_values) != 1: raise ValueError(f'Expects one homogeneous value for {attr_name} but got {attr_values}')