From 8949105369ccaacecc7b707601b787e4af730e89 Mon Sep 17 00:00:00 2001 From: Raza Sikander <54884406+raza-sikander@users.noreply.github.com> Date: Tue, 16 Apr 2024 01:28:56 +0530 Subject: [PATCH 1/6] Remove dtype(fp16) condition check for residual_add unit test (#5329) When the dtype is bf16 or fp32 the if condition is not satisfied and it continues execution instead of skipping when triton is not installed. Co-authored-by: Shaik Raza Sikander Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase --- tests/unit/ops/transformer/inference/test_residual_add.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/ops/transformer/inference/test_residual_add.py b/tests/unit/ops/transformer/inference/test_residual_add.py index c2952f74ff2d..91830e25fc81 100644 --- a/tests/unit/ops/transformer/inference/test_residual_add.py +++ b/tests/unit/ops/transformer/inference/test_residual_add.py @@ -77,7 +77,7 @@ def run_residual_add_reference(hidden_state, residual, attn_output, attn_bias, f @pytest.mark.parametrize("use_triton_ops", [True, False]) def test_residual_add(inference_module, batch, sequence, hidden_dim, dtype, mlp_after_attn, add_bias, mp_size, pre_attn_norm, use_triton_ops): - if not deepspeed.HAS_TRITON and use_triton_ops and dtype == torch.float16: + if not deepspeed.HAS_TRITON and use_triton_ops: pytest.skip("triton has to be installed for the test") ds_out = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name()) residual = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name()) From 258e500f3fcd33fa42871ff87612d135399017ac Mon Sep 17 00:00:00 2001 From: YiSheng5 Date: Wed, 17 Apr 2024 01:40:45 +0800 Subject: [PATCH 2/6] [XPU] Use non_daemonic_proc by default on XPU device (#5412) Set non_daemonic_proc=True by default on XPU Device, using non_daemonic_proc for unit test. Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- tests/unit/common.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unit/common.py b/tests/unit/common.py index 1fd83de81f02..a2593e703aef 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -248,6 +248,10 @@ def _launch_procs(self, num_procs): f"Skipping test because not enough GPUs are available: {num_procs} required, {get_accelerator().device_count()} available" ) + if get_accelerator().device_name() == 'xpu': + self.non_daemonic_procs = True + self.reuse_dist_env = False + # Set start method to `forkserver` (or `fork`) mp.set_start_method('forkserver', force=True) From 0896503e2f4d3b12583dfe267e52db3a1d63b88d Mon Sep 17 00:00:00 2001 From: inkcherry Date: Wed, 17 Apr 2024 02:21:26 +0800 Subject: [PATCH 3/6] Fix a convergence issues in TP topology caused by incorrect grad_norm. (#5411) Some users are concerned that changes in TP topology during MOE training may potentially cause interference with experiments when noticing similar issues https://github.com/microsoft/Megatron-DeepSpeed/issues/151 https://github.com/microsoft/Megatron-DeepSpeed/pull/176/files We found a grad_norm calculation error after enabling TP. This error occurs because flattened grad of a params group is used, where the group contains both non-TP and TP parameters. Therefore, it is not possible to use a single attribute to determine whether flattened grad needs to compute the norm. In the current code logic, all params are assumed to be non-TP, resulting in only tp_rank0 grad participating in grad_norm computation. Other tp_rank grads have grad_norm_sum equal to 0. We tested and found that with TP=1 and TP=4, the difference in grad_norm is approximately twice (sqrt(4)). This aligns with the aforementioned issue. This problem should also affect dense models. Due to the absence of flattening params_group grad in bf16, this problem is avoided. We tested the loss curve on the 1.3B model. In cases where TP size increases the inconsistent gap should be larger. with this change 1.3B with EP=4 TP=4 &1 , fp16,mbs=1,gbs=16 ![image](https://github.com/microsoft/DeepSpeed/assets/27563729/855042c8-ac8a-4192-b465-5fa60c1a7c59) without this change 1.3B with EP=4 TP=4&1 ,fp16,mbs=1,gbs=16 ![image](https://github.com/microsoft/DeepSpeed/assets/27563729/66854d14-7b83-4b09-a669-b452d6157ea0) --------- Co-authored-by: Conglong Li --- deepspeed/runtime/fp16/fused_optimizer.py | 65 +++++++++++++++++++---- deepspeed/runtime/utils.py | 54 ++++++++----------- 2 files changed, 77 insertions(+), 42 deletions(-) diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index af8050c4646a..bf1693307ea7 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -9,15 +9,16 @@ import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors - from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer -from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm, get_norm_with_moe_layers +from deepspeed.runtime.utils import get_global_norm, get_flattened_grad_norm, CheckOverflow, get_weight_norm, get_norm_with_moe_layers, is_model_parallel_parameter from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE from deepspeed.utils import logger, log_dist from deepspeed.utils.torch import required_torch_version from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, CLIP_GRAD from deepspeed.accelerator import get_accelerator from deepspeed.moe.utils import is_moe_param_group +from deepspeed.runtime.constants import PIPE_REPLICATED +from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank OVERFLOW_CHECK_TIMER = 'overflow_check' COMPUTE_NORM_TIMER = 'compute_norm' @@ -64,6 +65,8 @@ def __init__(self, self.fp16_groups_flat = [] self.fp32_groups_flat = [] + self.flatten_grad_norm_mask_list = [] + self.has_executed_step = False self._global_grad_norm = 0. # loop to deal with groups @@ -206,6 +209,40 @@ def override_loss_scale(self, loss_scale): self.custom_loss_scaler = True self.external_loss_scale = loss_scale + def _require_avoid_recompute_norm(self, p, tensor_model_parallel_rank): + # for filtering replicated tensors from tensor + if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: + return True + if (tensor_model_parallel_rank > 0) and not is_model_parallel_parameter(p): + return True + + def _get_norm_mask_idx(self, group): + """The function preserves the parallel information for norm + from unflattened gradients. + + Args: + group (Iterable[Tensor] ): params group + + Returns: + torch.Tensor: A 2D tensor containing index ranges for each group, + where each row represents a [start index, end index]. + """ + group_mask_idx_list = [] + grad_flat_st_idx = 0 + grad_flat_en_idx = 0 + + for p in group: + grad_flat_en_idx = grad_flat_st_idx + p.numel() + if p.grad is not None and self._require_avoid_recompute_norm(p, bwc_tensor_model_parallel_rank(self.mpu)): + # merge range + if len(group_mask_idx_list) > 0 and grad_flat_st_idx == group_mask_idx_list[-1][-1]: + group_mask_idx_list[-1][-1] = grad_flat_en_idx + else: + group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx]) + grad_flat_st_idx = grad_flat_en_idx + + return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device()) + def step(self, closure=None): """ Not supporting closure. @@ -251,23 +288,32 @@ def step(self, closure=None): for p in group ])) - for p in group: - p.grad = None - self.fp32_groups_flat[i].grad = grads_groups_flat[i] param_group = self.optimizer.param_groups[i] + + # split expert and non_expert grads for norm if self.has_moe_layers and is_moe_param_group(param_group): if param_group['name'] not in expert_grads_for_norm: expert_grads_for_norm[param_group['name']] = [] + expert_grads_for_norm[param_group['name']].append(self.fp32_groups_flat[i]) else: + # retrieves the required mask for calculating the norm of flat_grad + # perform this collect operation only once + if not self.has_executed_step: + cur_flat_grad_norm_mask = self._get_norm_mask_idx(group) + self.flatten_grad_norm_mask_list.append(cur_flat_grad_norm_mask) + non_experts_grads_for_norm.append(self.fp32_groups_flat[i]) - self.timers(COMPUTE_NORM_TIMER).start() + for p in group: + p.grad = None - all_groups_norm = get_grad_norm(non_experts_grads_for_norm, mpu=self.mpu) + self.timers(COMPUTE_NORM_TIMER).start() - self.timers(COMPUTE_NORM_TIMER).stop() + all_groups_norm = get_flattened_grad_norm(non_experts_grads_for_norm, + mpu=self.mpu, + grad_norm_mask=self.flatten_grad_norm_mask_list) if self.has_moe_layers: all_groups_norm = get_norm_with_moe_layers(all_groups_norm, @@ -276,6 +322,7 @@ def step(self, closure=None): norm_type=self.norm_type) scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm]) + self.timers(COMPUTE_NORM_TIMER).stop() # Stash unscaled gradient norm self._global_grad_norm = scaled_global_grad_norm / self.cur_scale @@ -298,7 +345,7 @@ def step(self, closure=None): updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data.copy_(q.data) - + self.has_executed_step = True self.timers(UPDATE_FP16_TIMER).stop() self.timers.log(STEP_TIMERS) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index c55f8a0e2995..7744b2ee8b98 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -17,7 +17,6 @@ import torch from deepspeed import comm as dist - try: from torch._six import inf except ModuleNotFoundError: @@ -385,7 +384,7 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): return total_norm -def get_grad_norm(parameters, norm_type=2, mpu=None): +def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=None): """Get grad norm of an iterable of parameters. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and @@ -397,7 +396,8 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): single Tensor that will have gradients normalized norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. - + grad_norm_mask (List[Tensor]): A list of Tensor, where + each Tensor is a 2D Tensor containing ranges of [start_index, end_index]. Returns: Total norm of the parameters (viewed as a single vector). """ @@ -415,18 +415,25 @@ def get_grad_norm(parameters, norm_type=2, mpu=None): total_norm = total_norm_cuda[0].item() else: total_norm = 0. - tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu) - for p in parameters: - # Pipeline parallelism may replicate parameters. Avoid multi-counting. - if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: - continue - - # Filter to avoid over-counting replicated tensors from tensor - # model parallelism - if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p): - continue + for idx, p in enumerate(parameters): + # Use grad_norm_mask to avoid redundant computation of flattened gradient norm + if grad_norm_mask is not None and len(grad_norm_mask[idx]) > 0: + + # A loop-free implementation to create a mask tensor based on a range list + # which is logically equivalent to the following implementation. + # # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool) + # # for mask_idx in grad_norm_mask[idx]: + # # mask_tensor_[mask_idx[0]:mask_idx[1]] = True + cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(), + dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1) + mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device(), dtype=p.dtype) + mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1), + cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1] + + param_norm = torch.masked_fill(p.grad.data, mask_tensor, 0).float().norm(norm_type) - param_norm = p.grad.data.float().norm(norm_type) + else: + param_norm = p.grad.data.float().norm(norm_type) total_norm += param_norm.item()**norm_type # Sum across all model parallel GPUs. @@ -814,25 +821,6 @@ def get_only_unique_item(items): return unique_item -def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, eps=1e-6): - """Clip the gradient of a list of parameters. - Args: - parameters: List of parameters whose .grad will be clipped. - global_grad_norm (float, optional): Precomputed gradient norm. Defaults to None. - mpu (optional): model parallelism unit. Defaults to None. - eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6 - Returns: - float: the global gradient norm - """ - if global_grad_norm is None: - global_grad_norm = get_grad_norm(parameters, mpu=mpu) - clip_coef = max_norm / (global_grad_norm + eps) - if clip_coef < 1: - for p in parameters: - p.grad.detach().mul_(clip_coef) - return global_grad_norm - - def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False, moe_ep_group=None): """Get norm of an iterable of tensors. From 34969d2091fcfd37545361111d7ce7d13229d1a3 Mon Sep 17 00:00:00 2001 From: Logan Adams <114770087+loadams@users.noreply.github.com> Date: Tue, 16 Apr 2024 12:27:45 -0700 Subject: [PATCH 4/6] Update 'create-pr' action in release workflow to latest (#5415) A [warning is shown](https://github.com/microsoft/DeepSpeed/actions/runs/8695213322/job/23845782048#step:10:31) when we do releases: ``` [deploy](https://github.com/microsoft/DeepSpeed/actions/runs/8695213322/job/23845782048) Node.js 16 actions are deprecated. Please update the following actions to use Node.js 20: peter-evans/create-pull-request@v4. For more information see: https://github.blog/changelog/2023-09-22-github-actions-transitioning-from-node-16-to-node-20/. ``` To resolve this we update the create a pull request to `@v6`, see release notes [here](https://github.com/peter-evans/create-pull-request/releases) --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5a931125eff6..2f571a14b228 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -35,7 +35,7 @@ jobs: run: | python release/bump_patch_version.py --current_version ${{ env.RELEASE_VERSION }} - name: Create Pull Request - uses: peter-evans/create-pull-request@v4 + uses: peter-evans/create-pull-request@v6 with: token: ${{ secrets.GH_PAT }} add-paths: | From bc0f77472828323fdc0ae67f62123948bc2b12d1 Mon Sep 17 00:00:00 2001 From: "Etienne.bfx" Date: Tue, 16 Apr 2024 21:43:44 +0200 Subject: [PATCH 5/6] Update engine.py to avoid torch warning (#5408) The state_dict function of module.py from torch write a warning if arguments are positional arguments and not keyword arguments --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: ebonnafoux --- deepspeed/runtime/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 992d7877c179..9a2b943b0992 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2542,7 +2542,7 @@ def all_gather_scalar(self, value, dp_group): return tensor_list def module_state_dict(self, destination=None, prefix="", keep_vars=False, exclude_frozen_parameters=False): - sd = self.module.state_dict(destination, prefix, keep_vars) + sd = self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) # Remove frozen parameter weights from state_dict if specified if exclude_frozen_parameters: From a9cbd688f01c7742397f53144b2223f9956540ac Mon Sep 17 00:00:00 2001 From: Shafiq Jetha <1066864+fasterinnerlooper@users.noreply.github.com> Date: Tue, 16 Apr 2024 14:49:35 -0600 Subject: [PATCH 6/6] Update _sidebar.scss (#5293) The right sidebar disappears off of the right side of the page. These changes will help bring the content back and place it correctly on the page. Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- docs/_sass/minimal-mistakes/_sidebar.scss | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/_sass/minimal-mistakes/_sidebar.scss b/docs/_sass/minimal-mistakes/_sidebar.scss index 63cef338c583..312a6279b9b0 100644 --- a/docs/_sass/minimal-mistakes/_sidebar.scss +++ b/docs/_sass/minimal-mistakes/_sidebar.scss @@ -76,10 +76,9 @@ @include breakpoint($large) { position: absolute; - top: 0; + top: auto; right: 0; width: $right-sidebar-width-narrow; - margin-right: -1.5 * $right-sidebar-width-narrow; padding-left: 1em; z-index: 10; @@ -94,7 +93,6 @@ @include breakpoint($x-large) { width: $right-sidebar-width; - margin-right: -1.5 * $right-sidebar-width; } }