diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 090d92e227..68fd0ff93c 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -712,7 +712,7 @@ def _custom_recursive_wrap_t2p1p0( f'FSDP with custom process groups cannot use `use_orig_params: True` when using meta init.') # Leaf node or final wrapping of the remainder both happen here. - return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel + return _wrap(module, wrapper_cls, **final_kwargs), nonwrapped_numel else: return module, total_wrapped_numel return module, 0 @@ -723,6 +723,24 @@ def _custom_recursive_wrap_t2p1p0( from torch.distributed.fsdp._init_utils import ProcessGroupType from torch.distributed.fsdp.wrap import ModuleWrapPolicy, _Policy + def _custom_or_policy_t2p1p0( + module: nn.Module, + recurse: bool, + nonwrapped_numel: int, + policies, + ) -> bool: + """Modified version of `_or_policy` from FSDP. + + A policy that wraps ``module`` if any policy in the passed in iterable of + ``policies`` returns something truthy. The result does not have to be ``True`` + and can be, for example, a dictionary of kwargs to override wrapping. + """ + for policy in policies: + result = policy(module=module, recurse=recurse, nonwrapped_numel=nonwrapped_numel) + if policy: + return result + return False + def _custom_auto_wrap_t2p1p0( root_module: nn.Module, policy: Union[Callable, _Policy], @@ -745,7 +763,7 @@ def _custom_auto_wrap_t2p1p0( from torch.distributed.fsdp._common_utils import _override_module_mixed_precision from torch.distributed.fsdp._wrap_utils import (_check_nested_wrapping, _validate_frozen_params, _warn_on_overridden_mixed_precision) - from torch.distributed.fsdp.wrap import (_construct_wrap_fn, _or_policy, _post_order_apply, + from torch.distributed.fsdp.wrap import (_construct_wrap_fn, _post_order_apply, _run_mixed_precision_override_policy, _wrap_module_cls_individually) mixed_precision = root_kwargs['mixed_precision'] @@ -793,7 +811,7 @@ def _custom_auto_wrap_t2p1p0( overridden_module_classes = _override_module_mixed_precision(root_module, mixed_precision._module_classes_to_ignore) policy = functools.partial( - _or_policy, + _custom_or_policy_t2p1p0, policies=[ policy, partial(