From 7ed952eff1c71f7652fe4786a16c8d917f9c37d6 Mon Sep 17 00:00:00 2001 From: Kazuki Fujii <68278821+okoge-kaz@users.noreply.github.com> Date: Sun, 8 Oct 2023 04:07:05 +0900 Subject: [PATCH] Fix bug in bfloat16 optimizer related to checkpointing (#4434) * fix: bf16 optimizer if condition * fix: unexpected keyword argument 'load_serial' * fix: add load_serial arg to bf16_optimizer * style: fix indentation --------- Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/bf16_optimizer.py | 3 ++- deepspeed/runtime/engine.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 9c7a84d4841eb..550af8fac0575 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -365,7 +365,8 @@ def load_state_dict(self, state_dict_list, checkpoint_folder, load_optimizer_states=True, - load_from_fp32_weights=False): + load_from_fp32_weights=False, + load_serial=None): if checkpoint_folder: self._load_universal_checkpoint(checkpoint_folder, load_optimizer_states, load_from_fp32_weights) else: diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index b4c8ef56c7018..8a8193ddd8f5d 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -3287,7 +3287,7 @@ def _get_zero_param_shapes(self): # if we don't use it, we get parameters ordered incorrectly if hasattr(self.optimizer, "round_robin_bit16_groups"): bit16_groups = self.optimizer.round_robin_bit16_groups - elif self.bfloat16_enabled() and not self.zero_optimization(): + elif self.bfloat16_enabled() and hasattr(self.optimizer, "bf16_groups"): bit16_groups = self.optimizer.bf16_groups else: bit16_groups = self.optimizer.bit16_groups if self.zero_optimization_stage(