diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index 0987656947..ad0fd0904c 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -9,6 +9,7 @@ import torch from packaging import version from torch.distributed._shard.sharding_spec import ChunkShardingSpec +from torch.distributed.fsdp import FullyShardedDataParallel from composer.trainer.mosaic_fsdp_utils import (_sharded_pre_load_state_dict_hook, build_metadata, custom_auto_wrap_t1p13p1) @@ -61,8 +62,6 @@ def patch_pytorch(): _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None # Better overlap communication and computation - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel - from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p1, _wait_for_computation_stream, forward) _runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p1 @@ -75,7 +74,6 @@ def patch_pytorch(): # Better overlap communication and computation from torch.distributed.fsdp import _runtime_utils - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p2, _wait_for_computation_stream, forward)