From b7b55b38f4e8ee847518382c5901ab801cad0668 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 17 Oct 2023 14:46:18 -0400 Subject: [PATCH] HSDP Support (#2648) * add hsdp * add tuple support * mod wide * update * set default * disable error validation * hsdp * gate import --- composer/callbacks/checkpoint_saver.py | 3 +-- composer/trainer/mosaic_fsdp_utils.py | 15 +++++++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/composer/callbacks/checkpoint_saver.py b/composer/callbacks/checkpoint_saver.py index 6b81cda7ab..34acb0445f 100644 --- a/composer/callbacks/checkpoint_saver.py +++ b/composer/callbacks/checkpoint_saver.py @@ -19,9 +19,8 @@ from composer.utils import (FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, FORMAT_NAME_WITH_DIST_TABLE, PartialFilePath, checkpoint, create_interval_scheduler, create_symlink_file, dist, ensure_folder_has_no_conflicting_files, format_name_with_dist, - format_name_with_dist_and_time, is_model_deepspeed, reproducibility) + format_name_with_dist_and_time, is_model_deepspeed, reproducibility, using_torch_2) from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME -from composer.utils.misc import using_torch_2 log = logging.getLogger(__name__) diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 585490cd2c..ad1c289a7a 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -7,11 +7,9 @@ """Utilities for monkey patching FSDP.""" import functools -import inspect import logging import math import warnings -from functools import partial from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Set, Tuple, Union, cast, no_type_check import torch @@ -29,7 +27,7 @@ from torch.distributed.utils import _replace_by_prefix from composer.core import Precision -from composer.utils import dist +from composer.utils import dist, using_torch_2 if TYPE_CHECKING: if version.parse(torch.__version__) >= version.parse('2.0.1') and version.parse( @@ -42,6 +40,10 @@ 'FULL_SHARD': ShardingStrategy.FULL_SHARD, } +if using_torch_2(): + SHARDING_MAP['_HYBRID_SHARD_ZERO2'] = ShardingStrategy._HYBRID_SHARD_ZERO2 + SHARDING_MAP['HYBRID_SHARD'] = ShardingStrategy.HYBRID_SHARD + BACKWARD_PREFETCH_MAP = { 'NONE': None, 'BACKWARD_PRE': BackwardPrefetch.BACKWARD_PRE, @@ -198,7 +200,12 @@ def _set_custom_fsdp_module_kwargs(module_kwargs: Dict, process_group_cache: Dic f"Automated setting of custom per module mixed_precision is not implemented, but it can be set if `isinstance(module_kwargs['mixed_precision'], MixedPrecision)`" ) if 'process_group' in module_kwargs: - module_kwargs['process_group'] = _get_process_group(module_kwargs['process_group'], process_group_cache) + # Call on every process group if it is a tuple + if isinstance(module_kwargs['process_group'], tuple): + module_kwargs['process_group'] = tuple( + _get_process_group(pg, process_group_cache) for pg in module_kwargs['process_group']) + else: + module_kwargs['process_group'] = _get_process_group(module_kwargs['process_group'], process_group_cache) return module_kwargs