diff --git a/composer/loggers/mlflow_logger.py b/composer/loggers/mlflow_logger.py index ed838e7c2c..c9ab559ab9 100644 --- a/composer/loggers/mlflow_logger.py +++ b/composer/loggers/mlflow_logger.py @@ -94,7 +94,7 @@ def __init__( self._last_flush_time = time.time() self._flush_interval = flush_interval - self._experiment_id = None + self._experiment_id: Optional[str] = None self._run_id = None if self._enabled: @@ -150,6 +150,7 @@ def init(self, state: State, logger: Logger) -> None: self._run_id = env_run_id else: # Search for an existing run tagged with this Composer run. + assert self._experiment_id is not None existing_runs = mlflow.search_runs(experiment_ids=[self._experiment_id], filter_string=f'tags.composer_run_name = "{state.run_name}"', output_format='list') diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index 90f15771a6..3ab36ef0bc 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -90,3 +90,38 @@ def patch_pytorch(): # Monkeypatch dtensor support from composer.trainer.mosaic_fsdp_utils import init_fn_t2p2p0 FullyShardedDataParallel.__init__ = init_fn_t2p2p0 # type: ignore + + # Monkeypath state_dict + from torch.distributed.checkpoint import state_dict # type: ignore + + from composer.trainer.mosaic_fsdp_utils import _verify_options_t2p2p0 + state_dict._verify_options = _verify_options_t2p2p0 + + elif version.parse(torch.__version__) < version.parse('2.3.1'): + # Monkey patch for torch < 2.3.1 ie torch == 2.3.0 + # Note: this is the same patch as 2.2.0, we are just making a new if branch + # for clarity and modularity of changes. + + # Allow 2D HSDP + from torch.distributed.fsdp import _runtime_utils + _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None + + # # Better overlap communication and computation + # from torch.distributed.fsdp import _runtime_utils + + # from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p2, + # _wait_for_computation_stream, forward) + # _runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p2 + # _runtime_utils._wait_for_computation_stream = _wait_for_computation_stream + # _runtime_utils._root_pre_forward = _root_pre_forward + # FullyShardedDataParallel.forward = forward + + # Monkeypath state_dict + from composer.trainer.mosaic_fsdp_utils import init_fn_t2p2p0 + FullyShardedDataParallel.__init__ = init_fn_t2p2p0 + + # Monkeypath state_dict + from torch.distributed.checkpoint import state_dict # type: ignore + + from composer.trainer.mosaic_fsdp_utils import _verify_options_t2p2p0 + state_dict._verify_options = _verify_options_t2p2p0 diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 6dbfbd9bdc..4c0a1b5a81 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -13,7 +13,9 @@ import logging import math import warnings -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Set, Tuple, Union, cast, no_type_check +import contextlib +from dataclasses import asdict +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union, cast, no_type_check import torch import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta @@ -38,6 +40,7 @@ torch.__version__) < version.parse('2.2.0'): from torch.distributed.fsdp._common_utils import _FSDPState + log = logging.getLogger(__name__) SHARDING_MAP = { @@ -227,7 +230,7 @@ def _custom_recursive_wrap_t1p13p1( modified version of https://github.com/pytorch/pytorch/blob/d922c29a22e4bf0fba49526f7536395eb8cd66f4/torch/distributed/fsdp/wrap.py#L353 which recursively wraps modules as FSDP modules for parameter sharding. - This modification enables the user to pass custom FSDP arguements for every wrapped module. + This modification enables the user to pass custom FSDP arguments for every wrapped module. The added process_group_cache enables different FSDP modules to, when appropriate, use the same process group instead of instantiating a new process group. @@ -318,7 +321,7 @@ def custom_auto_wrap_t1p13p1( modified version of https://github.com/pytorch/pytorch/blob/d922c29a22e4bf0fba49526f7536395eb8cd66f4/torch/distributed/fsdp/fully_sharded_data_parallel.py#L1252 FSDP's _auto_wrap recursively wraps modules as FSDP modules for parameter sharding. - This modification enables the user to pass custom FSDP arguements for every wrapped module. + This modification enables the user to pass custom FSDP arguments for every wrapped module. The added process_group_cache enables different FSDP modules to, when appropriate, use the same process group instead of instantiating a new process group. @@ -373,7 +376,7 @@ def _custom_recursive_wrap_t2p0p1( modified version of https://github.com/pytorch/pytorch/blob/96ca226a7332be0d8f3d6159d0c797e032ab0721/torch/distributed/fsdp/wrap.py#L320 which recursively wraps modules as FSDP modules for parameter sharding. - This modification enables the user to pass custom FSDP arguements for every wrapped module. + This modification enables the user to pass custom FSDP arguments for every wrapped module. The added process_group_cache enables different FSDP modules to, when appropriate, use the same process group instead of instantiating a new process group. @@ -471,7 +474,7 @@ def _custom_auto_wrap_t2p0p1( modified version of https://github.com/pytorch/pytorch/blob/96ca226a7332be0d8f3d6159d0c797e032ab0721/torch/distributed/fsdp/_wrap_utils.py#L31 FSDP's _auto_wrap recursively wraps modules as FSDP modules for parameter sharding. - This modification enables the user to pass custom FSDP arguements for every wrapped module. + This modification enables the user to pass custom FSDP arguments for every wrapped module. The added process_group_cache enables different FSDP modules to, when appropriate, use the same process group instead of instantiating a new process group. @@ -760,7 +763,7 @@ def _sharded_pre_load_state_dict_hook( if version.parse(torch.__version__) > version.parse('2.1.3') and version.parse( - torch.__version__) < version.parse('2.2.1'): + torch.__version__) < version.parse('2.3.1'): import copy from torch.distributed._tensor import DeviceMesh, DTensor, Replicate @@ -771,7 +774,7 @@ def _sharded_pre_load_state_dict_hook( from torch.distributed.fsdp._common_utils import _FSDPState from torch.distributed.fsdp._init_utils import (HYBRID_SHARDING_STRATEGIES, ProcessGroupType, _get_default_comm_hook_state, _init_intra_and_inter_node_groups, - _is_valid_hybrid_shard_pg_type) + _is_valid_hybrid_shard_pg_type, _init_extension) from torch.distributed.fsdp.fully_sharded_data_parallel import (_annotate_modules_for_dynamo, _auto_wrap, _check_orig_params_flattened, _init_buffer_state, _init_core_state, _init_device_handle, @@ -826,7 +829,7 @@ def chunk_dtensor_t2p2p0( tensor = tensor.clone().detach() # When a layer is not involved in TP, then the tensor will not be a DTensor. - # e.g. When a layer is not sppecified in the parallelize_plan, TP will have no effect on the layer. + # e.g. When a layer is not specified in the parallelize_plan, TP will have no effect on the layer. # e.g. When you do PairwiseParallel on a 3 layer model, TP will have no effect on the third layer. if isinstance(tensor, torch.Tensor) and not isinstance(tensor, DTensor): @@ -869,17 +872,6 @@ def chunk_dtensor_t2p2p0( DTensorExtensions.all_gather_dtensor = all_gather_dtensor_t2p2p0 DTensorExtensions.chunk_dtensor = chunk_dtensor_t2p2p0 - def _init_extension_t2p2p0(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState: - # TODO: we need to add additional check once we support FSDP + PiPPy. - # This check is currently sufficient, since we only support FSDP + TP. - if device_mesh and _mesh_resources.get_parent_mesh(state._device_mesh) is not None: - state._fsdp_extension = DTensorExtensions() - else: - # We need to explicilty set _fsdp_extension to None. - # Otherwise, we will run into an infinite recursion when getting the attribute. - state._fsdp_extension = None - return state - def _is_valid_hybrid_shard_device_mesh_t2p2p0(device_mesh: DeviceMesh) -> bool: #parent_mesh = _mesh_resources.get_parent_mesh(device_mesh) #if parent_mesh is not None: @@ -1052,7 +1044,7 @@ def init_fn_t2p2p0( _init_prefetching_state(self, backward_prefetch, forward_prefetch) _init_buffer_state(self, module) # extension needs to be set before `_init_param_handle_from_module()` - _init_extension_t2p2p0(self, device_mesh) + _init_extension(self, device_mesh) _init_param_handle_from_module( self, module, @@ -1070,6 +1062,85 @@ def init_fn_t2p2p0( _init_state_dict_state(self) _register_all_state_dict_hooks(self) + from torch.distributed.checkpoint.state_dict import StateDictOptions, _StateDictInfo + + def _verify_options_t2p2p0( + model: nn.Module, + optims: Tuple[torch.optim.Optimizer, ...], + optim_only: bool, + *, + submodules: Optional[Set[nn.Module]] = None, + options: Optional[StateDictOptions] = None, + ) -> _StateDictInfo: + """Verify the model and options passed by the user and generates _StateDictInfo.""" + from torch.distributed.checkpoint.state_dict import StateDictOptions, _get_fqns, _StateDictInfo + from torch.distributed.fsdp import FullOptimStateDictConfig, FullStateDictConfig + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import (OptimStateDictConfig, ShardedOptimStateDictConfig, ShardedStateDictConfig, + StateDictConfig, StateDictType) + + if optim_only and not optims: + raise RuntimeError('Optimizers are not passed in but optim_only is set to True.') + + options = options or StateDictOptions() + assert options is not None # pyright + + fqn_param_mapping: Dict[Union[str, torch.Tensor], Union[Set[str], torch.Tensor]] = {} + all_fqns = set() + for name, param in model.named_parameters(): + fqns = _get_fqns(model, name) + fqns = {fqn.replace('_checkpoint_wrapped_module.', '') for fqn in fqns} + fqn_param_mapping[param] = fqns + for fqn in fqns: + fqn_param_mapping[fqn] = param + all_fqns.add(fqn) + + submodule_prefixes = set() + if submodules: + submodules = set(submodules) + for name, module in model.named_modules(): + if module not in submodules: + continue + fqns = _get_fqns(model, name) + assert len(fqns) == 1, 'Submodule FQN should only have 1 instance' + for fqn in fqns: + submodule_prefixes.add(f'{fqn}.') + fsdp_modules = FSDP.fsdp_modules(model) + state_dict_config: StateDictConfig + optim_state_dict_config: OptimStateDictConfig + fsdp_context: Callable + if fsdp_modules: + # FSDP API only work if at least one FSDP instance exists. + if options.full_state_dict: + state_dict_config = FullStateDictConfig(offload_to_cpu=options.cpu_offload, rank0_only=options.cpu_offload) + optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=options.cpu_offload, + rank0_only=options.cpu_offload) + state_dict_type = StateDictType.FULL_STATE_DICT + else: + state_dict_config = ShardedStateDictConfig() + optim_state_dict_config = ShardedOptimStateDictConfig(offload_to_cpu=options.cpu_offload,) + state_dict_type = StateDictType.SHARDED_STATE_DICT + + fsdp_context = functools.partial( + FSDP.state_dict_type, + module=model, + state_dict_type=state_dict_type, + state_dict_config=state_dict_config, + optim_state_dict_config=optim_state_dict_config, + ) + else: + fsdp_context = contextlib.nullcontext + return _StateDictInfo( + **asdict(options), + fqn_param_mapping=fqn_param_mapping, + all_fqns=all_fqns, + submodule_prefixes=submodule_prefixes, + fsdp_context=fsdp_context, + fsdp_modules=cast(List[nn.Module], fsdp_modules), + handle_model=not optim_only, + handle_optim=(len(optims) > 0), + ) + def fsdp_state_has_default_pg(state: '_FSDPState') -> bool: """Indicates whether FlatParamHandle has the default process group. @@ -1153,7 +1224,7 @@ def _root_pre_forward( _p_assert(state._is_root is not None, 'Expects a root FSDP to have been set') if not state._is_root: # Always cast forward inputs in the root of this local FSDP unit for mixed - # precision, as this is where mixed precision could be configed. + # precision, as this is where mixed precision could be configured. # This is more useful for auto wrapping that is recommended in composable path. # For manual wrapping, cast forward inputs on each local FSDP unit root will # increase some overhead, so not turned on for model wrapper path right now where