diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index 29d408471b..f2c8c615b4 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -231,7 +231,7 @@ def prepare_fsdp_module( set_fsdp_default(fsdp_config) - # Check sync_module_states is True for mixed initialization + # Check sync_module_states is True for mixed initialization or HSDP if fsdp_config['sync_module_states'] == False: rank_on_meta = 1 if next(model.parameters()).device.type == 'meta' else 0 all_ranks_meta = device.tensor_to_device(torch.tensor([rank_on_meta], dtype=torch.uint8)) @@ -243,6 +243,10 @@ def prepare_fsdp_module( 'gpu and some ranks are on meta. Either keep all ranks on the same ' "device or set fsdp_config['sync_module_states'] = True. Otherwise, " 'some weights may be randomly initialized when loading a checkpoint.') + if fsdp_config['sharding_strategy'] in ('HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'): + raise ValueError('HSDP (HYBRID_SHARD or _HYBRID_SHARD_ZERO2) requires ' + 'fsdp_config["sync_module_states"] = True or different replicas will ' + 'have different weights.') # Check if other ranks OOMed after forward/backward pass when using auto microbatching. This # may happen when close to memory limit or with uneven memory usage across ranks. Since we