Skip to content

Commit

Permalink
compute global norm on device (#5125)
Browse files Browse the repository at this point in the history
Avoid host synchronization by keeping data on device

---------

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
3 people authored Apr 3, 2024
1 parent 548d37b commit 9f0e213
Showing 1 changed file with 36 additions and 32 deletions.
68 changes: 36 additions & 32 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,7 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):
for p in parameters:
all_norms.append(p.grad.data.abs().max().float())
total_norm = torch.stack(all_norms).max()
origin_device = total_norm.device.type
total_norm = total_norm.to(get_accelerator().device_name())
total_norm = total_norm.to(get_accelerator().current_device_name())
# Take max across all GPUs.
if mpu is not None:
dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
Expand All @@ -398,9 +397,8 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):
if len(all_norms) > 0:
total_norm = torch.stack(all_norms).square().sum().float()
else:
total_norm = torch.FloatTensor([0.0]).to(parameters[0].device)
origin_device = total_norm.device.type
total_norm = total_norm.to(get_accelerator().device_name())
total_norm = get_accelerator().FloatTensor([0.0])
total_norm = total_norm.to(get_accelerator().current_device_name())
# Sum across all model parallel GPUs.
if mpu is not None:
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
Expand All @@ -413,11 +411,11 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None):

dist.all_reduce(scaled_norm_tensor, group=pg)
total_norm = scaled_norm_tensor
total_norm = total_norm.to(origin_device)
total_norm = total_norm.to(parameters[0].device)

max_norm = torch.tensor([float(max_norm)], device=parameters[0].device)
max_norm = torch.tensor([float(max_norm)], device=total_norm.device)
clip_coef = max_norm / (total_norm + 1e-6)
tmp_tensor = torch.tensor([1.0], device=parameters[0].device)
tmp_tensor = torch.tensor([1.0], device=clip_coef.device)
clip_coef = torch.min(tmp_tensor, clip_coef)
for p in parameters:
p.grad.data.mul_(clip_coef)
Expand Down Expand Up @@ -890,42 +888,48 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=F
assert all([torch.is_tensor(t) for t in input_tensors]), f'expected list of only tensors'

norm_type = float(norm_type)
all_norms = []
if norm_type == inf:
total_norm = max(t.data.abs().max() for t in input_tensors)
total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
for t in input_tensors:
all_norms.append(t.data.abs().max().float())
total_norm = torch.stack(all_norms).max()
device_total_norm = total_norm.to(get_accelerator().current_device_name())
if mpu is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group())
if moe_ep_group is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=moe_ep_group)
total_norm = total_norm_cuda[0].item()
dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=moe_ep_group)
total_norm = device_total_norm.to(input_tensors[0].device)
else:
if use_graph:
if 'norm_tensors_compute_buffer' not in graph_cache:
graph_cache['norm_tensors_compute_buffer'] = [t.data.float().norm(norm_type) for t in input_tensors]
compute_buffer = graph_cache['norm_tensors_compute_buffer']

def _norm_tensors(tensor_list, _compute_buffer, _norm_type):
for i, t in enumerate(tensor_list):
_compute_buffer[i].data.copy_(t.data.float().norm(_norm_type)**_norm_type)
if i != 0:
_compute_buffer[0].data.add_(_compute_buffer[i].data)
if 'norm_tensors_compute_buffer' not in graph_cache or len(
graph_cache['norm_tensors_compute_buffer']) != len(input_tensors):
graph_cache['norm_tensors_compute_buffer'] = [
torch.empty([], dtype=torch.float, device=get_accelerator().current_device_name())
for t in input_tensors
]
compute_buffer = graph_cache['norm_tensors_compute_buffer']

graph_process(False, _norm_tensors, input_tensors, compute_buffer, norm_type)
def _norm_tensors(tensor_list, _compute_buffer, _norm_type):
for i, t in enumerate(tensor_list):
_compute_buffer[i].data.copy_(t.data.float().norm(_norm_type)**_norm_type)
if i != 0:
_compute_buffer[0].data.add_(_compute_buffer[i].data)

total_norm = compute_buffer[0]
if use_graph:
graph_process(False, _norm_tensors, input_tensors, compute_buffer, norm_type)
else:
total_norm = sum([t.data.float().norm(norm_type).item()**norm_type for t in input_tensors])
_norm_tensors(input_tensors, compute_buffer, norm_type)

device_total_norm = compute_buffer[0].float().detach()

total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]).detach()
if mpu is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group())
if moe_ep_group is not None:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=moe_ep_group)
dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=moe_ep_group)
total_norm = device_total_norm.to(input_tensors[0].device).pow(1. / norm_type)

total_norm = total_norm_cuda[0].item()**(1. / norm_type)

if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = -1
inf_or_nan = total_norm.isinf().logical_or(total_norm.isnan())
total_norm.masked_fill_(inf_or_nan, -1)

return total_norm

Expand Down

0 comments on commit 9f0e213

Please sign in to comment.