diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index f8c77e7bf2..cd150ebf29 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -25,6 +25,8 @@ log = logging.getLogger(__name__) +process_group_cache = {} + class DDPSyncStrategy(StringEnum): """How and when gradient synchronization should happen. @@ -437,36 +439,56 @@ def _param_init_fn(module: torch.nn.Module) -> None: 'This leaves parameters without initialization. Please add a ``param_init_fn`` or ``reset_parameters`` ' f'to module `{obj_name}`.') - # Choose which modules to FSDP wrap according to the following priority: - # If module has attribute `module._fsdp_wrap = ...`, always respect it - # Otherwise wrap if root object `obj.fsdp_wrap_fn(module)` is true. - def __auto_wrap_policy(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool: - if recurse: - return True - should_be_wrapped = False - if hasattr(module, '_fsdp_wrap'): - should_be_wrapped = bool(module._fsdp_wrap) - elif hasattr(obj, 'fsdp_wrap_fn') and isinstance(obj.fsdp_wrap_fn, Callable): - should_be_wrapped = obj.fsdp_wrap_fn(module) - - if should_be_wrapped and auto_microbatching: - module.register_forward_hook(sync_hook) - module.register_full_backward_hook(sync_hook) - return should_be_wrapped - - if is_torch_2_0: - - def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool: - return __auto_wrap_policy(module, recurse, nonwrapped_numel) - - _auto_wrap_policy = _auto_wrap_policy_new - + if version.parse(torch.__version__) > version.parse('2.1.0.dev'): + # CustomPolicy is only supported in torch v2.1.0-rc1 or higher + from torch.distributed.fsdp.wrap import CustomPolicy # type: ignore + + def lambda_fn(module: torch.nn.Module) -> Union[bool, dict]: + ret = False + if hasattr(module, '_fsdp_wrap'): + ret = bool(module._fsdp_wrap) + elif hasattr(obj, 'fsdp_wrap_fn') and isinstance(obj.fsdp_wrap_fn, Callable): + ret = obj.fsdp_wrap_fn(module) + from composer.trainer.mosaic_fsdp_utils import _set_custom_fsdp_module_kwargs + if isinstance(ret, dict): + ret = _set_custom_fsdp_module_kwargs(ret, process_group_cache) + if ret and auto_microbatching: + module.register_forward_hook(sync_hook) + module.register_full_backward_hook(sync_hook) + return ret + + _auto_wrap_policy = CustomPolicy(lambda_fn) else: + # Choose which modules to FSDP wrap according to the following priority: + # If module has attribute `module._fsdp_wrap = ...`, always respect it + # Otherwise wrap if root object `obj.fsdp_wrap_fn(module)` is true. + def __auto_wrap_policy(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool: + if recurse: + return True + should_be_wrapped = False + if hasattr(module, '_fsdp_wrap'): + should_be_wrapped = bool(module._fsdp_wrap) + elif hasattr(obj, 'fsdp_wrap_fn') and isinstance(obj.fsdp_wrap_fn, Callable): + should_be_wrapped = obj.fsdp_wrap_fn(module) + + if should_be_wrapped and auto_microbatching: + module.register_forward_hook(sync_hook) + module.register_full_backward_hook(sync_hook) + return should_be_wrapped + + if is_torch_2_0: + + def _auto_wrap_policy_new(module: torch.nn.Module, recurse: bool, nonwrapped_numel: int) -> bool: + return __auto_wrap_policy(module, recurse, nonwrapped_numel) + + _auto_wrap_policy = _auto_wrap_policy_new + + else: - def _auto_wrap_policy_old(module: torch.nn.Module, recurse: bool, unwrapped_params: int) -> bool: - return __auto_wrap_policy(module, recurse, unwrapped_params) + def _auto_wrap_policy_old(module: torch.nn.Module, recurse: bool, unwrapped_params: int) -> bool: + return __auto_wrap_policy(module, recurse, unwrapped_params) - _auto_wrap_policy = _auto_wrap_policy_old + _auto_wrap_policy = _auto_wrap_policy_old fsdp_obj = FullyShardedDataParallel( obj, diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index fecdb62c93..adf934b9e6 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -42,10 +42,6 @@ def patch_pytorch(): elif version.parse(torch.__version__) < version.parse('2.1.1'): # Monkey path for torch < 2.1.1 ie torch == 2.1.0 - # Monkey patch __init__ where __init__ calls the custom _auto_wrap fn - from composer.trainer.mosaic_fsdp_utils import init_fn_t2p1p0 - FullyShardedDataParallel.__init__ = init_fn_t2p1p0 # type: ignore - # Monkey patch sharding method ChunkShardingSpec.build_metadata = build_metadata ChunkShardingSpec.shard = shard diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 68fd0ff93c..3e2eceee17 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -618,322 +618,6 @@ def init_fn_t2p0p1( _register_all_state_dict_hooks(self) -def _custom_recursive_wrap_t2p1p0( - module: nn.Module, - auto_wrap_policy: Callable, - wrapper_cls: Callable, - ignored_modules: Set[nn.Module], - ignored_params: Set[nn.Parameter], - process_group_cache: Dict[Tuple[int], Any], - only_wrap_children: bool = False, - **kwargs: Any, -) -> Tuple[nn.Module, int]: - """Supports custom wrapping of modules with FSDP kwargs. - - Torch version must be 2.1.0. - - Modified version of https://github.com/pytorch/pytorch/blob/8292b03c47fd71beb23ae834971e044aef6f4d7c/torch/distributed/fsdp/_wrap_utils.py#L25 - to support custom FSDP kwargs, e.g. process groups. - - Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns - ``True`` with ``wrapper_cls``. - - Args: - module (nn.Module): Module to recursively wrap. - auto_wrap_policy (Callable): A callable representing a policy that - determines which modules to recursively wrap with ``wrapper_cls``. - wrapper_cls: wrapper_cls - ignored_modules (Set[torch.nn.Module]): Modules to ignore when - wrapping. - ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when - wrapping; these should be the parameters contained in the modules - in ``ignored_modules``. - process_group_cache (Dict[Tuple[int], Any]): a cache of process_group to - use instead of potentially instantiating a new process_group - only_wrap_children: warp only children - Returns: - (nn.Module, int): - ``module`` after wrapping and the numel recursively wrapped. - """ - from torch.distributed.fsdp.wrap import _wrap - - assert auto_wrap_policy is not None, 'Must specify auto_wrap_policy.' - assert wrapper_cls is not None, 'Must specify wrapper_cls' - # Make sure no child is already wrapped. - for _, child in module.named_modules(): - if child in ignored_modules: - continue - try: - assert not isinstance(child, cast(type, wrapper_cls)) - except TypeError: - # wrapper_cls is a function as opposed to a class type so we bypass the above check. - pass - - # We count all params, assuming none of them are already wrapped. - nonwrapped_numel = sum(p.numel() for p in module.parameters() if p not in ignored_params) - - assert auto_wrap_policy is not None - if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel): - total_wrapped_numel = 0 - # Iterate through the children, recursively wrap if necessary - for name, child in module.named_children(): - if child in ignored_modules: - continue - wrapped_child, num_wrapped_params = _custom_recursive_wrap_t2p1p0( - module=child, - auto_wrap_policy=auto_wrap_policy, - wrapper_cls=wrapper_cls, - ignored_modules=ignored_modules, - ignored_params=ignored_params, - process_group_cache=process_group_cache, - **kwargs, - ) - setattr(module, name, wrapped_child) - # Keep track of how many parameters have been wrapped - total_wrapped_numel += num_wrapped_params - # decide if we need to wrap the current module, - # since the left over parameters exceed the number of params to wrap - remainder = nonwrapped_numel - total_wrapped_numel - - module_kwargs = auto_wrap_policy(module=module, recurse=False, nonwrapped_numel=remainder) - if not only_wrap_children and module_kwargs: - # CHANGE: We modify the original code to support custom FSDP kwargs and add - # the process_group_cache to avoid instantiating a new process group. - module_kwargs = module_kwargs if isinstance(module_kwargs, dict) else {} - module_kwargs = _set_custom_fsdp_module_kwargs(module_kwargs, process_group_cache) - - final_kwargs = {**kwargs, **module_kwargs} - - if final_kwargs.get('process_group', None) is not None: - _pg_ranks = distributed.get_process_group_ranks(final_kwargs['process_group']) - _meta_init = any(p.device.type == 'meta' for p in module.parameters()) - if _meta_init and len(_pg_ranks) != dist.get_world_size() and final_kwargs.get('use_orig_params'): - raise NotImplementedError( - f'FSDP with custom process groups cannot use `use_orig_params: True` when using meta init.') - - # Leaf node or final wrapping of the remainder both happen here. - return _wrap(module, wrapper_cls, **final_kwargs), nonwrapped_numel - else: - return module, total_wrapped_numel - return module, 0 - - -if version.parse(torch.__version__) > version.parse('2.0.2') and version.parse( - torch.__version__) < version.parse('2.1.1'): - from torch.distributed.fsdp._init_utils import ProcessGroupType - from torch.distributed.fsdp.wrap import ModuleWrapPolicy, _Policy - - def _custom_or_policy_t2p1p0( - module: nn.Module, - recurse: bool, - nonwrapped_numel: int, - policies, - ) -> bool: - """Modified version of `_or_policy` from FSDP. - - A policy that wraps ``module`` if any policy in the passed in iterable of - ``policies`` returns something truthy. The result does not have to be ``True`` - and can be, for example, a dictionary of kwargs to override wrapping. - """ - for policy in policies: - result = policy(module=module, recurse=recurse, nonwrapped_numel=nonwrapped_numel) - if policy: - return result - return False - - def _custom_auto_wrap_t2p1p0( - root_module: nn.Module, - policy: Union[Callable, _Policy], - ignored_modules: Set[nn.Module], - ignored_params: Set[nn.Parameter], - root_kwargs: Dict[str, Any], - fsdp_fn: Callable, # e.g. `FullyShardedDataParallel` or `fully_shard` - ): - """Modified version of https://github.com/pytorch/pytorch/blob/f13101640f548f8fa139c03dfa6711677278c391/torch/distributed/fsdp/wrap.py#L487. - - Calls custom _recursive_wrap fn and adds progress group cache. - - Auto wraps modules in ``root_module`` 's tree according to ``policy`` - following a post-order traversal. - - Precondition: ``root_kwargs`` should contain all arguments except - ``module``. This function accepts the kwargs dict directly since it gets - forwarded into the post-order traversal function. - """ - from torch.distributed.fsdp._common_utils import _override_module_mixed_precision - from torch.distributed.fsdp._wrap_utils import (_check_nested_wrapping, _validate_frozen_params, - _warn_on_overridden_mixed_precision) - from torch.distributed.fsdp.wrap import (_construct_wrap_fn, _post_order_apply, - _run_mixed_precision_override_policy, _wrap_module_cls_individually) - - mixed_precision = root_kwargs['mixed_precision'] - is_wrapper = inspect.isclass(fsdp_fn) - # TODO: We may relax this no-nested-wrapping constraint to support manual - # wrapping followed by auto wrapping. - _check_nested_wrapping(root_module) - - if isinstance(policy, _Policy): - root_kwargs['auto_wrap_policy' if is_wrapper else 'policy'] = None - target_module_to_kwargs = policy._run_policy(root_module, ignored_modules, root_kwargs) - if mixed_precision is not None: - target_module_to_kwargs = _run_mixed_precision_override_policy( - root_module, - mixed_precision._module_classes_to_ignore, - ignored_modules, - root_kwargs, - target_module_to_kwargs, - ) - overridden_module_classes = _override_module_mixed_precision(root_module, - mixed_precision._module_classes_to_ignore) - _warn_on_overridden_mixed_precision(overridden_module_classes) - use_orig_params = root_kwargs.get('use_orig_params', False) - _validate_frozen_params( - root_module, - set(target_module_to_kwargs.keys()), - ignored_params, - use_orig_params, - ) - wrap_fn = _construct_wrap_fn(root_module, target_module_to_kwargs, fsdp_fn) - _post_order_apply(root_module, wrap_fn) - return - - recursive_wrap_kwargs = { - 'module': root_module, - 'auto_wrap_policy': policy, - 'wrapper_cls': fsdp_fn, - 'ignored_modules': ignored_modules, - 'ignored_params': ignored_params, - 'only_wrap_children': True, - } - if mixed_precision is not None: - # Wrap modules of the ignored types separately and register forward - # hooks to cast to fp32 and back to the original dtype, respectively - overridden_module_classes = _override_module_mixed_precision(root_module, - mixed_precision._module_classes_to_ignore) - policy = functools.partial( - _custom_or_policy_t2p1p0, - policies=[ - policy, - partial( - _wrap_module_cls_individually, - module_classes=mixed_precision._module_classes_to_ignore, - ), - ], - ) - recursive_wrap_kwargs['auto_wrap_policy'] = policy - _warn_on_overridden_mixed_precision(overridden_module_classes) - - # CHANGE: Add process group cache and call our custom _recursive_wrap - recursive_wrap_kwargs['process_group_cache'] = {} - - _custom_recursive_wrap_t2p1p0(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type] - - def init_fn_t2p1p0( - self, - module: nn.Module, - process_group: ProcessGroupType = None, - sharding_strategy: Optional[ShardingStrategy] = None, - cpu_offload: Optional[CPUOffload] = None, - auto_wrap_policy: Optional[Union[Callable, ModuleWrapPolicy]] = None, - backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE, - mixed_precision: Optional[MixedPrecision] = None, - ignored_modules: Optional[Iterable[torch.nn.Module]] = None, - param_init_fn: Optional[Callable[[nn.Module], None]] = None, - device_id: Optional[Union[int, torch.device]] = None, - sync_module_states: bool = False, - forward_prefetch: bool = False, - limit_all_gathers: bool = True, - use_orig_params: bool = False, - ignored_states: Union[Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]] = None, - ): - """Modified version of https://github.com/pytorch/pytorch/blob/8ed169b1628285924e10fc98de53dbb75c92c43e/torch/distributed/fsdp/fully_sharded_data_parallel.py#L399C1.""" - from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo - from torch.distributed.fsdp._init_utils import (HYBRID_SHARDING_STRATEGIES, _check_orig_params_flattened, - _init_buffer_state, _init_core_state, _init_device_handle, - _init_ignored_module_states, _init_param_handle_from_module, - _init_prefetching_state, _init_process_group_state, - _init_runtime_state, _init_state_dict_state) - from torch.distributed.fsdp._state_dict_utils import _register_all_state_dict_hooks - from torch.distributed.fsdp._unshard_param_utils import _register_flat_param - - torch._C._log_api_usage_once('torch.distributed.fsdp') - super(FullyShardedDataParallel, self).__init__() - _init_ignored_module_states(self, module, ignored_modules, ignored_states) - _init_device_handle(self, module, self._ignored_params, device_id) - - # Add module annotations for Dynamo support (see function for details) - _annotate_modules_for_dynamo(module, self._ignored_modules, use_orig_params) - - # Initializes self.process_group, along with rank and world size. This will - # also set another attribute, _inter_node_pg, to control the process group - # over which sharding occurs, if sharding_strategy is {HYBRID_SHARD, _HYBRID_SHARD_ZERO2}. - # Note that this is done before auto_wrapping, so that child FSDP modules simply pick up - # the same process group state as the root FSDP module. - _init_process_group_state(self, process_group, sharding_strategy, auto_wrap_policy) - if auto_wrap_policy is not None: - root_kwargs = { - 'process_group': process_group, - 'sharding_strategy': sharding_strategy, - 'cpu_offload': cpu_offload, - 'backward_prefetch': backward_prefetch, - 'mixed_precision': mixed_precision, - 'param_init_fn': param_init_fn, - 'device_id': device_id, - 'sync_module_states': sync_module_states, - 'forward_prefetch': forward_prefetch, - 'limit_all_gathers': limit_all_gathers, - 'use_orig_params': use_orig_params, - 'ignored_states': self._ignored_params, - } - if sharding_strategy in HYBRID_SHARDING_STRATEGIES: - # Share root process groups with children to maintain - # the invariant that all FSDP modules will have the same - # process groups. - root_kwargs['process_group'] = (self.process_group, self._inter_node_pg) - - # CHANGE: Call our custom _auto_wrap function - _custom_auto_wrap_t2p1p0( - module, - auto_wrap_policy, - self._ignored_modules, - self._ignored_params, - root_kwargs, - FullyShardedDataParallel, - ) - - backward_prefetch_limit = 1 - forward_prefetch_limit = 1 - _init_core_state( - self, - sharding_strategy, - mixed_precision, - cpu_offload, - limit_all_gathers, - use_orig_params, - backward_prefetch_limit, - forward_prefetch_limit, - ) - _init_runtime_state(self) - _init_prefetching_state(self, backward_prefetch, forward_prefetch) - _init_buffer_state(self, module) - _init_param_handle_from_module( - self, - module, - device_id, - param_init_fn, - sync_module_states, - ) - self._fsdp_wrapped_module = module - if not use_orig_params: - _check_orig_params_flattened(self, self._ignored_params) - _register_flat_param(self, self) - - # `_state_dict_type` controls the `state_dict()` behavior, which is - # implemented using post-save and pre-load hooks - _init_state_dict_state(self) - _register_all_state_dict_hooks(self) - - def get_split_size(dim_size: int, chunks: int) -> int: """Gets the minimum size per chunk.