From fcc731f09d8e09b04a816b3ea0f83ab1d15169b3 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Thu, 25 Apr 2024 11:01:35 -0700 Subject: [PATCH] Fix torch.compile error for PyTorch v2.3 (#5463) PyTorch v2.3 throws an error when it tries to compile `iter_params` used for ZeRO3. This PR excludes the function from the compilation targets. After this PR is merged, we can [unpin the torch version for unit tests](https://github.com/microsoft/DeepSpeed/pull/5459). --- deepspeed/runtime/zero/partitioned_param_coordinator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 8fc962c4f2a7..bdec8a55fcbc 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -34,6 +34,7 @@ def get_all_parameters(sub_module, recurse=False): return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters()) +@compiler.disable def iter_params(module: Module, recurse=False) -> Iterable[Parameter]: return map(lambda pair: pair[1], get_all_parameters(module, recurse))