Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Apr 16, 2024
2 parents 1acac5a + a9cbd68 commit 2113f27
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
run: |
python release/bump_patch_version.py --current_version ${{ env.RELEASE_VERSION }}
- name: Create Pull Request
uses: peter-evans/create-pull-request@v4
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GH_PAT }}
add-paths: |
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2542,7 +2542,7 @@ def all_gather_scalar(self, value, dp_group):
return tensor_list

def module_state_dict(self, destination=None, prefix="", keep_vars=False, exclude_frozen_parameters=False):
sd = self.module.state_dict(destination, prefix, keep_vars)
sd = self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)

# Remove frozen parameter weights from state_dict if specified
if exclude_frozen_parameters:
Expand Down
65 changes: 56 additions & 9 deletions deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@

import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from deepspeed.runtime.base_optimizer import DeepSpeedOptimizer
from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm, get_norm_with_moe_layers
from deepspeed.runtime.utils import get_global_norm, get_flattened_grad_norm, CheckOverflow, get_weight_norm, get_norm_with_moe_layers, is_model_parallel_parameter
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import logger, log_dist
from deepspeed.utils.torch import required_torch_version
from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, CLIP_GRAD
from deepspeed.accelerator import get_accelerator
from deepspeed.moe.utils import is_moe_param_group
from deepspeed.runtime.constants import PIPE_REPLICATED
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank

OVERFLOW_CHECK_TIMER = 'overflow_check'
COMPUTE_NORM_TIMER = 'compute_norm'
Expand Down Expand Up @@ -64,6 +65,8 @@ def __init__(self,
self.fp16_groups_flat = []
self.fp32_groups_flat = []

self.flatten_grad_norm_mask_list = []
self.has_executed_step = False
self._global_grad_norm = 0.

# loop to deal with groups
Expand Down Expand Up @@ -206,6 +209,40 @@ def override_loss_scale(self, loss_scale):
self.custom_loss_scaler = True
self.external_loss_scale = loss_scale

def _require_avoid_recompute_norm(self, p, tensor_model_parallel_rank):
# for filtering replicated tensors from tensor
if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
return True
if (tensor_model_parallel_rank > 0) and not is_model_parallel_parameter(p):
return True

def _get_norm_mask_idx(self, group):
"""The function preserves the parallel information for norm
from unflattened gradients.
Args:
group (Iterable[Tensor] ): params group
Returns:
torch.Tensor: A 2D tensor containing index ranges for each group,
where each row represents a [start index, end index].
"""
group_mask_idx_list = []
grad_flat_st_idx = 0
grad_flat_en_idx = 0

for p in group:
grad_flat_en_idx = grad_flat_st_idx + p.numel()
if p.grad is not None and self._require_avoid_recompute_norm(p, bwc_tensor_model_parallel_rank(self.mpu)):
# merge range
if len(group_mask_idx_list) > 0 and grad_flat_st_idx == group_mask_idx_list[-1][-1]:
group_mask_idx_list[-1][-1] = grad_flat_en_idx
else:
group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx])
grad_flat_st_idx = grad_flat_en_idx

return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device())

