From 2aa50e7741a077ff21f5743934fbcf4b755d441e Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 26 Dec 2023 12:48:06 -0700 Subject: [PATCH] Add process group as arg to FSDP (#2794) * add test * only cast if PG is specified * add to docstring * filter warning * filter warning * docs * support lists * remove warnings * lint * hsdp monkeypatch * logs * change log * fix patch * typo * clean up logs --- composer/trainer/dist_strategy.py | 10 ++++++-- composer/trainer/mosaic_fsdp.py | 13 +++++++++- composer/trainer/mosaic_fsdp_utils.py | 28 ++++++++++------------ docs/source/notes/distributed_training.rst | 1 + tests/trainer/test_fsdp.py | 27 +++++++++++++++++++++ 5 files changed, 60 insertions(+), 19 deletions(-) diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index 1875af8314..29d408471b 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -18,7 +18,8 @@ from composer.devices import Device from composer.trainer.meta_safe_apply import meta_safe_apply from composer.trainer.mosaic_fsdp import patch_pytorch -from composer.trainer.mosaic_fsdp_utils import BACKWARD_PREFETCH_MAP, SHARDING_MAP, get_cpu_offload, get_mixed_precision +from composer.trainer.mosaic_fsdp_utils import (BACKWARD_PREFETCH_MAP, SHARDING_MAP, _set_custom_fsdp_module_kwargs, + get_cpu_offload, get_mixed_precision) from composer.utils import StringEnum, dist, ensure_tuple, using_torch_2 __all__ = ['DDPSyncStrategy', 'ddp_sync_context', 'prepare_ddp_module', 'prepare_fsdp_module'] @@ -143,6 +144,7 @@ def set_fsdp_default(fsdp_config: Dict[str, Any]): fsdp_config.setdefault('load_monolith_rank0_only', False) fsdp_config.setdefault('load_planner', None) fsdp_config.setdefault('mixed_precision', 'DEFAULT') + fsdp_config.setdefault('process_group', None) fsdp_config.setdefault('save_planner', None) fsdp_config.setdefault('sharded_ckpt_prefix_dir', 'ep{epoch}-ba{batch}') fsdp_config.setdefault('sharding_strategy', 'FULL_SHARD') @@ -347,6 +349,10 @@ def sync_hook(*args): f'Consider using `amp` or `bf16` for precision or setting param_dtype in mixed_precision to `None` ' f'with sharding strategy `{sharding_map_key}.`') + process_group = None + if fsdp_config['process_group'] is not None: + process_group_dict = {'process_group': fsdp_config['process_group']} + process_group = _set_custom_fsdp_module_kwargs(process_group_dict, process_group_cache)['process_group'] backward_prefetch = BACKWARD_PREFETCH_MAP[fsdp_config['backward_prefetch'].upper()] activation_checkpointing = fsdp_config['activation_checkpointing'] activation_cpu_offload = fsdp_config['activation_cpu_offload'] @@ -510,7 +516,6 @@ def lambda_fn(module: torch.nn.Module) -> Union[bool, dict]: ret = bool(module._fsdp_wrap) elif hasattr(obj, 'fsdp_wrap_fn') and isinstance(obj.fsdp_wrap_fn, Callable): ret = obj.fsdp_wrap_fn(module) - from composer.trainer.mosaic_fsdp_utils import _set_custom_fsdp_module_kwargs if isinstance(ret, dict): ret = _set_custom_fsdp_module_kwargs(ret, process_group_cache) if ret and auto_microbatching: @@ -553,6 +558,7 @@ def _auto_wrap_policy_old(module: torch.nn.Module, recurse: bool, unwrapped_para fsdp_obj = FullyShardedDataParallel( obj, + process_group=process_group, sharding_strategy=sharding_strategy, auto_wrap_policy=_auto_wrap_policy, # type: ignore FSDP type bug cpu_offload=cpu_offload, diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index 4bbd878c44..bf6ebaa228 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -42,10 +42,21 @@ def patch_pytorch(): elif version.parse(torch.__version__) < version.parse('2.1.1'): # Monkey path for torch < 2.1.1 ie torch == 2.1.0 - from torch.distributed.fsdp import _state_dict_utils # Monkey patch sharding method ChunkShardingSpec.build_metadata = build_metadata # Monkey patch partial state dict handling + from torch.distributed.fsdp import _state_dict_utils _state_dict_utils._sharded_pre_load_state_dict_hook = (_sharded_pre_load_state_dict_hook) + + # Allow 2D HSDP + from torch.distributed.fsdp import _runtime_utils + _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None + + elif version.parse(torch.__version__) < version.parse('2.2.0'): + # Monkey path for torch < 2.2.0 ie torch == 2.1.1, 2.1.2 + + # Allow 2D HSDP + from torch.distributed.fsdp import _runtime_utils + _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index debc1195ba..da08772a63 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -34,6 +34,8 @@ torch.__version__) < version.parse('2.0.2'): from torch.distributed.fsdp._common_utils import _FSDPState +log = logging.getLogger(__name__) + SHARDING_MAP = { 'NO_SHARD': ShardingStrategy.NO_SHARD, 'SHARD_GRAD_OP': ShardingStrategy.SHARD_GRAD_OP, @@ -124,10 +126,7 @@ def get_cpu_offload(cpu_offload=False): def _get_process_group(pg, process_group_cache=None): """Helper function for configuring and/or retrieving process groups.""" - warnings.warn(f'Instantiating FSDP with custom process groups is an experimental feature.') - - # Return regular process_groups as is, no cacheing - if pg is None or isinstance(pg, ProcessGroup): + if pg is None or isinstance(pg, ProcessGroup): # Return as is, no caching return pg world_size = dist.get_world_size() @@ -136,13 +135,13 @@ def _get_process_group(pg, process_group_cache=None): # Handle special str process_group cases if pg == 'self': pg = 'set1' - warnings.warn(f"Converting process_group='self' to process_group='{pg}'") + log.info(f"Converting process_group='self' to process_group='{pg}'") elif pg == 'node': pg = f'set{local_world_size}' - warnings.warn(f"Converting process_group='node' to process_group='{pg}'") + log.info(f"Converting process_group='node' to process_group='{pg}'") elif pg == 'local_rank_across_nodes': pg = f'mod{local_world_size}' - warnings.warn(f"Converting process_group='local_rank_across_nodes' to process_group='{pg}'") + log.info(f"Converting process_group='local_rank_across_nodes' to process_group='{pg}'") # Handle str and Union[List[int], Tuple[int]] process_group cases if isinstance(pg, str) and pg.startswith('set'): @@ -164,15 +163,10 @@ def _get_process_group(pg, process_group_cache=None): raise ValueError(f'Unsure how to setup process_group={pg}') if process_group_cache is not None and ranks in process_group_cache: - warnings.warn( - f'On rank={dist.get_global_rank()} using cached progress group with {ranks=}. ' + - 'If the intention was to use a new process group, a new process group can be instantiated and passed' + - " in as an arguement (`'process_group': newly_instantiated_process_group_obect,`)") + log.info(f'Using cached progress group with {ranks=} on rank={dist.get_global_rank()}.') return process_group_cache[ranks] - warnings.warn( - f'Composer is instantiating custom process groups with {ranks=} (on rank={dist.get_global_rank()}). ' + - 'This is an experimental feature.') + log.info(f'Instantiating custom process groups with {ranks=} on rank={dist.get_global_rank()}.') ranks_per_subgroup_list = list(set(dist.all_gather_object(ranks))) ( @@ -200,8 +194,10 @@ def _set_custom_fsdp_module_kwargs(module_kwargs: Dict, process_group_cache: Dic f"Automated setting of custom per module mixed_precision is not implemented, but it can be set if `isinstance(module_kwargs['mixed_precision'], MixedPrecision)`" ) if 'process_group' in module_kwargs: - # Call on every process group if it is a tuple - if isinstance(module_kwargs['process_group'], tuple): + # Call on every process group if it is a tuple/list of non-ints + if type(module_kwargs['process_group']) in [ + list, tuple + ] and not all(isinstance(x, int) for x in module_kwargs['process_group']): module_kwargs['process_group'] = tuple( _get_process_group(pg, process_group_cache) for pg in module_kwargs['process_group']) else: diff --git a/docs/source/notes/distributed_training.rst b/docs/source/notes/distributed_training.rst index 18619be6a9..cab087f3b8 100644 --- a/docs/source/notes/distributed_training.rst +++ b/docs/source/notes/distributed_training.rst @@ -201,6 +201,7 @@ The full spec and defaults for Composer's `fsdp_config` is here: # 'reduce_dtype': 'fp32' | 'fp16' | 'bf16', # 'buffer_dtype': 'fp32' | 'fp16' | 'bf16', # }, + 'process_group': str = 'self' | 'node' | 'local_rank_across_nodes' | 'setK' | 'modK', # Default: None 'save_planner': torch.distributed.checkpoint.planner.SavePlanner, # Default: None 'sharded_ckpt_prefix_dir': str = 'ep{epoch}-ba{batch}', # Default: 'ep{epoch}-ba{batch}' 'sharding_strategy': str = 'FULL_SHARD' | 'SHARD_GRAD_OP' | 'NO_SHARD', # Default: 'FULL_SHARD' diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index 25b294fea9..95aaf31e97 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -197,6 +197,33 @@ def test_fsdp_prefetch_limit(forward_prefetch_limit: int, backward_prefetch_limi trainer.fit() +@pytest.mark.gpu +@world_size(2) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'), + reason='FSDP requires PyTorch 1.13 or higher') +@pytest.mark.filterwarnings('ignore:Instantiating FSDP with custom process groups.*:UserWarning') +@pytest.mark.filterwarnings('ignore:Composer is instantiating custom process groups.*:UserWarning') +def test_fsdp_process_group(world_size: int): + model = SimpleModel() + model.fc1._fsdp_wrap = True + model.fc2._fsdp_wrap = True + dataset = RandomClassificationDataset(size=10) + dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset)) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + trainer = Trainer( + model=model, + optimizers=optimizer, + train_dataloader=dataloader, + fsdp_config={ + 'process_group': 'mod1', # all ranks + }, + max_duration='3ba', + ) + + trainer.fit() + + class SimpleMLP(ComposerModel): def __init__(self, num_features: int = 128, device: str = 'cuda'):