From c66bc4269e9484b6e57d6f5521df02c70d399246 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Tue, 23 Apr 2024 07:27:09 +0800 Subject: [PATCH] set the default to use set_to_none for clearing gradients in BF16 optimizer. (#5434) as discussed in #5175, set the default to use set_to_none for clearing gradients in BF16 optimizer. Additionally, for the case of zero clearing, use foreach_zero. Verified correctness with mega-ds llama 7B training. FYI @loadams --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/bf16_optimizer.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index f970e582b354..1f3365b20f4e 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -341,7 +341,7 @@ def _update_hp_grad(self, lp, group_idx, param_idx, clear_lp_grads): # clear gradients if clear_lp_grads: - lp.grad._zero() + lp.grad.zero_() @torch.no_grad() def _update_hp_grads_func(self, clear_lp_grads=False): @@ -441,11 +441,20 @@ def clear_hp_grads(self): self.fp32_groups_has_gradients[i] = [False] * len(group) def clear_lp_grads(self): + + # using zero_() fixed memory address for graph replay + set_to_none = False if self.graph_harvesting else True + zero_grads_list = [] for group in self.bf16_groups: for param in group: - if param.grad is not None: - # Using zero_() fixed memory address for graph replay - param.grad.zero_() + if set_to_none: + param.grad = None + elif param.grad is not None: + if param.grad.grad_fn is not None: + param.grad.detach_() + zero_grads_list.append(param.grad) + if not set_to_none and len(zero_grads_list) > 0: + torch._foreach_zero_(zero_grads_list) def state_dict(self): state_dict = {}