def step(self, closure=None):
"""
Not supporting closure.
Expand Down Expand Up @@ -251,23 +288,32 @@ def step(self, closure=None):
for p in group
]))

for p in group:
p.grad = None

self.fp32_groups_flat[i].grad = grads_groups_flat[i]
param_group = self.optimizer.param_groups[i]

# split expert and non_expert grads for norm
if self.has_moe_layers and is_moe_param_group(param_group):
if param_group['name'] not in expert_grads_for_norm:
expert_grads_for_norm[param_group['name']] = []

expert_grads_for_norm[param_group['name']].append(self.fp32_groups_flat[i])
else:
# retrieves the required mask for calculating the norm of flat_grad
# perform this collect operation only once
if not self.has_executed_step:
cur_flat_grad_norm_mask = self._get_norm_mask_idx(group)
self.flatten_grad_norm_mask_list.append(cur_flat_grad_norm_mask)

non_experts_grads_for_norm.append(self.fp32_groups_flat[i])

self.timers(COMPUTE_NORM_TIMER).start()
for p in group:
p.grad = None

all_groups_norm = get_grad_norm(non_experts_grads_for_norm, mpu=self.mpu)
self.timers(COMPUTE_NORM_TIMER).start()

self.timers(COMPUTE_NORM_TIMER).stop()
all_groups_norm = get_flattened_grad_norm(non_experts_grads_for_norm,
mpu=self.mpu,
grad_norm_mask=self.flatten_grad_norm_mask_list)

if self.has_moe_layers:
all_groups_norm = get_norm_with_moe_layers(all_groups_norm,
Expand All @@ -276,6 +322,7 @@ def step(self, closure=None):
norm_type=self.norm_type)

scaled_global_grad_norm = get_global_norm(norm_list=[all_groups_norm])
self.timers(COMPUTE_NORM_TIMER).stop()

# Stash unscaled gradient norm
self._global_grad_norm = scaled_global_grad_norm / self.cur_scale
Expand All @@ -298,7 +345,7 @@ def step(self, closure=None):
updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i])
for p, q in zip(self.fp16_groups[i], updated_params):
p.data.copy_(q.data)

self.has_executed_step = True
self.timers(UPDATE_FP16_TIMER).stop()

self.timers.log(STEP_TIMERS)
Expand Down
54 changes: 21 additions & 33 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import torch
from deepspeed import comm as dist

try:
from torch._six import inf
except ModuleNotFoundError:
Expand Down Expand Up @@ -385,7 +384,7 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):
return total_norm


def get_grad_norm(parameters, norm_type=2, mpu=None):
def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=None):
"""Get grad norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
Expand All @@ -397,7 +396,8 @@ def get_grad_norm(parameters, norm_type=2, mpu=None):
single Tensor that will have gradients normalized
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
grad_norm_mask (List[Tensor]): A list of Tensor, where
each Tensor is a 2D Tensor containing ranges of [start_index, end_index].
Returns:
Total norm of the parameters (viewed as a single vector).
"""
Expand All @@ -415,18 +415,25 @@ def get_grad_norm(parameters, norm_type=2, mpu=None):
total_norm = total_norm_cuda[0].item()
else:
total_norm = 0.
tensor_mp_rank = bwc_tensor_model_parallel_rank(mpu=mpu)
for p in parameters:
# Pipeline parallelism may replicate parameters. Avoid multi-counting.
if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated:
continue

# Filter to avoid over-counting replicated tensors from tensor
# model parallelism
if (tensor_mp_rank > 0) and not is_model_parallel_parameter(p):
continue
for idx, p in enumerate(parameters):
# Use grad_norm_mask to avoid redundant computation of flattened gradient norm
if grad_norm_mask is not None and len(grad_norm_mask[idx]) > 0:

# A loop-free implementation to create a mask tensor based on a range list
# which is logically equivalent to the following implementation.
# # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool)
# # for mask_idx in grad_norm_mask[idx]:
# # mask_tensor_[mask_idx[0]:mask_idx[1]] = True
cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(),
dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1)
mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device(), dtype=p.dtype)
mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1),
cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1]

param_norm = torch.masked_fill(p.grad.data, mask_tensor, 0).float().norm(norm_type)

param_norm = p.grad.data.float().norm(norm_type)
else:
param_norm = p.grad.data.float().norm(norm_type)
total_norm += param_norm.item()**norm_type

# Sum across all model parallel GPUs.
Expand Down Expand Up @@ -814,25 +821,6 @@ def get_only_unique_item(items):
return unique_item


def clip_gradients(parameters, max_norm=1.0, global_grad_norm=None, mpu=None, eps=1e-6):
"""Clip the gradient of a list of parameters.
Args:
parameters: List of parameters whose .grad will be clipped.
global_grad_norm (float, optional): Precomputed gradient norm. Defaults to None.
mpu (optional): model parallelism unit. Defaults to None.
eps (float, optional): epsilon value added to grad norm. Defaults to 1e-6
Returns:
float: the global gradient norm
"""
if global_grad_norm is None:
global_grad_norm = get_grad_norm(parameters, mpu=mpu)
clip_coef = max_norm / (global_grad_norm + eps)
if clip_coef < 1:
for p in parameters:
p.grad.detach().mul_(clip_coef)
return global_grad_norm


def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=False, moe_ep_group=None):
"""Get norm of an iterable of tensors.
Expand Down
4 changes: 1 addition & 3 deletions docs/_sass/minimal-mistakes/_sidebar.scss
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,9 @@

@include breakpoint($large) {
position: absolute;
top: 0;
top: auto;
right: 0;
width: $right-sidebar-width-narrow;
margin-right: -1.5 * $right-sidebar-width-narrow;
padding-left: 1em;
z-index: 10;

Expand All @@ -94,7 +93,6 @@

@include breakpoint($x-large) {
width: $right-sidebar-width;
margin-right: -1.5 * $right-sidebar-width;
}
}

Expand Down
4 changes: 4 additions & 0 deletions tests/unit/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ def _launch_procs(self, num_procs):
f"Skipping test because not enough GPUs are available: {num_procs} required, {get_accelerator().device_count()} available"
)

if get_accelerator().device_name() == 'xpu':
self.non_daemonic_procs = True
self.reuse_dist_env = False

# Set start method to `forkserver` (or `fork`)
mp.set_start_method('forkserver', force=True)

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/ops/transformer/inference/test_residual_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def run_residual_add_reference(hidden_state, residual, attn_output, attn_bias, f
@pytest.mark.parametrize("use_triton_ops", [True, False])
def test_residual_add(inference_module, batch, sequence, hidden_dim, dtype, mlp_after_attn, add_bias, mp_size,
pre_attn_norm, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops and dtype == torch.float16:
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
ds_out = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
residual = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
Expand Down

0 comments on commit 2113f27

Please sign in to comment.