Skip to content

Commit

Permalink
All unshard streams wait on computation every step (mosaicml#2823)
Browse files Browse the repository at this point in the history
* patched torch

* fixed torch imports

* fixed torch imports

* fixed torch imports

* patching through composer

* patching through composer

* patching typingr

* comment added

* don't patch torch 2.1.0

* patch torch 2.1.1 and 2.2.0

* linting fix

* waiting on computation stream from unshard stream

* waiting on computation stream from unshard stream

* less waiting

* no waiting

* all unshard streams wait on computation stream now

* 2.2.0 dev change

* correct waiting on computation stream

* fsdp state typiung

* patching root pre forward

* patching root pre forward

* fsdp state typing

* patch forward

* correct waiting

* linting
  • Loading branch information
snarayan21 authored Jan 8, 2024
1 parent 5592e41 commit c22c61a
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 21 deletions.
16 changes: 13 additions & 3 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch
from packaging import version
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed.fsdp import FullyShardedDataParallel

from composer.trainer.mosaic_fsdp_utils import (_sharded_pre_load_state_dict_hook, build_metadata,
custom_auto_wrap_t1p13p1)
Expand Down Expand Up @@ -62,14 +61,25 @@ def patch_pytorch():
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

# Better overlap communication and computation
from composer.trainer.mosaic_fsdp_utils import _share_state_and_init_handle_attrs_t2p1
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel

from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p1,
_wait_for_computation_stream, forward)
_runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p1
_runtime_utils._wait_for_computation_stream = _wait_for_computation_stream
_runtime_utils._root_pre_forward = _root_pre_forward
FullyShardedDataParallel.forward = forward

elif version.parse(torch.__version__) < version.parse('2.2.1'):
# Monkey patch for torch < 2.2.1 ie torch == 2.2.0

# Better overlap communication and computation
from torch.distributed.fsdp import _runtime_utils
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel

from composer.trainer.mosaic_fsdp_utils import _share_state_and_init_handle_attrs_t2p2
from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p2,
_wait_for_computation_stream, forward)
_runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p2
_runtime_utils._wait_for_computation_stream = _wait_for_computation_stream
_runtime_utils._root_pre_forward = _root_pre_forward
FullyShardedDataParallel.forward = forward
167 changes: 149 additions & 18 deletions composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,153 @@ def fsdp_state_pg_ranks(state: '_FSDPState') -> Tuple[int, ...]:
return tuple(get_process_group_ranks(state.process_group))


def _wait_for_computation_stream(
computation_stream: torch.Stream,
root_state: '_FSDPState',
pre_unshard_stream: torch.Stream,
):
"""Unshard and pre-unshard streams wait for computation stream.
Has the unshard and pre-unshard streams wait for the computation stream.
For example, this should be called in the FSDP root's pre-forward to
respect optimizer step computation.
"""
# Tracing does not need to wait
if torch.distributed._functional_collectives.is_torchdynamo_compiling():
return
# Ensure all unshard streams wait for the computation stream.
unshard_streams = set()
for fsdp_state in root_state._all_fsdp_states:
unshard_streams.add(fsdp_state._unshard_stream)
for unshard_stream in unshard_streams:
unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]
# Having the pre-all-gather stream wait for the current stream even if we
# do not leverage the pre-all-gather stream is tolerable since this only
# runs once per iteration
pre_unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]


@no_type_check
def _root_pre_forward(
state: '_FSDPState',
module: nn.Module,
args,
kwargs,
) -> None:
"""Runs pre-forward logic specific to the root FSDP instance.
This should run before any individual module's pre-forward. This starts
with an attempt at lazy initialization (which only runs non-vacuously once).
Otherwise, if this is called on a non-root FSDP instance, then it returns
directly.
"""
from torch.distributed.fsdp._common_utils import _is_composable
from torch.distributed.fsdp._runtime_utils import (_cast_buffers_to_dtype_and_device,
_get_buffers_and_dtypes_for_computation, _lazy_init,
_reset_flat_param_grad_info_if_needed, _root_cast_forward_input)
from torch.distributed.utils import _p_assert, _to_kwargs
with torch.profiler.record_function('FullyShardedDataParallel._root_pre_forward'):
_lazy_init(state, module)
_p_assert(state._is_root is not None, 'Expects a root FSDP to have been set')
if not state._is_root:
# Always cast forward inputs in the root of this local FSDP unit for mixed
# precision, as this is where mixed precision could be configed.
# This is more useful for auto wrapping that is recommended in composable path.
# For manual wrapping, cast forward inputs on each local FSDP unit root will
# increase some overhead, so not turned on for model wrapper path right now where
# manual wrapping is more broadly used.
if _is_composable(state):
return _root_cast_forward_input(state, module, args, kwargs)
return args, kwargs

# We cast buffers back to full precision if we're forcing full precision. Disjointly, we check if buffers
# are in full precision and if we should cast them back to lower precision, which happens when
# exiting eval() mode.
handle = state._handle
if handle:
should_cast_buffers_to_full_prec = handle._force_full_precision
else:
should_cast_buffers_to_full_prec = True

