Skip to content

Commit

Permalink
HSDP Support (mosaicml#2648)
Browse files Browse the repository at this point in the history
* add hsdp

* add tuple support

* mod wide

* update

* set default

* disable error validation

* hsdp

* gate import
  • Loading branch information
mvpatel2000 authored Oct 17, 2023
1 parent 1caacc3 commit b7b55b3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
3 changes: 1 addition & 2 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
from composer.utils import (FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, FORMAT_NAME_WITH_DIST_TABLE, PartialFilePath,
checkpoint, create_interval_scheduler, create_symlink_file, dist,
ensure_folder_has_no_conflicting_files, format_name_with_dist,
format_name_with_dist_and_time, is_model_deepspeed, reproducibility)
format_name_with_dist_and_time, is_model_deepspeed, reproducibility, using_torch_2)
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME
from composer.utils.misc import using_torch_2

log = logging.getLogger(__name__)

Expand Down
15 changes: 11 additions & 4 deletions composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
"""Utilities for monkey patching FSDP."""

import functools
import inspect
import logging
import math
import warnings
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Set, Tuple, Union, cast, no_type_check

import torch
Expand All @@ -29,7 +27,7 @@
from torch.distributed.utils import _replace_by_prefix

from composer.core import Precision
from composer.utils import dist
from composer.utils import dist, using_torch_2

if TYPE_CHECKING:
if version.parse(torch.__version__) >= version.parse('2.0.1') and version.parse(
Expand All @@ -42,6 +40,10 @@
'FULL_SHARD': ShardingStrategy.FULL_SHARD,
}

if using_torch_2():
SHARDING_MAP['_HYBRID_SHARD_ZERO2'] = ShardingStrategy._HYBRID_SHARD_ZERO2
SHARDING_MAP['HYBRID_SHARD'] = ShardingStrategy.HYBRID_SHARD

BACKWARD_PREFETCH_MAP = {
'NONE': None,
'BACKWARD_PRE': BackwardPrefetch.BACKWARD_PRE,
Expand Down Expand Up @@ -198,7 +200,12 @@ def _set_custom_fsdp_module_kwargs(module_kwargs: Dict, process_group_cache: Dic
f"Automated setting of custom per module mixed_precision is not implemented, but it can be set if `isinstance(module_kwargs['mixed_precision'], MixedPrecision)`"
)
if 'process_group' in module_kwargs:
module_kwargs['process_group'] = _get_process_group(module_kwargs['process_group'], process_group_cache)
# Call on every process group if it is a tuple
if isinstance(module_kwargs['process_group'], tuple):
module_kwargs['process_group'] = tuple(
_get_process_group(pg, process_group_cache) for pg in module_kwargs['process_group'])
else:
module_kwargs['process_group'] = _get_process_group(module_kwargs['process_group'], process_group_cache)

return module_kwargs

Expand Down

0 comments on commit b7b55b3

Please sign in to comment.