diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index 61d088bdda..db6ea1a240 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -72,6 +72,10 @@ def patch_pytorch(): elif version.parse(torch.__version__) < version.parse('2.2.1'): # Monkey patch for torch < 2.2.1 ie torch == 2.2.0 + # Allow 2D HSDP + from torch.distributed.fsdp import _runtime_utils + _runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None + # # Better overlap communication and computation # from torch.distributed.fsdp import _runtime_utils