Skip to content

Commit

Permalink
DeepSpeedZeroOptimizer: refactor bit16 flattening to support more acc…
Browse files Browse the repository at this point in the history
…elerators (microsoft#4833)

The approach till today use the practice where the torch.nn.parameter
data is being replaced with a new cpu data storage, to offload device
memory.
All params are being flatenned on the host and moved to the device.
in some accelerators torch.nn.parameter which is a device parameter
cannot be assigned with a cpu storage.
This PR copy the param data into a new cpu tensor, and shrinks the
device storage.
Later when the flat buffer is moved to the device param.data will be a
view to the flat buffer.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Michael Wyatt <[email protected]>
  • Loading branch information
3 people authored Jan 10, 2024
1 parent ed10cc7 commit ade9836
Showing 1 changed file with 33 additions and 17 deletions.
50 changes: 33 additions & 17 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,6 @@ def get_alignment_padding(tensor_list, alignment):
return (alignment - remainder) if remainder else remainder


def move_to_cpu(tensor_list):
for tensor in tensor_list:
tensor.data = tensor.data.cpu()


def print_rank_msg(msg):
print(f"rank {dist.get_rank()} - {msg}")

Expand Down Expand Up @@ -294,6 +289,7 @@ def __init__(self,

self.round_robin_bit16_groups = []
self.round_robin_bit16_indices = []
self.round_robin_bit16_meta = []

# Use different parallel to do all_to_all_reduce related things
# padding on each partition for alignment purposes
Expand All @@ -316,7 +312,14 @@ def __init__(self,

see_memory_usage(f"Before moving param group {i} to CPU")
# move all the parameters to cpu to free up GPU space for creating flat buffer
move_to_cpu(self.bit16_groups[i])

# Create temp CPU param copies, free accelerator tensors
orig_group_numel = 0
for param in self.bit16_groups[i]:
orig_group_numel += param.numel()
param.cpu_data = param.data.cpu()
param.data = torch.empty(1).to(param.device)

empty_cache()
see_memory_usage(f"After moving param group {i} to CPU", force=False)

Expand All @@ -334,18 +337,31 @@ def __init__(self,
self.round_robin_bit16_groups.append(round_robin_tensors)
self.round_robin_bit16_indices.append(round_robin_indices)

# create flat buffer in CPU and move to GPU
self.bit16_groups_flat.append(
self.flatten_dense_tensors_aligned(
self.round_robin_bit16_groups[i],
self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i])).to(
get_accelerator().current_device_name()))
# Create meta tensors list, ordered according to round_robin_tensors
meta_tensors = []
for param in round_robin_tensors:
meta_tensors.append(torch.zeros_like(param.cpu_data, device="meta"))
self.round_robin_bit16_meta.append(meta_tensors)

# create flat buffer in CPU
flattened_buffer = self.flatten_dense_tensors_aligned(
self.round_robin_bit16_groups[i],
self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i]),
use_cpu_data=True)

# free temp CPU params
for param in self.bit16_groups[i]:
del param.cpu_data

# Move CPU flat tensor to the accelerator memory.
self.bit16_groups_flat.append(flattened_buffer.to(get_accelerator().current_device_name()))
del flattened_buffer

see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False)

# Record padding required for alignment
if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1:
padding = self.bit16_groups_flat[i].numel() - sum(
[t.numel() for t in self.round_robin_bit16_groups[i]])
padding = self.bit16_groups_flat[i].numel() - orig_group_numel
else:
padding = 0
self.groups_padding.append(padding)
Expand Down Expand Up @@ -596,8 +612,7 @@ def _configure_moe_settings(self):
assert self.ep_process_group is not None, "Expert parallel group should be configured with MoE"

def _update_model_bit16_weights(self, group_index):
updated_params = self.unflatten(self.bit16_groups_flat[group_index],
self.round_robin_bit16_groups[group_index])
updated_params = self.unflatten(self.bit16_groups_flat[group_index], self.round_robin_bit16_meta[group_index])
for p, q in zip(self.round_robin_bit16_groups[group_index], updated_params):
p.data = q.data

Expand Down Expand Up @@ -887,7 +902,8 @@ def report_ipg_memory_usage(self, tag, param_elems):
)

# create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_aligned(self, tensor_list, alignment):
def flatten_dense_tensors_aligned(self, tensor_list, alignment, use_cpu_data=False):
tensor_list = [param.cpu_data for param in tensor_list] if use_cpu_data else tensor_list
return self.flatten(align_dense_tensors(tensor_list, alignment))

############### Independent Partition Gradient ########################
Expand Down

0 comments on commit ade9836

Please sign in to comment.