Skip to content

Commit

Permalink
Code review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
oelayan7 committed Dec 17, 2024
1 parent e2f8932 commit 7d39f60
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
2 changes: 1 addition & 1 deletion deepspeed/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):

keep_module_on_host: bool = False
"""
When loading checkpoints to model parameters, they are moved to the device. In large very models
When loading checkpoints to model parameters, they are moved to the device. In very large models
this might fill the device and cause OOM. Setting this flag to true, will keep checkpoints on
host and not move them directly to the device (giving an option to quantize checkpoint data before
moving it to the device for example).
Expand Down
30 changes: 14 additions & 16 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


def move(tensor, device, keep_module_on_host=False):
def move(tensor, device, copy=True):
if tensor.is_meta:
return torch.empty_like(tensor, device='cpu' if keep_module_on_host else device)
elif keep_module_on_host:
return tensor.to('cpu') if device != 'cpu' else tensor
return torch.empty_like(tensor, device=device)
else:
# Using new tensors help in freeing memory (after split for example) was done before by calling clone().
# Using copy=True instead of clone() will help in case of cpu --> cpu.
# Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced.
return tensor.to(device, copy=True)
return tensor.to(device, copy=copy)


class ReplaceWithTensorSlicing:
Expand Down Expand Up @@ -340,6 +338,10 @@ def set_tensor_parallel_config(self, mp_size, mp_group):
def _replace(self, child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
return
device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name()
# keep_module_on_host is used to keep the module on the host. Checkpoints are loaded to the host first (in some
# cases it can be done from the disk even to prevent filling host's memory), thus no need to create a new copy.
return_new_copy = not self.keep_module_on_host
weight_shape = child.weight.shape
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
# For mixtral-7x8b, need to skip MoE gate linear replace.
Expand Down Expand Up @@ -369,19 +371,17 @@ def _replace(self, child, name, conv_linear_layer):
data = child.weight.data.split(get_shard_size_list(
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name),
dim=1)
data_dc = move(data[mp_replace.gpu_index],
get_accelerator().current_device_name(), self.keep_module_on_host).detach()
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
del data

setattr(child, "replaced", True)
if name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(
torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
child.bias if child.bias is None else torch.nn.parameter.Parameter(
move(child.bias,
get_accelerator().current_device_name(), self.keep_module_on_host)), self.mp_group)
move(child.bias, device_name, return_new_copy)), self.mp_group)
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name(), self.keep_module_on_host)), self.mp_group)
torch.nn.parameter.Parameter(move(child.bias, device_name, return_new_copy)), self.mp_group)
else:

# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
Expand All @@ -394,24 +394,22 @@ def _replace(self, child, name, conv_linear_layer):
#The copy is a regular copy, The shape of dst and src is the same
data_dc = move(
prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index),
get_accelerator().current_device_name(), self.keep_module_on_host)
device_name, return_new_copy)

bias_data_dc = None if child.bias is None else move(
prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index),
get_accelerator().current_device_name(), self.keep_module_on_host)
device_name, return_new_copy)
else:
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name),
dim=1 if self.conv_linear_layer else 0)
data_dc = move(data[mp_replace.gpu_index],
get_accelerator().current_device_name(), self.keep_module_on_host).detach()
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
del data

if child.bias is not None:
bias_data = child.bias.data.split(get_shard_size_list(
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name),
dim=0)
bias_data = move(bias_data[mp_replace.gpu_index],
get_accelerator().current_device_name(), self.keep_module_on_host)
bias_data = move(bias_data[mp_replace.gpu_index], device_name, return_new_copy)
bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
del bias_data
else:
Expand Down

0 comments on commit 7d39f60

Please sign in to comment.