Skip to content

Commit

Permalink
fix: use a pre-pinned buffer for grad D2H copy
Browse files Browse the repository at this point in the history
  • Loading branch information
xylian86 committed Oct 29, 2024
1 parent 926451b commit 109bb6d
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,10 +538,15 @@ def _setup_for_real_optimizer(self):
self.grad_partitions_flat_buffer = get_accelerator().pin_memory(self.grad_partitions_flat_buffer)

offset = 0
max_partition_numel = 0
for param in all_params:
self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow(
0, offset, param.partition_numel())
offset += param.partition_numel()
max_partition_numel = max(max_partition_numel, param.partition_numel())
if self.offload_optimizer:
self.pinned_grad_buffer: Tensor = get_accelerator().pin_memory(
torch.empty(max_partition_numel, device=self.device))

def _link_all_hp_params(self):
for p in self.module.parameters():
Expand Down Expand Up @@ -1498,10 +1503,13 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L
offload_fp32_gradients[i].append(grad_buffer.float())
offload_fp32_offsets[i].append(dest_offset)
else:
buffer_numel = grad_buffer.numel()
fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow(
0, dest_offset, grad_buffer.numel())
fp32_grad_tensor.copy_(
grad_buffer.to(dtype=torch.float32, device=self.device, non_blocking=True).pin_memory())
0, dest_offset, buffer_numel)
self.pinned_grad_buffer[:buffer_numel].copy_(
grad_buffer.to(dtype=torch.float32, non_blocking=True))
get_accelerator().synchronize()
fp32_grad_tensor.copy_(self.pinned_grad_buffer[:buffer_numel], non_blocking=True)

# free the gradient
if not get_accelerator().is_synchronized_device():
Expand Down

0 comments on commit 109bb6d

Please sign in to comment.