From db98cc3ad1e0a20807e0c2513f0eee40f626860e Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Mon, 16 Dec 2024 11:05:55 -0800 Subject: [PATCH] Fix assertion for offloading states (#6855) This PR fixes the assertions in `offload_states` method mentioned in #6833. Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/engine.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 0aad018528d3..5f023d87f375 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3738,6 +3738,11 @@ def offload_states(self, assert self.zero_optimization_stage( ) == ZeroStageEnum.weights, "Moving buffers across devices is supported only for ZeRO stage 3." + opt_offload_config = self.zero_offload_optimizer() + assert opt_offload_config is None or opt_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded optimizer states." + param_offload_config = self.zero_offload_param() + assert param_offload_config is None or param_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded parameters." + assert not self.zero_offload_param(), "Moving states across devices is not supported for offloaded parameters." if device == OffloadDeviceEnum.none: