Skip to content

Commit

Permalink
Add process group as arg to FSDP (mosaicml#2794)
Browse files Browse the repository at this point in the history
* add test

* only cast if PG is specified

* add to docstring

* filter warning

* filter warning

* docs

* support lists

* remove warnings

* lint

* hsdp monkeypatch

* logs

* change log

* fix patch

* typo

* clean up logs
  • Loading branch information
mvpatel2000 authored Dec 26, 2023
1 parent a3ea7a4 commit 2aa50e7
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 19 deletions.
10 changes: 8 additions & 2 deletions composer/trainer/dist_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from composer.devices import Device
from composer.trainer.meta_safe_apply import meta_safe_apply
from composer.trainer.mosaic_fsdp import patch_pytorch
from composer.trainer.mosaic_fsdp_utils import BACKWARD_PREFETCH_MAP, SHARDING_MAP, get_cpu_offload, get_mixed_precision
from composer.trainer.mosaic_fsdp_utils import (BACKWARD_PREFETCH_MAP, SHARDING_MAP, _set_custom_fsdp_module_kwargs,
get_cpu_offload, get_mixed_precision)
from composer.utils import StringEnum, dist, ensure_tuple, using_torch_2

__all__ = ['DDPSyncStrategy', 'ddp_sync_context', 'prepare_ddp_module', 'prepare_fsdp_module']
Expand Down Expand Up @@ -143,6 +144,7 @@ def set_fsdp_default(fsdp_config: Dict[str, Any]):
fsdp_config.setdefault('load_monolith_rank0_only', False)
fsdp_config.setdefault('load_planner', None)
fsdp_config.setdefault('mixed_precision', 'DEFAULT')
fsdp_config.setdefault('process_group', None)
fsdp_config.setdefault('save_planner', None)
fsdp_config.setdefault('sharded_ckpt_prefix_dir', 'ep{epoch}-ba{batch}')
fsdp_config.setdefault('sharding_strategy', 'FULL_SHARD')
Expand Down Expand Up @@ -347,6 +349,10 @@ def sync_hook(*args):
f'Consider using `amp` or `bf16` for precision or setting param_dtype in mixed_precision to `None` '
f'with sharding strategy `{sharding_map_key}.`')

