diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index e068f4a48b4a..c1c2b6c61cfd 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -380,8 +380,7 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): for p in parameters: all_norms.append(p.grad.data.abs().max().float()) total_norm = torch.stack(all_norms).max() - origin_device = total_norm.device.type - total_norm = total_norm.to(get_accelerator().device_name()) + total_norm = total_norm.to(get_accelerator().current_device_name()) # Take max across all GPUs. if mpu is not None: dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) @@ -398,9 +397,8 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): if len(all_norms) > 0: total_norm = torch.stack(all_norms).square().sum().float() else: - total_norm = torch.FloatTensor([0.0]).to(parameters[0].device) - origin_device = total_norm.device.type - total_norm = total_norm.to(get_accelerator().device_name()) + total_norm = get_accelerator().FloatTensor([0.0]) + total_norm = total_norm.to(get_accelerator().current_device_name()) # Sum across all model parallel GPUs. if mpu is not None: dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) @@ -413,11 +411,11 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2, mpu=None): dist.all_reduce(scaled_norm_tensor, group=pg) total_norm = scaled_norm_tensor - total_norm = total_norm.to(origin_device) + total_norm = total_norm.to(parameters[0].device) - max_norm = torch.tensor([float(max_norm)], device=parameters[0].device) + max_norm = torch.tensor([float(max_norm)], device=total_norm.device) clip_coef = max_norm / (total_norm + 1e-6) - tmp_tensor = torch.tensor([1.0], device=parameters[0].device) + tmp_tensor = torch.tensor([1.0], device=clip_coef.device) clip_coef = torch.min(tmp_tensor, clip_coef) for p in parameters: p.grad.data.mul_(clip_coef) @@ -890,42 +888,48 @@ def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None, use_graph=F assert all([torch.is_tensor(t) for t in input_tensors]), f'expected list of only tensors' norm_type = float(norm_type) + all_norms = [] if norm_type == inf: - total_norm = max(t.data.abs().max() for t in input_tensors) - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + for t in input_tensors: + all_norms.append(t.data.abs().max().float()) + total_norm = torch.stack(all_norms).max() + device_total_norm = total_norm.to(get_accelerator().current_device_name()) if mpu is not None: - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) + dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=mpu.get_model_parallel_group()) if moe_ep_group is not None: - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=moe_ep_group) - total_norm = total_norm_cuda[0].item() + dist.all_reduce(device_total_norm, op=dist.ReduceOp.MAX, group=moe_ep_group) + total_norm = device_total_norm.to(input_tensors[0].device) else: - if use_graph: - if 'norm_tensors_compute_buffer' not in graph_cache: - graph_cache['norm_tensors_compute_buffer'] = [t.data.float().norm(norm_type) for t in input_tensors] - compute_buffer = graph_cache['norm_tensors_compute_buffer'] - def _norm_tensors(tensor_list, _compute_buffer, _norm_type): - for i, t in enumerate(tensor_list): - _compute_buffer[i].data.copy_(t.data.float().norm(_norm_type)**_norm_type) - if i != 0: - _compute_buffer[0].data.add_(_compute_buffer[i].data) + if 'norm_tensors_compute_buffer' not in graph_cache or len( + graph_cache['norm_tensors_compute_buffer']) != len(input_tensors): + graph_cache['norm_tensors_compute_buffer'] = [ + torch.empty([], dtype=torch.float, device=get_accelerator().current_device_name()) + for t in input_tensors + ] + compute_buffer = graph_cache['norm_tensors_compute_buffer'] - graph_process(False, _norm_tensors, input_tensors, compute_buffer, norm_type) + def _norm_tensors(tensor_list, _compute_buffer, _norm_type): + for i, t in enumerate(tensor_list): + _compute_buffer[i].data.copy_(t.data.float().norm(_norm_type)**_norm_type) + if i != 0: + _compute_buffer[0].data.add_(_compute_buffer[i].data) - total_norm = compute_buffer[0] + if use_graph: + graph_process(False, _norm_tensors, input_tensors, compute_buffer, norm_type) else: - total_norm = sum([t.data.float().norm(norm_type).item()**norm_type for t in input_tensors]) + _norm_tensors(input_tensors, compute_buffer, norm_type) + + device_total_norm = compute_buffer[0].float().detach() - total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]).detach() if mpu is not None: - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) + dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=mpu.get_model_parallel_group()) if moe_ep_group is not None: - dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=moe_ep_group) + dist.all_reduce(device_total_norm, op=dist.ReduceOp.SUM, group=moe_ep_group) + total_norm = device_total_norm.to(input_tensors[0].device).pow(1. / norm_type) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) - - if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 + inf_or_nan = total_norm.isinf().logical_or(total_norm.isnan()) + total_norm.masked_fill_(inf_or_nan, -1) return total_norm