Skip to content

Commit

Permalink
Updating FSDP monkeypatch (mosaicml#2571)
Browse files Browse the repository at this point in the history
* debug wrap

* logs

* kwargs

* more logs

* add more logs

* fix logs

* custom or policy

* final kwargs

* remove prints and cleanupg

* remove two prints

* remove prints

* lint
  • Loading branch information
mvpatel2000 authored Sep 27, 2023
1 parent 51dd023 commit d56d4ad
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def _custom_recursive_wrap_t2p1p0(
f'FSDP with custom process groups cannot use `use_orig_params: True` when using meta init.')

# Leaf node or final wrapping of the remainder both happen here.
return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
return _wrap(module, wrapper_cls, **final_kwargs), nonwrapped_numel
else:
return module, total_wrapped_numel
return module, 0
Expand All @@ -723,6 +723,24 @@ def _custom_recursive_wrap_t2p1p0(
from torch.distributed.fsdp._init_utils import ProcessGroupType
from torch.distributed.fsdp.wrap import ModuleWrapPolicy, _Policy

def _custom_or_policy_t2p1p0(
module: nn.Module,
recurse: bool,
nonwrapped_numel: int,
policies,
) -> bool:
"""Modified version of `_or_policy` from FSDP.
A policy that wraps ``module`` if any policy in the passed in iterable of
``policies`` returns something truthy. The result does not have to be ``True``
and can be, for example, a dictionary of kwargs to override wrapping.
"""
for policy in policies:
result = policy(module=module, recurse=recurse, nonwrapped_numel=nonwrapped_numel)
if policy:
return result
return False

def _custom_auto_wrap_t2p1p0(
root_module: nn.Module,
policy: Union[Callable, _Policy],
Expand All @@ -745,7 +763,7 @@ def _custom_auto_wrap_t2p1p0(
from torch.distributed.fsdp._common_utils import _override_module_mixed_precision
from torch.distributed.fsdp._wrap_utils import (_check_nested_wrapping, _validate_frozen_params,
_warn_on_overridden_mixed_precision)
from torch.distributed.fsdp.wrap import (_construct_wrap_fn, _or_policy, _post_order_apply,
from torch.distributed.fsdp.wrap import (_construct_wrap_fn, _post_order_apply,
_run_mixed_precision_override_policy, _wrap_module_cls_individually)

mixed_precision = root_kwargs['mixed_precision']
Expand Down Expand Up @@ -793,7 +811,7 @@ def _custom_auto_wrap_t2p1p0(
overridden_module_classes = _override_module_mixed_precision(root_module,
mixed_precision._module_classes_to_ignore)
policy = functools.partial(
_or_policy,
_custom_or_policy_t2p1p0,
policies=[
policy,
partial(
Expand Down

0 comments on commit d56d4ad

Please sign in to comment.