From 2ab9423d48166f5dbb163ab6aa11258533b4d47e Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 17 Apr 2024 02:21:26 +0800 Subject: [PATCH] Fix a convergence issues in TP topology caused by incorrect grad_norm. (#5411) Some users are concerned that changes in TP topology during MOE training may potentially cause interference with experiments when noticing similar issues https://github.com/microsoft/Megatron-DeepSpeed/issues/151 https://github.com/microsoft/Megatron-DeepSpeed/pull/176/files We found a grad_norm calculation error after enabling TP. This error occurs because flattened grad of a params group is used, where the group contains both non-TP and TP parameters. Therefore, it is not possible to use a single attribute to determine whether flattened grad needs to compute the norm. In the current code logic, all params are assumed to be non-TP, resulting in only tp_rank0 grad participating in grad_norm computation. Other tp_rank grads have grad_norm_sum equal to 0. We tested and found that with TP=1 and TP=4, the difference in grad_norm is approximately twice (sqrt(4)). This aligns with the aforementioned issue. This problem should also affect dense models. Due to the absence of flattening params_group grad in bf16, this problem is avoided. We tested the loss curve on the 1.3B model. In cases where TP size increases the inconsistent gap should be larger. with this change 1.3B with EP=4 TP=4 &1 , fp16,mbs=1,gbs=16 ![image](https://github.com/microsoft/DeepSpeed/assets/27563729/855042c8-ac8a-4192-b465-5fa60c1a7c59) without this change 1.3B with EP=4 TP=4&1 ,fp16,mbs=1,gbs=16 ![image](https://github.com/microsoft/DeepSpeed/assets/27563729/66854d14-7b83-4b09-a669-b452d6157ea0) --------- Co-authored-by: Conglong Li --- deepspeed/runtime/fp16/fused_optimizer.py | 65 +++++++++++++++++++---- deepspeed/runtime/utils.py | 54 ++++++++----------- 2 files changed, 77 insertions(+), 42 deletions(-) diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index af8050c4646a..bf1693307ea7 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -9,15 +9,16 @@ import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors - from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer -from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm, get_norm_with_moe_layers +from deepspeed.runtime.utils import get_global_norm, get_flattened_grad_norm, CheckOverflow, get_weight_norm, get_norm_with_moe_layers, is_model_parallel_parameter from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE from deepspeed.utils import logger, log_dist from deepspeed.utils.torch import required_torch_version from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, CLIP_GRAD from deepspeed.accelerator import get_accelerator from deepspeed.moe.utils import is_moe_param_group +from deepspeed.runtime.constants import PIPE_REPLICATED +from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank OVERFLOW_CHECK_TIMER = 'overflow_check' COMPUTE_NORM_TIMER = 'compute_norm' @@ -64,6 +65,8 @@ def __init__(self, self.fp16_groups_flat = [] self.fp32_groups_flat = [] + self.flatten_grad_norm_mask_list = [] + self.has_executed_step = False self._global_grad_norm = 0. # loop to deal with groups @@ -206,6 +209,40 @@ def override_loss_scale(self, loss_scale): self.custom_loss_scaler = True self.external_loss_scale = loss_scale + def _require_avoid_recompute_norm(self, p, tensor_model_parallel_rank): + # for filtering replicated tensors from tensor + if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: + return True + if (tensor_model_parallel_rank > 0) and not is_model_parallel_parameter(p): + return True + + def _get_norm_mask_idx(self, group): + """The function preserves the parallel information for norm + from unflattened gradients. + + Args: + group (Iterable[Tensor] ): params group + + Returns: + torch.Tensor: A 2D tensor containing index ranges for each group, + where each row represents a [start index, end index]. + """ + group_mask_idx_list = [] + grad_flat_st_idx = 0 + grad_flat_en_idx = 0 + + for p in group: + grad_flat_en_idx = grad_flat_st_idx + p.numel() + if p.grad is not None and self._require_avoid_recompute_norm(p, bwc_tensor_model_parallel_rank(self.mpu)): + # merge range + if len(group_mask_idx_list) > 0 and grad_flat_st_idx == group_mask_idx_list[-1][-1]: + group_mask_idx_list[-1][-1] = grad_flat_en_idx + else: + group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx]) + grad_flat_st_idx = grad_flat_en_idx + + return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device()) + def step(self, closure=None): """ Not supporting closure. @@ -251,23 +288,32 @@ def step(self, closure=None): for p in group ])) - for p in group: - p.grad = None - self.fp32_groups_flat[i].grad = grads_groups_flat[i] param_group = self.optimizer.param_groups[i] + + # split expert and non_expert grads for norm if self.has_moe_layers and is_moe_param_group(param_group): if param_group['name'] not in expert_grads_for_norm: expert_grads_for_norm[param_group['name']] = [] + expert_grads_for_norm[param_group['name']].append(self.fp32_groups_flat[i]) else: + # retrieves the required mask for calculating the norm of flat_grad + # perform this collect operation only once + if not self.has_executed_step: + cur_flat_grad_norm_mask = self._get_norm_mask_idx(group) + self.flatten_grad_norm_mask_list.append(cur_flat_grad_norm_mask) + non_experts_grads_for_norm.append(self.fp32_groups_flat[i]) - self.timers(COMPUTE_NORM_TIMER).start() + for p in group: + p.grad = None - all_groups_norm = get_grad_norm(non_experts_grads_for_norm, mpu=self.mpu) + self.timers(COMPUTE_NORM_TIMER).start() - self.timers(COMPUTE_NORM_TIMER).stop() + all_groups_norm = get_flattened_grad_norm(non_experts_grads_for_norm, + mpu=self.mpu, + grad_norm_mask=self.flatten_grad_norm_mask_list) if self.has_moe_layers: all_groups_norm = get_norm_with_moe_layers(all_groups_norm, @@ -276,6 +322,7 @@ def step(self, closure=None): norm_type=self.norm_type) scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm]) + self.timers(COMPUTE_NORM_TIMER).stop() # Stash unscaled gradient norm self._global_grad_norm = scaled_global_grad_norm / self.cur_scale @@ -298,7 +345,7 @@ def step(self, closure=None): updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data.copy_(q.data) - + self.has_executed_step = True self.timers(UPDATE_FP16_TIMER).stop() self.timers.log(STEP_TIMERS) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index c55f8a0e2995..7744b2ee8b98 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -17,7 +17,6 @@ import torch from deepspeed import comm as dist - try: from torch._six import inf except ModuleNotFoundError: @@ -385,7 +384,7 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): return total_norm -def get_grad_norm(parameters, norm_type=2, mpu=None): +def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=None): """Get grad norm of an iterable of parameters. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and @@ -397,7 +396,8 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): single Tensor that will have gradients normalized norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. - + grad_norm_mask (List[Tensor]): A list of Tensor, where + each Tensor is a 2D Tensor containing ranges of [start_index, end_index]. Returns: Total norm of the parameters (viewed as a single vector). """ @@ -415,18 +415,25 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): total_norm = total_norm_cuda[0].item() else: total_norm = 0. - tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu) - for p in parameters: - # Pipeline parallelism may replicate parameters. Avoid multi-counting. - if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: - continue - - # Filter to avoid over-counting replicated tensors from tensor - # model parallelism - if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p): - continue + for idx, p in enumerate(parameters): + # Use grad_norm_mask to avoid redundant computation of flattened gradient norm + if grad_norm_mask is not None and len(grad_norm_mask[idx]) > 0: + + # A loop-free implementation to create a mask tensor based on a range list + # which is logically equivalent to the following implementation. + # # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool) + # # for mask_idx in grad_norm_mask[idx]: + # # mask_tensor_[mask_idx[0]:mask_idx[1]] = True + cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(), + dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1) + mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device(), dtype=p.dtype) + mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1), + cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1] + + param_norm = torch.masked_fill(p.grad.data, mask_tensor, 0).float().norm(norm_type) - param_norm = p.grad.data.float().norm(norm_type) + else: + param_norm = p.grad.data.float().norm(norm_type) total_norm += param_norm.item()**norm_type # Sum across all model parallel GPUs. @@ -814,25 +821,6 @@ def get_only_unique_item(items): return unique_item -def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, eps=1e-6): - """Clip the gradient of a list of parameters. - Args: - parameters: List of parameters whose .grad will be clipped. - global_grad_norm (float, optional): Precomputed gradient norm. Defaults to None. - mpu (optional): model parallelism unit. Defaults to None. - eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6 - Returns: - float: the global gradient norm - """ - if global_grad_norm is None: - global_grad_norm = get_grad_norm(parameters, mpu=mpu) - clip_coef = max_norm / (global_grad_norm + eps) - if clip_coef < 1: - for p in parameters: - p.grad.detach().mul_(clip_coef) - return global_grad_norm - - def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False, moe_ep_group=None): """Get norm of an iterable of tensors.