process_group = None
if fsdp_config['process_group'] is not None:
process_group_dict = {'process_group': fsdp_config['process_group']}
process_group = _set_custom_fsdp_module_kwargs(process_group_dict, process_group_cache)['process_group']
backward_prefetch = BACKWARD_PREFETCH_MAP[fsdp_config['backward_prefetch'].upper()]
activation_checkpointing = fsdp_config['activation_checkpointing']
activation_cpu_offload = fsdp_config['activation_cpu_offload']
Expand Down Expand Up @@ -510,7 +516,6 @@ def lambda_fn(module: torch.nn.Module) -> Union[bool, dict]:
ret = bool(module._fsdp_wrap)
elif hasattr(obj, 'fsdp_wrap_fn') and isinstance(obj.fsdp_wrap_fn, Callable):
ret = obj.fsdp_wrap_fn(module)
from composer.trainer.mosaic_fsdp_utils import _set_custom_fsdp_module_kwargs
if isinstance(ret, dict):
ret = _set_custom_fsdp_module_kwargs(ret, process_group_cache)
if ret and auto_microbatching:
Expand Down Expand Up @@ -553,6 +558,7 @@ def _auto_wrap_policy_old(module: torch.nn.Module, recurse: bool, unwrapped_para

fsdp_obj = FullyShardedDataParallel(
obj,
process_group=process_group,
sharding_strategy=sharding_strategy,
auto_wrap_policy=_auto_wrap_policy, # type: ignore FSDP type bug
cpu_offload=cpu_offload,
Expand Down
13 changes: 12 additions & 1 deletion composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,21 @@ def patch_pytorch():

elif version.parse(torch.__version__) < version.parse('2.1.1'):
# Monkey path for torch < 2.1.1 ie torch == 2.1.0
from torch.distributed.fsdp import _state_dict_utils

# Monkey patch sharding method
ChunkShardingSpec.build_metadata = build_metadata

# Monkey patch partial state dict handling
from torch.distributed.fsdp import _state_dict_utils
_state_dict_utils._sharded_pre_load_state_dict_hook = (_sharded_pre_load_state_dict_hook)

# Allow 2D HSDP
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

elif version.parse(torch.__version__) < version.parse('2.2.0'):
# Monkey path for torch < 2.2.0 ie torch == 2.1.1, 2.1.2

# Allow 2D HSDP
from torch.distributed.fsdp import _runtime_utils
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None
28 changes: 12 additions & 16 deletions composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
torch.__version__) < version.parse('2.0.2'):
from torch.distributed.fsdp._common_utils import _FSDPState

log = logging.getLogger(__name__)

SHARDING_MAP = {
'NO_SHARD': ShardingStrategy.NO_SHARD,
'SHARD_GRAD_OP': ShardingStrategy.SHARD_GRAD_OP,
Expand Down Expand Up @@ -124,10 +126,7 @@ def get_cpu_offload(cpu_offload=False):

def _get_process_group(pg, process_group_cache=None):
"""Helper function for configuring and/or retrieving process groups."""
warnings.warn(f'Instantiating FSDP with custom process groups is an experimental feature.')

# Return regular process_groups as is, no cacheing
if pg is None or isinstance(pg, ProcessGroup):
if pg is None or isinstance(pg, ProcessGroup): # Return as is, no caching
return pg

world_size = dist.get_world_size()
Expand All @@ -136,13 +135,13 @@ def _get_process_group(pg, process_group_cache=None):
# Handle special str process_group cases
if pg == 'self':
pg = 'set1'
warnings.warn(f"Converting process_group='self' to process_group='{pg}'")
log.info(f"Converting process_group='self' to process_group='{pg}'")
elif pg == 'node':
pg = f'set{local_world_size}'
warnings.warn(f"Converting process_group='node' to process_group='{pg}'")
log.info(f"Converting process_group='node' to process_group='{pg}'")
elif pg == 'local_rank_across_nodes':
pg = f'mod{local_world_size}'
warnings.warn(f"Converting process_group='local_rank_across_nodes' to process_group='{pg}'")
log.info(f"Converting process_group='local_rank_across_nodes' to process_group='{pg}'")

# Handle str and Union[List[int], Tuple[int]] process_group cases
if isinstance(pg, str) and pg.startswith('set'):
Expand All @@ -164,15 +163,10 @@ def _get_process_group(pg, process_group_cache=None):
raise ValueError(f'Unsure how to setup process_group={pg}')

if process_group_cache is not None and ranks in process_group_cache:
warnings.warn(
f'On rank={dist.get_global_rank()} using cached progress group with {ranks=}. ' +
'If the intention was to use a new process group, a new process group can be instantiated and passed' +
" in as an arguement (`'process_group': newly_instantiated_process_group_obect,`)")
log.info(f'Using cached progress group with {ranks=} on rank={dist.get_global_rank()}.')
return process_group_cache[ranks]

warnings.warn(
f'Composer is instantiating custom process groups with {ranks=} (on rank={dist.get_global_rank()}). ' +
'This is an experimental feature.')
log.info(f'Instantiating custom process groups with {ranks=} on rank={dist.get_global_rank()}.')

ranks_per_subgroup_list = list(set(dist.all_gather_object(ranks)))
(
Expand Down Expand Up @@ -200,8 +194,10 @@ def _set_custom_fsdp_module_kwargs(module_kwargs: Dict, process_group_cache: Dic
f"Automated setting of custom per module mixed_precision is not implemented, but it can be set if `isinstance(module_kwargs['mixed_precision'], MixedPrecision)`"
)
if 'process_group' in module_kwargs:
# Call on every process group if it is a tuple
if isinstance(module_kwargs['process_group'], tuple):
# Call on every process group if it is a tuple/list of non-ints
if type(module_kwargs['process_group']) in [
list, tuple
] and not all(isinstance(x, int) for x in module_kwargs['process_group']):
module_kwargs['process_group'] = tuple(
_get_process_group(pg, process_group_cache) for pg in module_kwargs['process_group'])
else:
Expand Down
1 change: 1 addition & 0 deletions docs/source/notes/distributed_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ The full spec and defaults for Composer's `fsdp_config` is here:
# 'reduce_dtype': 'fp32' | 'fp16' | 'bf16',
# 'buffer_dtype': 'fp32' | 'fp16' | 'bf16',
# },
'process_group': str = 'self' | 'node' | 'local_rank_across_nodes' | 'setK' | 'modK', # Default: None
'save_planner': torch.distributed.checkpoint.planner.SavePlanner, # Default: None
'sharded_ckpt_prefix_dir': str = 'ep{epoch}-ba{batch}', # Default: 'ep{epoch}-ba{batch}'
'sharding_strategy': str = 'FULL_SHARD' | 'SHARD_GRAD_OP' | 'NO_SHARD', # Default: 'FULL_SHARD'
Expand Down
27 changes: 27 additions & 0 deletions tests/trainer/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,33 @@ def test_fsdp_prefetch_limit(forward_prefetch_limit: int, backward_prefetch_limi
trainer.fit()


@pytest.mark.gpu
@world_size(2)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'),
reason='FSDP requires PyTorch 1.13 or higher')
@pytest.mark.filterwarnings('ignore:Instantiating FSDP with custom process groups.*:UserWarning')
@pytest.mark.filterwarnings('ignore:Composer is instantiating custom process groups.*:UserWarning')
def test_fsdp_process_group(world_size: int):
model = SimpleModel()
model.fc1._fsdp_wrap = True
model.fc2._fsdp_wrap = True
dataset = RandomClassificationDataset(size=10)
dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset))
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

trainer = Trainer(
model=model,
optimizers=optimizer,
train_dataloader=dataloader,
fsdp_config={
'process_group': 'mod1', # all ranks
},
max_duration='3ba',
)

trainer.fit()


class SimpleMLP(ComposerModel):

def __init__(self, num_features: int = 128, device: str = 'cuda'):
Expand Down

0 comments on commit 2aa50e7

Please sign in to comment.