Skip to content

Commit

Permalink
improve buffer allocation
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed May 1, 2024
1 parent 3e15fae commit b3c0750
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from types import MethodType
from collections import OrderedDict
from functools import reduce
from operator import mul

import torch
from deepspeed import comm as dist
Expand Down Expand Up @@ -180,6 +182,7 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
}
self.pipe_recv_buf = None
self.grad_layer = None
self._grad_layer_buf = []

self.meta_buffer = None

Expand Down Expand Up @@ -319,6 +322,7 @@ def reset_activation_shape(self):
self.first_output_send = True
self.pipe_recv_buf = None
self.grad_layer = None
self._grad_layer_buf = []
self.meta_buffer = None

self.pipe_partition_input_meta_cache = None
Expand Down Expand Up @@ -928,10 +932,12 @@ def _send_tensor_meta(self, buffer, recv_stage):
"""
send_bytes = 0
if isinstance(buffer, torch.Tensor):
send_dtype = torch.LongTensor(data=[self.DTYPE_TO_ID[buffer.dtype]]).to(self.device)
type_tensor = torch.LongTensor(data=[0]).to(self.device)
p2p.send(type_tensor, recv_stage)
send_shape = torch.LongTensor(data=buffer.size()).to(self.device)
send_ndims = torch.LongTensor(data=[len(buffer.size())]).to(self.device)
p2p.send(send_dtype, recv_stage)
p2p.send(send_ndims, recv_stage)
p2p.send(send_shape, recv_stage)
send_bytes += _tensor_bytes(buffer)
Expand Down Expand Up @@ -1000,20 +1006,24 @@ def _recv_tensor_meta(self, send_stage):

# A single tensor will be sent.
if recv_type == 0:
recv_dtype = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(recv_dtype, send_stage)
recv_dtype = self.ID_TO_DTYPE[recv_dtype.item()]
recv_ndims = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(recv_ndims, send_stage)
recv_ndims = recv_ndims.item()
recv_shape = torch.LongTensor([1] * recv_ndims).to(self.device)
p2p.recv(recv_shape, send_stage)
recv_shape = recv_shape.tolist()
return self._allocate_buffer(recv_shape, num_buffers=1)[0]
return self._allocate_or_extend_buffers(0, recv_shape, recv_dtype)

# List or tuple of tensors
elif recv_type == 1 or recv_type == 2:
count_tensor = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(count_tensor, send_stage)
num_tensors = count_tensor.item()
recv_shapes_and_dtypes = []
buffers = []
for idx in range(num_tensors):
recv_dtype = torch.LongTensor(data=[0]).to(self.device)
p2p.recv(recv_dtype, send_stage)
Expand All @@ -1025,7 +1035,8 @@ def _recv_tensor_meta(self, send_stage):
p2p.recv(recv_shape, send_stage)
recv_shapes_and_dtypes.append((recv_shape.tolist(), recv_dtype))

buffers = self._allocate_buffers(recv_shapes_and_dtypes, num_buffers=1)[0]
buffers.append(self._allocate_or_extend_buffers(idx, recv_shape.tolist(), recv_dtype))

# Convert to tuples if requested.
if recv_type == 2:
buffers = tuple(buffers)
Expand Down Expand Up @@ -1190,8 +1201,7 @@ def _exec_recv_grads(self, buffer_id):
# Allocate gradient if necessary
if self.dynamic_shape or self.grad_layer is None:
if isinstance(outputs, torch.Tensor):
s = list(outputs.size())
self.grad_layer = self._allocate_buffer(s, dtype=outputs.dtype, num_buffers=1)[0]
self.grad_layer = self._allocate_or_extend_buffers(0, list(outputs.size()), outputs.dtype)
else:
# XXX This is a HACK
# When we exchange activations/gradients, the two pipe stages
Expand All @@ -1213,7 +1223,11 @@ def _exec_recv_grads(self, buffer_id):
for t in outputs[2:] if t.is_floating_point()]
else:
sizes_and_dtypes = [(list(t.size()), t.dtype) for t in outputs if t.is_floating_point()]
self.grad_layer = self._allocate_buffers(sizes_and_dtypes, num_buffers=1)[0]

self.grad_layer = [
self._allocate_or_extend_buffers(i, size, dtype)
for i, (size, dtype) in enumerate(sizes_and_dtypes)
]

if isinstance(self.grad_layer, torch.Tensor):
p2p.recv(self.grad_layer, self.next_stage)
Expand Down Expand Up @@ -1305,6 +1319,18 @@ def _allocate_buffers(self, shapes_and_dtypes, requires_grad=False, num_buffers=
buffers.append(buffer)
return buffers

def _allocate_or_extend_buffers(self, idx, shape, dtype):
numel = reduce(mul, shape) if len(shape) > 0 else 1
if len(self._grad_layer_buf) <= idx or self._grad_layer_buf[idx].numel() > numel:
new_buf = self._allocate_buffer(shape, dtype=dtype, num_buffers=1)[0]
if len(self._grad_layer_buf) <= idx:
self._grad_layer_buf.append(new_buf)
else:
self._grad_layer_buf[idx] = new_buf
return self._grad_layer_buf[idx]
else:
return self._grad_layer_buf[idx].flatten()[:numel].view(shape)

def forward(self, *args, **kwargs):
"""Disabled for pipeline parallel training. See ``train_batch()``. """
raise PipelineError("Only train_batch() is accessible in pipeline mode.")
Expand Down

0 comments on commit b3c0750

Please sign in to comment.