diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml index 989b4ded43..55fbefcfe6 100644 --- a/.github/workflows/pr-cpu.yaml +++ b/.github/workflows/pr-cpu.yaml @@ -27,6 +27,11 @@ jobs: markers: 'not daily and not remote and not gpu and not vision and not doctest' pytest_command: 'coverage run -m pytest' composer_package_name: 'mosaicml' + # - name: 'cpu-3.10-2.2' + # container: mosaicml/pytorch:2.2.0_cu121-nightly20231213-python3.10-ubuntu20.04 + # markers: 'not daily and not remote and not gpu and not vision and not doctest' + # pytest_command: 'coverage run -m pytest' + # composer_package_name: 'mosaicml' - name: 'cpu-vision' container: mosaicml/pytorch_vision:1.13.1_cpu-python3.10-ubuntu20.04 markers: 'not daily and not remote and not gpu and vision and not doctest' diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 2c818b7229..acd7b4266a 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -17,6 +17,11 @@ jobs: markers: 'not daily and not remote and gpu and (doctest or not doctest)' pytest_command: 'coverage run -m pytest' composer_package_name: 'mosaicml' + # - name: 'gpu-3.10-2.2' + # container: mosaicml/pytorch:2.2.0_cu121-nightly20231213-python3.10-ubuntu20.04 + # markers: 'not daily and not remote and gpu and (doctest or not doctest)' + # pytest_command: 'coverage run -m pytest' + # composer_package_name: 'mosaicml' name: ${{ matrix.name }} if: github.repository_owner == 'mosaicml' with: diff --git a/composer/algorithms/alibi/attention_surgery_functions/_bert.py b/composer/algorithms/alibi/attention_surgery_functions/_bert.py index 915e940cad..c2a7bb3bd5 100644 --- a/composer/algorithms/alibi/attention_surgery_functions/_bert.py +++ b/composer/algorithms/alibi/attention_surgery_functions/_bert.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 +import copy import math from types import MethodType from typing import Optional, Tuple @@ -20,13 +21,14 @@ def bert_embedding_converter(module: torch.nn.Module, module_index: int, max_seq """ assert isinstance(module, (BertEmbeddings, RobertaEmbeddings)) del module_index # unused - zero_and_freeze_expand_position_embeddings(module, + new_module = copy.deepcopy(module) + zero_and_freeze_expand_position_embeddings(new_module, max_sequence_length, position_embedding_attribute='position_embeddings') - module_device = next(module.parameters()).device - module.register_buffer('position_ids', torch.arange(max_sequence_length).expand((1, -1)).to(module_device)) - return module + module_device = next(new_module.parameters()).device + new_module.register_buffer('position_ids', torch.arange(max_sequence_length).expand((1, -1)).to(module_device)) + return new_module @policy_registry.register(BertSelfAttention, RobertaSelfAttention) diff --git a/composer/core/state.py b/composer/core/state.py index 1ba5a193db..be92a799ce 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -10,7 +10,7 @@ import warnings from collections import OrderedDict from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast import numpy as np import torch @@ -792,6 +792,15 @@ def fsdp_state_dict_type(self): def fsdp_sharded_state_dict_enabled(self): return self.fsdp_config is not None and self.fsdp_enabled and self.fsdp_state_dict_type in ['sharded', 'local'] + @property + def fsdp_device_mesh(self): + if self.fsdp_enabled: + if not hasattr(self.model, 'model'): + return None + return self.model.model._device_mesh + else: + return None + @property def load_fsdp_monolith_rank0_only(self): return self.fsdp_config is not None and self.fsdp_auto_wrap and self.fsdp_config[ @@ -864,6 +873,9 @@ def get_model_state_dict(self) -> Dict[str, Any]: Returns: Dict[str, Any]: The state dict for the model. """ + return self.get_model_and_optimizer_state_dict(model_only=True)[0] + + def _legacy_get_model_state_dict(self) -> Dict[str, Any]: if self.fsdp_enabled and self.fsdp_state_dict_type is not None: with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type): model_state_dict = self.model.state_dict() @@ -876,6 +888,43 @@ def get_model_state_dict(self) -> Dict[str, Any]: torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(model_state_dict, 'module.') return model_state_dict + def _legacy_get_optim_state_dict(self) -> Dict[str, Any]: + optimizer = ensure_tuple(self.optimizers)[0] # Let's stop pretending. We don't support more than one optimizer. + if self.fsdp_enabled and self.fsdp_state_dict_type is not None: + optim_state_dict = { + type(optimizer).__qualname__: + fsdp_get_optim_state_dict(self.model, optimizer, state_dict_type=self.fsdp_state_dict_type) + } + else: + optim_state_dict = {type(optimizer).__qualname__: optimizer.state_dict()} + return optim_state_dict + + def get_model_and_optimizer_state_dict(self, model_only=False) -> Tuple[Dict[str, Any], Dict[str, Any]]: + if version.parse(torch.__version__) > version.parse('2.1.3'): + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict + if self.fsdp_state_dict_type not in [None, 'full', 'sharded']: + raise NotImplementedError( + textwrap.dedent(f'fsdp_state_dict_type={self.fsdp_state_dict_type} is not supported for ' + f'torch version {{version.parse(torch.__version__)}} > 2.1.3. Please set ' + 'fsdp_state_dict_type to None, "full", or "sharded".')) + + optimizer = ensure_tuple(self.optimizers)[0] + model_state_dict, optim_state_dict = get_state_dict( + model=self.model, + optimizers=([] if model_only else optimizer), + submodules=None, + options=StateDictOptions( + full_state_dict=self.fsdp_state_dict_type != 'sharded', + cpu_offload=True, + ), + ) + optim_state_dict = {type(optimizer).__qualname__: optim_state_dict} + else: + model_state_dict = self._legacy_get_model_state_dict() + optim_state_dict = self._legacy_get_optim_state_dict() + + return model_state_dict, optim_state_dict + def state_dict(self) -> Dict[str, Any]: """Collect the state dicts of our serializable attributes. @@ -883,23 +932,16 @@ def state_dict(self) -> Dict[str, Any]: Dict[str, Any]: The state dict. """ state_dict = {} - + model_state_dict, optim_state_dict = None, None + if 'model' in self.serialized_attributes or 'optimizers' in self.serialized_attributes: + model_state_dict, optim_state_dict = self.get_model_and_optimizer_state_dict() for attribute_name in self.serialized_attributes: attribute_value = getattr(self, attribute_name) if attribute_name == 'dataset_state': serialized_value = self._dataset_state_dict() elif attribute_name == 'model': - serialized_value = self.get_model_state_dict() + serialized_value = model_state_dict elif attribute_name == 'optimizers': - optimizer = ensure_tuple(attribute_value)[ - 0] # Let's stop pretending. We don't support more than one optimizer. - if self.fsdp_enabled and self.fsdp_state_dict_type is not None: - optim_state_dict = { - type(optimizer).__qualname__: - fsdp_get_optim_state_dict(self.model, optimizer, state_dict_type=self.fsdp_state_dict_type) - } - else: - optim_state_dict = {type(optimizer).__qualname__: optimizer.state_dict()} serialized_value = optim_state_dict elif attribute_name == 'algorithms': # Store as list to preserve order in which algorithms were applied @@ -1058,49 +1100,34 @@ def _apply_required_algorithms( 'have undergone surgery, the following algorithms may be excluded using ' f'`load_exclude_algorithms`, e.g. `load_exclude_algorithms=[{missing_algo_names}]`.')) from e - def load_model_state( + def _legacy_load_model_state( self, state_dict: Dict[str, Any], - logger: Logger, strict: bool, - exclude_algorithms: Optional[List[str]] = None, - algorithm_passes: Optional[List[AlgorithmPass]] = None, ): """Loads the model's state from a ``state_dict``. Args: state_dict (Dict[str, Any]): The state dict, generated from a previous call to :meth:`state_dict`. - logger (Logger): The logger. strict (bool): Whether the keys (i.e., model parameter names) in the model state dict should perfectly match the keys in the model instance. - exclude_algorithms (List[str], optional): List of algorithm names to exclude from autoloading. (default: ``None``) - algorithm_passes (List[AlgorithmPass], optional): A list of algorithm passes to apply to autoloaded algorithms - to sort them into the correct order. (default: ``None``) """ - if 'algorithms' in state_dict: - self._apply_required_algorithms(state_dict, logger, exclude_algorithms, algorithm_passes) - - if state_dict.get('is_model_ddp', False) and not self.is_model_ddp: - # This check is for backwards compatibility, as pre-v0.6.0 checkpoints serialized the state - # with the `module.` prefix - torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict['model'], 'module.') - # For FSDP monolith checkpoints, the model does not exist on ranks > 0 - model_on_rank = state_dict['model'] is not None + if state_dict['model'] is None: + return missing_keys, unexpected_keys = [], [] try: - # Load model if it exists. For FSDP monolith checkpoints, the model does not exist on ranks > 0 - if model_on_rank: - if self.fsdp_enabled and self.fsdp_state_dict_type is not None and not self.load_fsdp_monolith_rank0_only: - log.debug( - f'Loading model state dict with strict={strict} and FSDP state_dict_type={self.fsdp_state_dict_type}' - ) - with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type): - missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict) - else: - log.debug(f'Loading model state dict with strict={strict}') + # Load model if it exists + if self.fsdp_enabled and self.fsdp_state_dict_type is not None and not self.load_fsdp_monolith_rank0_only: + log.debug( + f'Loading model state dict with strict={strict} and FSDP state_dict_type={self.fsdp_state_dict_type}' + ) + with fsdp_state_dict_type_context(self.model, state_dict_type=self.fsdp_state_dict_type): missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict) + else: + log.debug(f'Loading model state dict with strict={strict}') + missing_keys, unexpected_keys = self.model.load_state_dict(state_dict['model'], strict=strict) except RuntimeError as e: if 'Missing key(s) in state_dict' in str(e) or 'Unexpected key(s) in state_dict' in str(e): raise RuntimeError( @@ -1110,9 +1137,9 @@ def load_model_state( else: raise e - if model_on_rank and len(missing_keys) > 0: + if len(missing_keys) > 0: log.warning(f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}") - if model_on_rank and len(unexpected_keys) > 0: + if len(unexpected_keys) > 0: if self.fsdp_config is not None and self.fsdp_config[ 'use_orig_params'] and self.fsdp_state_dict_type == 'local': log.warning( @@ -1122,16 +1149,7 @@ def load_model_state( 'was still loaded correctly.') log.warning(f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}") - # If loading FSDP monolith checkpoint on rank 0 only, the model must be wrapped after loading - if self.load_fsdp_monolith_rank0_only: - assert self.fsdp_config is not None - log.info('Wrapping model with FSDP after loading model_state.') - from composer.trainer.dist_strategy import prepare_fsdp_module - prepare_fsdp_module(self.model, self.optimizers, self.fsdp_config, self.precision, self.device, - self.auto_microbatching) - log.debug('Finished wrapping model with FSDP.') - - def load_optim_state(self, state_dict: Dict[str, Any]): + def _legacy_load_optim_state(self, state_dict: Dict[str, Any]): """Load the optimizer state. Args: @@ -1205,6 +1223,55 @@ def _load_dataset_state(self, obj: Dict[str, Any]) -> None: # starts. This avoids "CUDA error: initialization error" -- its not clear why. # self.dataset_resumption['eval'][evaluator.label] = True + def load_model_and_optimizer_state( + self, + state_dict: Dict[str, Any], + logger: Logger, + strict: bool, + exclude_algorithms: Optional[List[str]] = None, + algorithm_passes: Optional[List[AlgorithmPass]] = None, + load_model_only: bool = False, + ): + if 'algorithms' in state_dict: + self._apply_required_algorithms(state_dict, logger, exclude_algorithms, algorithm_passes) + + if state_dict.get('is_model_ddp', False) and not self.is_model_ddp: + # This check is for backwards compatibility, as pre-v0.6.0 checkpoints serialized the state + # with the `module.` prefix + torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(state_dict['model'], 'module.') + + # Load model and optimizer state + use_state_dict_fns = version.parse(torch.__version__) > version.parse('2.1.3') + if use_state_dict_fns: + from torch.distributed.checkpoint.state_dict import StateDictOptions, set_state_dict + model_state_dict = state_dict.get('model', {}) + optimizer, optim_state_dict = [], {} + if not load_model_only: + optimizer = ensure_tuple(self.optimizers)[0] + optim_state_dict = state_dict['optimizers'].get(type(optimizer).__qualname__, {}) + set_state_dict( + self.model, + optimizers=optimizer, + model_state_dict=model_state_dict, + optim_state_dict=optim_state_dict, + options=StateDictOptions(strict=strict, cpu_offload=True), + ) + else: + self._legacy_load_model_state(state_dict, strict) + + # If loading FSDP monolith checkpoint on rank 0 only, the model must be wrapped after loading + if self.load_fsdp_monolith_rank0_only: + assert self.fsdp_config is not None + log.info('Wrapping model with FSDP after loading model_state.') + from composer.trainer.dist_strategy import prepare_fsdp_module + prepare_fsdp_module(self.model, self.optimizers, self.fsdp_config, self.precision, self.device, + self.auto_microbatching) + log.debug('Finished wrapping model with FSDP.') + + # Legacy optimizer state load must happen after FSDP monolith + if not use_state_dict_fns and not load_model_only: + self._legacy_load_optim_state(state_dict) + def load_state_dict( self, state: Dict[str, Any], @@ -1228,28 +1295,26 @@ def load_state_dict( # Call load_model_state first since it applies required algorithms if 'model' in state: - self.load_model_state( + self.load_model_and_optimizer_state( state, logger, strict=strict, exclude_algorithms=exclude_algorithms, algorithm_passes=algorithm_passes, + load_model_only=(not 'optimizers' in state), ) for attribute_name in sorted(state.keys()): # Sort so all ranks load in the same order serialized_value = state[attribute_name] # Skip removed attributes as well as algorithms and model, which was already loaded - if attribute_name not in self.serialized_attributes or attribute_name == 'model': + if attribute_name not in self.serialized_attributes or attribute_name in ['model', 'optimizers']: continue - # Integrations are extra information about other libraries (e.g. huggingface) and not attributes to be loaded here if attribute_name == 'integrations': continue - # Skip metadata, which is not an attribute on State if attribute_name == 'metadata': continue - log.debug(f'Loading {attribute_name} into state.') # Restructure algorithms serialized_value from list to dict @@ -1258,8 +1323,6 @@ def load_state_dict( if attribute_name == 'dataset_state': self._load_dataset_state(serialized_value) - elif attribute_name == 'optimizers': - self.load_optim_state(state) elif attribute_name == 'train_metrics': # Get current metrics object and populate each metric present # in serialization with serialized data via load_state_dict() diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index f2c8c615b4..66e5f7a509 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -243,10 +243,11 @@ def prepare_fsdp_module( 'gpu and some ranks are on meta. Either keep all ranks on the same ' "device or set fsdp_config['sync_module_states'] = True. Otherwise, " 'some weights may be randomly initialized when loading a checkpoint.') - if fsdp_config['sharding_strategy'] in ('HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'): - raise ValueError('HSDP (HYBRID_SHARD or _HYBRID_SHARD_ZERO2) requires ' - 'fsdp_config["sync_module_states"] = True or different replicas will ' - 'have different weights.') + # Comment out while we debug deadlock + # if fsdp_config['sharding_strategy'] in ('HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'): + # raise ValueError('HSDP (HYBRID_SHARD or _HYBRID_SHARD_ZERO2) requires ' + # 'fsdp_config["sync_module_states"] = True or different replicas will ' + # 'have different weights.') # Check if other ranks OOMed after forward/backward pass when using auto microbatching. This # may happen when close to memory limit or with uneven memory usage across ranks. Since we @@ -273,6 +274,13 @@ def sync_hook(*args): # `nn.Module.named_parameters`. # Setting it to `True` is mandatory when using `torch.compile()`. kwargs['use_orig_params'] = fsdp_config['use_orig_params'] + if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.2.0'): + if 'device_mesh' in fsdp_config: + from torch.distributed._tensor import init_device_mesh + kwargs['device_mesh'] = init_device_mesh( + 'cuda', + tuple([int(x) for x in fsdp_config['device_mesh']]), + ) # necessary variables for optimizers with multiple param groups in FSDP num_param_groups = None diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index ad0fd0904c..bec718b8c1 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -81,3 +81,7 @@ def patch_pytorch(): _runtime_utils._wait_for_computation_stream = _wait_for_computation_stream _runtime_utils._root_pre_forward = _root_pre_forward FullyShardedDataParallel.forward = forward + + # Monkeypatch dtensor support + from composer.trainer.mosaic_fsdp_utils import init_fn_t2p2p0 + FullyShardedDataParallel.__init__ = init_fn_t2p2p0 # type: ignore diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 3cf26d79ec..6dbfbd9bdc 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -4,6 +4,9 @@ # Released under BSD 3-Clause License, # Copyright (c) Facebook, Inc. and its affiliates. +# yapf: disable +# isort: skip_file + """Utilities for monkey patching FSDP.""" import functools @@ -756,6 +759,318 @@ def _sharded_pre_load_state_dict_hook( _enter_unshard_params_ctx(module, fsdp_state, writeback=True) +if version.parse(torch.__version__) > version.parse('2.1.3') and version.parse( + torch.__version__) < version.parse('2.2.1'): + import copy + + from torch.distributed._tensor import DeviceMesh, DTensor, Replicate + from torch.distributed._tensor import Shard as DShard + from torch.distributed.algorithms._comm_hooks import default_hooks + from torch.distributed.device_mesh import _mesh_resources + from torch.distributed.distributed_c10d import _get_default_group + from torch.distributed.fsdp._common_utils import _FSDPState + from torch.distributed.fsdp._init_utils import (HYBRID_SHARDING_STRATEGIES, ProcessGroupType, + _get_default_comm_hook_state, _init_intra_and_inter_node_groups, + _is_valid_hybrid_shard_pg_type) + from torch.distributed.fsdp.fully_sharded_data_parallel import (_annotate_modules_for_dynamo, _auto_wrap, + _check_orig_params_flattened, _init_buffer_state, + _init_core_state, _init_device_handle, + _init_ignored_module_states, + _init_param_handle_from_module, + _init_prefetching_state, _init_runtime_state, + _init_state_dict_state, + _register_all_state_dict_hooks, + _register_flat_param) + from torch.distributed.fsdp.wrap import CustomPolicy, ModuleWrapPolicy, _Policy + from torch.distributed.tensor.parallel.fsdp import DTensorExtensions + + def all_gather_dtensor_t2p2p0( + self, + tensor: DTensor, + parent_mesh: Optional[DeviceMesh], + ) -> torch.Tensor: + """All gather a DTensor in its FSDP dimension and return the local tensor.""" + assert parent_mesh == tensor.device_mesh + + placements = list(copy.deepcopy(tensor.placements)) + # FSDP + TP: [Shard(0), tp_placement] -> [Replicate(), tp_placement] + # HSDP + TP: [Replicate(), Shard(0), tp_placement] -> [Replicate(), Replicate(), tp_placement] + for i in range(0, len(placements) - 1): + placements[i] = Replicate() + tensor = tensor.redistribute( + device_mesh=tensor.device_mesh, + placements=placements, + ) + return tensor.to_local() + + def chunk_dtensor_t2p2p0( + self, + tensor: torch.Tensor, + rank: int, + device_mesh: DeviceMesh, + ) -> DTensor: + """Shard a tensor to chunks along the first dimension. + + The local rank will gets its corresponding chunk as the local tensor to create a DTensor. + """ + parent_mesh = _mesh_resources.get_parent_mesh(device_mesh) + if parent_mesh is None: + raise RuntimeError('No parent device_mesh is found for FSDP device_mesh.') + # if parent_mesh.ndim != 2: + # raise RuntimeError( + # f"Found parent device_mesh of ndim={parent_mesh.ndim},", + # "but only 2D meshes are currently supported.", + # ) + + # We need to explicitly call .detach() to return a new tensor detached from the current graph. + tensor = tensor.clone().detach() + + # When a layer is not involved in TP, then the tensor will not be a DTensor. + # e.g. When a layer is not sppecified in the parallelize_plan, TP will have no effect on the layer. + # e.g. When you do PairwiseParallel on a 3 layer model, TP will have no effect on the third layer. + if isinstance(tensor, torch.Tensor) and not isinstance(tensor, DTensor): + + # For tensors, it is replicated across tp dimension and sharded across FSDP dimension. + # TP is the inner dimension and FSDP is the outer dimension. + # Therefore, shard placements for tensor is (Shard(0), Replicate()). + replicate_placements = [Replicate() for _ in range(parent_mesh.ndim)] + shard_placements = [Replicate() for _ in range(parent_mesh.ndim)] + shard_placements[0] = DShard(0) # type: ignore[call-overload] + + return DTensor.from_local(tensor, parent_mesh, replicate_placements).redistribute( + device_mesh=parent_mesh, + placements=shard_placements, + ) + + else: + tp_placements = tensor.placements + tp_placement = tp_placements[0] + + tensor = tensor.to_local() + + if parent_mesh.ndim <= 2: + # For DTensors, it is sharded across tp dimension first and then sharded across FSDP dimension. + # TP is the inner dimension and FSDP is the outer dimension. + # Therefore, shard placements for tensor is (Shard(0), tp_placement). + replicate_placements = [Replicate() for _ in range(parent_mesh.ndim)] + replicate_placements[-1] = tp_placement # type: ignore[call-overload] + shard_placements = [DShard(0) for _ in range(parent_mesh.ndim)] # type: ignore[misc] + shard_placements[-1] = tp_placement # type: ignore[call-overload] + + elif parent_mesh.ndim == 3: + replicate_placements = [Replicate(), Replicate(), tp_placement] + shard_placements = [Replicate(), DShard(0), tp_placement] # type: ignore[misc] + + return DTensor.from_local(tensor, parent_mesh, replicate_placements).redistribute( + device_mesh=parent_mesh, + placements=shard_placements, + ) + + DTensorExtensions.all_gather_dtensor = all_gather_dtensor_t2p2p0 + DTensorExtensions.chunk_dtensor = chunk_dtensor_t2p2p0 + + def _init_extension_t2p2p0(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState: + # TODO: we need to add additional check once we support FSDP + PiPPy. + # This check is currently sufficient, since we only support FSDP + TP. + if device_mesh and _mesh_resources.get_parent_mesh(state._device_mesh) is not None: + state._fsdp_extension = DTensorExtensions() + else: + # We need to explicilty set _fsdp_extension to None. + # Otherwise, we will run into an infinite recursion when getting the attribute. + state._fsdp_extension = None + return state + + def _is_valid_hybrid_shard_device_mesh_t2p2p0(device_mesh: DeviceMesh) -> bool: + #parent_mesh = _mesh_resources.get_parent_mesh(device_mesh) + #if parent_mesh is not None: + # raise RuntimeError( + # f"Found device_mesh {device_mesh} passed in has a parent device_mesh {parent_mesh}.", + # "Hybrid sharding + TP is not supported yet.", + # ) + return isinstance(device_mesh, DeviceMesh) and device_mesh.ndim == 2 + + def _init_process_group_state_for_hybrid_shard_t2p2p0( + state: _FSDPState, + process_group: ProcessGroupType, + device_mesh: DeviceMesh, + ) -> _FSDPState: + if device_mesh: + if _is_valid_hybrid_shard_device_mesh_t2p2p0(device_mesh): + state._device_mesh = device_mesh + # We currently only allow _inter_node_pg to be the outermost dimension, and the + # process_group(intra_node) to be the innermost dimension. + state._inter_node_pg = device_mesh.get_group(mesh_dim=0) + state.process_group = device_mesh.get_group(mesh_dim=1) + else: + raise ValueError('Expected device_mesh to have ndim=2 ' + f'but got {len(device_mesh.get_group())}') + elif process_group is None: + default_group = _get_default_group() + intra_node_group, inter_node_group = _init_intra_and_inter_node_groups(default_group, + state._device_handle.device_count()) + # we shard across intra-node + state.process_group = intra_node_group + # save _inter_node_pg to allreduce across. + state._inter_node_pg = inter_node_group + else: + # Check type and assign state.process_group and state._inter_node_pg. + if _is_valid_hybrid_shard_pg_type(process_group): + # Assuming that user passed in as intra node group and inter node group + # as documented. + state.process_group, state._inter_node_pg = process_group + else: + raise ValueError('Expected process_group to be passed in as either None or ' + f'Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}') + # Create state for allreduce + state._inter_node_state = _get_default_comm_hook_state(process_group=state._inter_node_pg,) + return state + + def _init_process_group_state_t2p2p0( + state: _FSDPState, + process_group: ProcessGroupType, + sharding_strategy: ShardingStrategy, + policy: Optional[_Policy], + device_mesh: Optional[DeviceMesh] = None, + ) -> _FSDPState: + if process_group is not None and device_mesh is not None: + raise ValueError('Cannot pass both process_group and device_mesh at the ' + 'same time. Please just pass only one of them.') + is_hybrid_strategy = sharding_strategy in HYBRID_SHARDING_STRATEGIES + if is_hybrid_strategy: + if process_group is None and policy is None and device_mesh is None: + # Raise an error here, since this is manual wrapping with no process group + # passed in, there is no way to ensure all wrapped FSDP instances use the same + # process groups. + raise ValueError( + f'Manual wrapping with {sharding_strategy}', + 'requires explicit specification of process group or device_mesh.', + ) + else: + state = _init_process_group_state_for_hybrid_shard_t2p2p0(state, process_group, device_mesh) + else: + if device_mesh: + state._device_mesh = device_mesh + state.process_group = device_mesh.get_group(mesh_dim=0) + else: + state.process_group = (process_group if process_group is not None else _get_default_group()) + + state.rank = state.process_group.rank() + state.world_size = state.process_group.size() + data_parallel_world_size = state.world_size + if is_hybrid_strategy: + data_parallel_world_size *= state._inter_node_pg.size() + state._gradient_predivide_factor = ( + default_hooks.DefaultState._get_gradient_predivide_factor(data_parallel_world_size)) + state._gradient_postdivide_factor = (data_parallel_world_size / state._gradient_predivide_factor) + return state + + def init_fn_t2p2p0( + self, + module: nn.Module, + process_group: ProcessGroupType = None, + sharding_strategy: Optional[ShardingStrategy] = None, + cpu_offload: Optional[CPUOffload] = None, + auto_wrap_policy: Optional[Union[Callable, ModuleWrapPolicy, CustomPolicy]] = None, + backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE, + mixed_precision: Optional[MixedPrecision] = None, + ignored_modules: Optional[Iterable[torch.nn.Module]] = None, + param_init_fn: Optional[Callable[[nn.Module], None]] = None, + device_id: Optional[Union[int, torch.device]] = None, + sync_module_states: bool = False, + forward_prefetch: bool = False, + limit_all_gathers: bool = True, + use_orig_params: bool = False, + ignored_states: Union[Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]] = None, + device_mesh: Optional[DeviceMesh] = None, + ): + """Docstring for lint.""" + torch._C._log_api_usage_once('torch.distributed.fsdp') + super(FullyShardedDataParallel, self).__init__() + _init_ignored_module_states(self, module, ignored_modules, ignored_states) + _init_device_handle(self, module, self._ignored_params, device_id) + + # Add module annotations for Dynamo support (see function for details) + _annotate_modules_for_dynamo(module, self._ignored_modules, use_orig_params) + + # Initializes self.process_group, along with rank and world size. This will + # also set another attribute, _inter_node_pg, to control the process group + # over which sharding occurs, if sharding_strategy is {HYBRID_SHARD, _HYBRID_SHARD_ZERO2}. + # Note that this is done before auto_wrapping, so that child FSDP modules simply pick up + # the same process group state as the root FSDP module. + self._device_mesh = device_mesh + _init_process_group_state_t2p2p0( + self, + process_group, + sharding_strategy, + auto_wrap_policy, + device_mesh, + ) + if auto_wrap_policy is not None: + root_kwargs = { + 'process_group': process_group, + 'sharding_strategy': sharding_strategy, + 'cpu_offload': cpu_offload, + 'backward_prefetch': backward_prefetch, + 'mixed_precision': mixed_precision, + 'param_init_fn': param_init_fn, + 'device_id': device_id, + 'sync_module_states': sync_module_states, + 'forward_prefetch': forward_prefetch, + 'limit_all_gathers': limit_all_gathers, + 'use_orig_params': use_orig_params, + 'ignored_states': self._ignored_params, + 'device_mesh': device_mesh, + } + if sharding_strategy in HYBRID_SHARDING_STRATEGIES and device_mesh is None: + # Share root process groups with children to maintain + # the invariant that all FSDP modules will have the same + # process groups. + root_kwargs['process_group'] = (self.process_group, self._inter_node_pg) + + _auto_wrap( + module, + auto_wrap_policy, + self._ignored_modules, + self._ignored_params, + root_kwargs, + FullyShardedDataParallel, + ) + + backward_prefetch_limit = 1 + forward_prefetch_limit = 1 + _init_core_state( + self, + sharding_strategy, + mixed_precision, + cpu_offload, + limit_all_gathers, + use_orig_params, + backward_prefetch_limit, + forward_prefetch_limit, + ) + _init_runtime_state(self) + _init_prefetching_state(self, backward_prefetch, forward_prefetch) + _init_buffer_state(self, module) + # extension needs to be set before `_init_param_handle_from_module()` + _init_extension_t2p2p0(self, device_mesh) + _init_param_handle_from_module( + self, + module, + device_id, + param_init_fn, + sync_module_states, + ) + self._fsdp_wrapped_module = module + if not use_orig_params: + _check_orig_params_flattened(self, self._ignored_params) + _register_flat_param(self, self) + + # `_state_dict_type` controls the `state_dict()` behavior, which is + # implemented using post-save and pre-load hooks + _init_state_dict_state(self) + _register_all_state_dict_hooks(self) + + def fsdp_state_has_default_pg(state: '_FSDPState') -> bool: """Indicates whether FlatParamHandle has the default process group. @@ -1044,7 +1359,7 @@ def _share_state_and_init_handle_attrs_t2p2( handle = root_state._handle if handle: handle.init_flat_param_attributes() - _validate_and_get_hybrid_shard_state(root_module) + # _validate_and_get_hybrid_shard_state(root_module) attr_name_to_values: Dict[str, Set[Any]] = {} for attr_name in HOMOGENEOUS_ATTR_NAMES: attr_name_to_values[attr_name] = set() diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index c8c6d325e0..0655be3004 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -960,10 +960,7 @@ def __init__( assert not isinstance(device_train_microbatch_size, str) # Distributed - if deepspeed_config is not None or fsdp_config is not None or dist.get_world_size() > 1: - # Deepspeed and FSDP both require torch.distributed to be initialized, even if the world size is 1 - # And torch.distributed is always required for multi-rank training - dist.initialize_dist(device, dist_timeout) + dist.initialize_dist(device, dist_timeout) # Reproducibility rank_zero_seed, seed = _distribute_and_get_random_seed(seed, device) diff --git a/composer/utils/checkpoint.py b/composer/utils/checkpoint.py index 4e1ee7777f..978e59330a 100644 --- a/composer/utils/checkpoint.py +++ b/composer/utils/checkpoint.py @@ -496,29 +496,37 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): # We need no_grad because we overwrite tensor values with set_() when we do elastic loading and we don't want the set_ op recorded in the computation graph. with torch.no_grad(): # 1. Load model and metadata first - model_state_dict = None if load_weights_only: - model_state_dict = {'state': {'model': state.get_model_state_dict()}} + state_dict = {'state': {'model': state.get_model_state_dict()}} else: cur_state_dict = state.state_dict() - cur_state_dict.pop('optimizers') - model_state_dict = {'state': cur_state_dict} + # For older versions of torch, we load optimizer separately. + if version.parse(torch.__version__) < version.parse('2.1.3'): + cur_state_dict.pop('optimizers') + state_dict = {'state': cur_state_dict} if ignore_keys: # Filter provided list of key paths if not callable(ignore_keys): ignore_keys = glob_filter(ignore_keys) # Call function to modify state_dict - ignore_keys(model_state_dict) + ignore_keys(state_dict) - dist_cp.load_state_dict( - state_dict=model_state_dict, - storage_reader=storage_reader, - planner=load_planner, - ) + if version.parse(torch.__version__) > version.parse('2.1.3'): + dist_cp.load( # type: ignore + state_dict=state_dict, + storage_reader=storage_reader, + planner=load_planner, + ) + else: + dist_cp.load_state_dict( + state_dict=state_dict, + storage_reader=storage_reader, + planner=load_planner, + ) state.load_state_dict( - model_state_dict['state'], + state_dict['state'], logger, strict=strict_model_weights, exclude_algorithms=exclude_algorithms, @@ -526,11 +534,12 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner): ) # 2. Optionally load optimizer - if not load_weights_only: + # if we are using later than 2.1.0 then optimizer will already be loaded + if version.parse(torch.__version__) < version.parse('2.1.3') and not load_weights_only: optim_state = load_sharded_optimizer_state_dict(model_state_dict=state.state_dict()['model'], optimizer_key='optimizers', storage_reader=storage_reader) - state.load_optim_state(optim_state) + state._legacy_load_optim_state(optim_state) # 3. Optionally load RNG rng_state_dicts = reproducibility.get_rng_state() @@ -859,12 +868,13 @@ def _restore_checkpoint( if load_path is None: raise RuntimeError(f'Failed to load DeepSpeed checkpoint') elif load_weights_only: - state.load_model_state( + state.load_model_and_optimizer_state( state_dict['state'], logger, strict=strict_model_weights, exclude_algorithms=exclude_algorithms, algorithm_passes=algorithm_passes, + load_model_only=True, ) if not load_weights_only: state.load_state_dict( @@ -921,12 +931,12 @@ def _save_checkpoint( } if state.fsdp_sharded_state_dict_enabled: - # To load optimizer states with torch 2.0, the optimizer state must be at the top + # To load optimizer states with 2.0 <= torch < 2.1.3 , the optimizer state must be at the top # level of the state dict because the load_sharded_optimizer_state_dict function # requires a top level state dict key for the optimizer. # See https://github.com/pytorch/pytorch/blob/v2.0.1/torch/distributed/checkpoint/optimizer.py#L271 # for more info. - if using_torch_2(): + if using_torch_2() and version.parse(torch.__version__) < version.parse('2.1.3'): if not weights_only: state_dict['optimizers'] = state_dict['state'].pop('optimizers') log.debug('State dict created.') @@ -935,8 +945,12 @@ def _save_checkpoint( if dirname: os.makedirs(dirname, exist_ok=True) + # Only some ranks are meant to save checkpoint and produce a file + expect_file = False + # All ranks save for deepspeed if is_deepspeed: + expect_file = True log.debug('Saving deepspeed checkpoints to %s...', save_filename) if dist.get_global_rank() == 0: with open(save_filename, 'wb') as f: @@ -954,16 +968,41 @@ def _save_checkpoint( _validate_save_planner(save_planner) import torch.distributed.checkpoint as dist_cp - - log.debug('Saving sharded checkpoints to %s...', save_filename) - dist_cp.save_state_dict( - state_dict=state_dict, - storage_writer=dist_cp.FileSystemWriter(dirname), - planner=save_planner, - ) + from torch.distributed import get_process_group_ranks + + log.debug(f'Saving sharded checkpoints to {save_filename}...') + process_group = None + device_mesh = state.fsdp_device_mesh + if device_mesh is not None and device_mesh.ndim == 2: + expect_file = (device_mesh.get_local_rank(mesh_dim=0) == 0) + if expect_file: + process_group = device_mesh.get_group(1) # Only save on first replica + log.debug( + f'global_rank={dist.get_global_rank()}, {expect_file=}, process_group={get_process_group_ranks(process_group)}' + ) + else: + expect_file = True + + if expect_file: + if version.parse(torch.__version__) > version.parse('2.1.3'): + dist_cp.save( # type: ignore + state_dict=state_dict, + storage_writer=dist_cp.FileSystemWriter(dirname), + planner=save_planner, + process_group=process_group, + ) + else: + dist_cp.save_state_dict( + state_dict=state_dict, + storage_writer=dist_cp.FileSystemWriter(dirname), + planner=save_planner, + process_group=process_group, + ) + log.debug('Finished pytorch save state dict') # Only rank 0 saves the state_dict unless you are using sharded checkpointing with torch <2.0 elif dist.get_global_rank() == 0 or state.fsdp_sharded_state_dict_enabled: + expect_file = True log_msg = f'Saving sharded checkpoints to {save_filename}...' if state.fsdp_sharded_state_dict_enabled else f'Saving monolithic checkpoint to {save_filename}' with open(save_filename, 'wb') as f: log.debug(log_msg) @@ -979,7 +1018,7 @@ def _save_checkpoint( dist.barrier() # ensure all ranks saved their files - if dist.get_global_rank() == 0 or is_deepspeed or state.fsdp_sharded_state_dict_enabled: + if expect_file: assert os.path.exists(save_filename), 'Expected file to have been saved.' return save_filename else: diff --git a/pyproject.toml b/pyproject.toml index 342c9b3d7e..f4155e23ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -151,6 +151,10 @@ filterwarnings = [ '''ignore:torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead:UserWarning''', # Ignore torch sharded tensor deprecated warnings '''ignore:Please use DTensor instead and we are deprecating ShardedTensor.:UserWarning''', + # Ignore torch pytree deprecated warnings + '''ignore:torch.utils._pytree._register_pytree_node is deprecated.*:UserWarning''', + # Ignore autograd kernel warning inside DeepSpeed + '''ignore:.*an autograd kernel was not registered to the Autograd key.*:UserWarning''' ] # Coverage diff --git a/tests/algorithms/test_algorithm_resumption.py b/tests/algorithms/test_algorithm_resumption.py index 1f89f551d5..9f243caeae 100644 --- a/tests/algorithms/test_algorithm_resumption.py +++ b/tests/algorithms/test_algorithm_resumption.py @@ -57,7 +57,7 @@ def test_algorithm_resumption( 'save_filename': 'ep{epoch}-rank{rank}', 'save_interval': '1ep', 'train_subset_num_batches': 2, - 'precision': 'amp_fp16', + 'precision': 'amp_bf16', } train_dataloader = get_alg_dataloader(alg_cls) if world_size == 1 else get_alg_dataloader(alg_cls, multigpu=True) # train model once, saving checkpoints every epoch diff --git a/tests/algorithms/test_required_on_load.py b/tests/algorithms/test_required_on_load.py index 3844a57084..ddb05a0c3c 100644 --- a/tests/algorithms/test_required_on_load.py +++ b/tests/algorithms/test_required_on_load.py @@ -9,6 +9,7 @@ import pytest import torch +from packaging import version from composer import Trainer, algorithms from composer.callbacks import CheckpointSaver @@ -163,14 +164,20 @@ def test_autoload(algo_name: str, load_weights_only: bool, already_added: bool, context = pytest.warns(UserWarning, match='Automatically adding required_on_load algorithm*') # Excluding some algorithms leads to errors when loading elif exclude: - if algo_name in ['Factorize', 'SqueezeExcite']: - context = pytest.raises( - ValueError, - match= - "loaded state dict contains a parameter group that doesn't match the size of optimizer's group", - ) - elif algo_name == 'Alibi': - context = pytest.raises(RuntimeError) + if version.parse(torch.__version__) > version.parse('2.1.3'): + if algo_name in [ + 'Alibi', 'BlurPool', 'Factorize', 'GatedLinearUnits', 'GhostBatchNorm', 'SqueezeExcite' + ]: + context = pytest.raises(KeyError) # Optimizer loading is strict + else: + if algo_name in ['Factorize', 'SqueezeExcite']: + context = pytest.raises( + ValueError, + match= + "loaded state dict contains a parameter group that doesn't match the size of optimizer's group", + ) + elif algo_name == 'Alibi': + context = pytest.raises(RuntimeError) with context: trainer2 = Trainer( diff --git a/tests/callbacks/test_memory_monitor.py b/tests/callbacks/test_memory_monitor.py index f40a04eeb3..f2badc638c 100644 --- a/tests/callbacks/test_memory_monitor.py +++ b/tests/callbacks/test_memory_monitor.py @@ -7,13 +7,10 @@ from composer.callbacks import MemoryMonitor from composer.loggers import InMemoryLogger from composer.trainer import Trainer -from tests.common import RandomClassificationDataset, SimpleModel, device +from tests.common import RandomClassificationDataset, SimpleModel -@device('cpu', 'gpu') -def test_memory_monitor_warnings_on_cpu_models(device: str): - # Error if the user sets device=cpu even when cuda is available - del device # unused. always using cpu +def test_memory_monitor_warnings_on_cpu_models(): with pytest.warns(UserWarning, match='The memory monitor only works on CUDA devices'): Trainer( model=SimpleModel(), diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index e1658cf62e..4a471c9b38 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -694,7 +694,11 @@ def test_strict_errors(self, missing_key: bool, unexpected_key: bool): last_checkpoint = os.path.join('first', 'ep2.pt') if missing_key or unexpected_key: - error_context = pytest.raises(RuntimeError, match='Failed to load checkpoint due to') + message = r'Error\(s\) in loading state_dict' + if version.parse(torch.__version__) < version.parse('2.1.3'): + # Composer implements strict for older torch versions + message = 'Failed to load checkpoint due to' + error_context = pytest.raises(RuntimeError, match=message) else: error_context = contextlib.nullcontext() @@ -977,8 +981,10 @@ def test_autoload_algorithm_old_checkpoint(self): old_init, old_repr = NoOpModel.__init__, NoOpModel.__repr__ NoOpModel.__init__ = lambda self, x: None # type: ignore NoOpModel.__repr__ = lambda self: 'NoOpModel(3)' - with pytest.warns(UserWarning, match='required_on_load algorithm.*'), pytest.raises( - ValueError, match='loaded state dict contains a parameter group.*'): + error_context = pytest.raises(KeyError, match='module.0.weight') + if version.parse(torch.__version__) < version.parse('2.1.3'): + error_context = pytest.raises(ValueError, match='loaded state dict contains a parameter group.*') + with pytest.warns(UserWarning, match='required_on_load algorithm.*'), error_context: trainer_3 = self.get_trainer(load_path=os.path.join('first', 'ep1.pt'),) trainer_3.fit(duration='1ba') # Restore algorithm @@ -1310,6 +1316,7 @@ def test_rotate_checkpoints( dataset=train_dataset, sampler=dist.get_sampler(train_dataset), ), + precision='fp32', save_folder=str(save_folder), save_filename='checkpoint_{rank}_{batch}.pt', save_interval='1ba',