if should_cast_buffers_to_full_prec:
_cast_buffers_to_dtype_and_device(
buffers=dict(module.named_buffers()).values(),
buffer_dtypes=list(state._buffer_name_to_orig_dtype.values()),
device=state.compute_device,
)
# This flag is only set when we cast buffers to full precision, to avoid the
# CPU overhead that can stem from retrieving all buffers and their types in the
# following else branch.
state._needs_buffer_dtype_restore_check = True
elif getattr(state, '_needs_buffer_dtype_restore_check', False):
# Check if buffers are in full precision and we need to cast them
# back down.
(
buffers,
buffer_dtypes_for_computation,
) = _get_buffers_and_dtypes_for_computation(state, module)
if len(buffers) > 0 and len(buffer_dtypes_for_computation) > 0:
if any(buffer.dtype != buffer_dtype_for_computation
for buffer, buffer_dtype_for_computation in zip(buffers, buffer_dtypes_for_computation)):
# Assume we have to cast everything if there is one mismatch
_cast_buffers_to_dtype_and_device(buffers, buffer_dtypes_for_computation, state.compute_device)
# We don't have to check this again until we cast buffers to full precision again.
state._needs_buffer_dtype_restore_check = False

if state.forward_prefetch:
handles = []
for fsdp_state in state._all_fsdp_states:
if fsdp_state._handle:
handles.append(fsdp_state._handle)
for handle in handles:
handle._needs_pre_forward_unshard = True
handle._prefetched = False

_wait_for_computation_stream(
state._device_handle.current_stream(),
state,
state._pre_unshard_stream,
)
_reset_flat_param_grad_info_if_needed(state._all_handles)

# Prepares the forward inputs by moving them to ``compute_device``
# TODO: Do not use the side stream for tensor copies for now; investigate
# the perf with/without it.
with torch.profiler.record_function('FullyShardedDataParallel._to_kwargs'):
args_tuple, kwargs_tuple = _to_kwargs(args, kwargs, state.compute_device, False)
args = args_tuple[0]
kwargs = kwargs_tuple[0]

return _root_cast_forward_input(state, module, args, kwargs)


def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic."""
from torch.distributed.fsdp._runtime_utils import (_post_forward, _post_forward_reshard, _pre_forward,
_pre_forward_unshard)
from torch.distributed.utils import _p_assert
handle = self._handle
with torch.autograd.profiler.record_function('FullyShardedDataParallel.forward'):
args, kwargs = _root_pre_forward(self, self, args, kwargs)
unused = None
args, kwargs = _pre_forward(
self,
handle,
_pre_forward_unshard,
self._fsdp_wrapped_module,
args,
kwargs,
)
if handle:
_p_assert(
handle.flat_param.device == self.compute_device,
'Expected `FlatParameter` to be on the compute device '
f'{self.compute_device} but got {handle.flat_param.device}',
)
output = self._fsdp_wrapped_module(*args, **kwargs)
return _post_forward(self, handle, _post_forward_reshard, self, unused, output)


@no_type_check
def _share_state_and_init_handle_attrs_t2p1(
root_state: '_FSDPState',
Expand All @@ -801,8 +948,7 @@ def _share_state_and_init_handle_attrs_t2p1(
been modified to assign a different unshard stream to each process group.
"""
from torch.distributed.fsdp._runtime_utils import (HOMOGENEOUS_ATTR_NAMES, _init_device_mesh,
_validate_and_get_hybrid_shard_state,
_wait_for_computation_stream)
_validate_and_get_hybrid_shard_state)
from torch.distributed.utils import _p_assert

handle = root_state._handle
Expand Down Expand Up @@ -875,13 +1021,6 @@ def _share_state_and_init_handle_attrs_t2p1(
handle = fsdp_state._handle
if handle:
handle.init_flat_param_attributes()
# Ensure that all unshard streams wait on the default computation stream
for pg_unshard_stream in fsdp_pg_unshard_streams.values():
_wait_for_computation_stream(
root_state._device_handle.current_stream(),
pg_unshard_stream,
root_state._pre_unshard_stream,
)
for attr_name, attr_values in attr_name_to_values.items():
if len(attr_values) != 1:
raise ValueError(f'Expects one homogeneous value for {attr_name} but got {attr_values}')
Expand All @@ -899,8 +1038,7 @@ def _share_state_and_init_handle_attrs_t2p2(
done together to require a single loop over the states. This function has
been modified to assign a different unshard stream to each process group.
"""
from torch.distributed.fsdp._runtime_utils import (HOMOGENEOUS_ATTR_NAMES, _validate_and_get_hybrid_shard_state,
_wait_for_computation_stream)
from torch.distributed.fsdp._runtime_utils import HOMOGENEOUS_ATTR_NAMES, _validate_and_get_hybrid_shard_state
from torch.distributed.utils import _p_assert

handle = root_state._handle
Expand Down Expand Up @@ -973,13 +1111,6 @@ def _share_state_and_init_handle_attrs_t2p2(
handle = fsdp_state._handle
if handle:
handle.init_flat_param_attributes()
# Ensure that all unshard streams wait on the default computation stream
for pg_unshard_stream in fsdp_pg_unshard_streams.values():
_wait_for_computation_stream(
root_state._device_handle.current_stream(),
pg_unshard_stream,
root_state._pre_unshard_stream,
)
for attr_name, attr_values in attr_name_to_values.items():
if len(attr_values) != 1:
raise ValueError(f'Expects one homogeneous value for {attr_name} but got {attr_values}')

0 comments on commit c22c61a

Please sign in to comment.