Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoTP] Make AutoTP work when num_heads not divisible by number of workers #4011

Merged
merged 55 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
0706acd
allow number of heads not divisible by number of ranks
delock Jul 20, 2023
0bf785f
get num_heads from model config, more robust
delock Jul 21, 2023
72b9e1a
simplify logic where num_head itself is sharded
delock Jul 21, 2023
5ed9a56
name tweaks
delock Jul 21, 2023
73f499d
make code more robust where num_attention_heads may not be defined in…
delock Jul 21, 2023
48322c7
Merge branch 'master' into gma/uneven_heads
delock Jul 21, 2023
f14e290
Merge branch 'master' into gma/uneven_heads
delock Jul 24, 2023
b62317c
Merge branch 'master' into gma/uneven_heads
loadams Jul 24, 2023
12c0628
support num_key_value_heads < num_attention_heads which is used by ll…
delock Jul 25, 2023
8f23d9b
add test for 5 ranks
delock Jul 25, 2023
9c53bd7
change odd rank # to 3 to avoid test skip
delock Jul 25, 2023
413224b
Merge branch 'master' into gma/uneven_heads
tjruwase Jul 25, 2023
78d6667
Merge branch 'master' into gma/uneven_heads
delock Aug 9, 2023
27fde30
add get_shard_size function
delock Aug 9, 2023
8e1fd27
modify sharding mechanism according to latest auto TP
delock Aug 10, 2023
9a6bc12
Merge branch 'master' into gma/uneven_heads
delock Aug 16, 2023
2dac94f
fix accuracy issue
delock Aug 17, 2023
885f6a3
Merge branch 'master' into gma/uneven_heads
delock Aug 17, 2023
7ffd811
Merge branch 'master' into gma/uneven_heads
molly-smith Aug 18, 2023
40659ba
Merge branch 'master' into gma/uneven_heads
tjruwase Aug 22, 2023
71f9f40
fix format
delock Aug 21, 2023
db9db6b
skip tests with fusedqkv
delock Aug 23, 2023
72531c0
Merge branch 'master' into gma/uneven_heads
delock Aug 23, 2023
9d5eae3
remove skip of fusedqkv tests
delock Aug 23, 2023
25e656d
skip test fusedqkv with odd number of ranks
delock Aug 23, 2023
7f6d7f6
support model with n_heads in model_config
delock Aug 24, 2023
e3a5b77
Merge branch 'master' into gma/uneven_heads
molly-smith Aug 24, 2023
c9ec881
Merge branch 'master' into gma/uneven_heads
delock Aug 26, 2023
f5be257
fix TestInjectionPolicy::test[fp32-t5]
delock Aug 27, 2023
b671040
fix uneven_heads on some fusedqkv types (#12)
inkcherry Aug 28, 2023
d59ff22
better fix when activation size cannot be divided by number of heads
delock Aug 30, 2023
6c3c841
Merge branch 'master' into gma/uneven_heads_rebase
delock Aug 30, 2023
58e8b24
Merge branch 'master' into gma/uneven_heads
molly-smith Sep 1, 2023
4c6b7fa
move tp_shard.py under module_inject
delock Sep 6, 2023
18e1c5d
Merge branch 'master' into gma/uneven_heads
delock Sep 6, 2023
8ef01e2
Add get_num_kv_heads in tp_shard.py
delock Sep 7, 2023
9a61fc2
Merge branch 'master' into gma/uneven_heads
delock Sep 11, 2023
74870db
Merge branch 'master' into gma/uneven_heads
delock Sep 13, 2023
115cc20
Merge branch 'master' into gma/uneven_heads
molly-smith Sep 13, 2023
0781c41
Refine according to comments
delock Sep 14, 2023
194337f
remove old comment
mrwyattii Sep 14, 2023
47d84ca
Merge branch 'master' into gma/uneven_heads
delock Sep 18, 2023
369eb3e
Merge branch 'master' into gma/uneven_heads
mrwyattii Sep 19, 2023
567fb9a
fix bug in getting num_kv_heads
delock Sep 20, 2023
47c83ca
Merge branch 'master' into gma/uneven_heads
molly-smith Sep 20, 2023
d194ab0
Merge branch 'master' into gma/uneven_heads
tjruwase Sep 27, 2023
6db5ddd
Merge branch 'master' into gma/uneven_heads
delock Oct 7, 2023
698b62a
Merge branch 'up-master' into gma/uneven_heads
delock Oct 10, 2023
d75149f
support uneven sharding of lm_head tensor parallel
delock Oct 10, 2023
248532d
Merge branch 'master' into gma/uneven_heads
delock Oct 11, 2023
a9056fd
Merge branch 'master' into gma/uneven_heads
delock Oct 11, 2023
81bd29f
Merge branch 'master' into gma/uneven_heads
delock Oct 12, 2023
693a9fe
Merge branch 'master' into gma/uneven_heads
delock Oct 18, 2023
4c45a5b
Merge branch 'master' into gma/uneven_heads
delock Oct 19, 2023
a7513e1
Merge branch 'master' into gma/uneven_heads
delock Oct 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .layers import LinearAllreduce, LinearLayer
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw
from deepspeed.utils.tp_shard import get_shard_size, get_shard_size_list


class ReplaceWithTensorSlicing:
Expand Down Expand Up @@ -308,8 +309,9 @@ def _replace(self, child, name, conv_linear_layer):

if self.conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
data = child.weight.data.split(
(weight_shape[0] if self.conv_linear_layer else weight_shape[1]) // self.mp_size, dim=1)
data = child.weight.data.split(get_shard_size_list(
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size),
dim=1)
data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())

setattr(child, "replaced", True)
Expand All @@ -332,13 +334,14 @@ def _replace(self, child, name, conv_linear_layer):
module_str, child.bias.data, self.mp_size, mp_replace.gpu_index).to(
get_accelerator().current_device_name())
else:
data = child.weight.data.split((weight_shape[0]) // self.mp_size,
data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size),
dim=1 if self.conv_linear_layer else 0)
data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())

if child.bias is not None:
bias_data = child.bias.data.split(
(weight_shape[1] if self.conv_linear_layer else weight_shape[0]) // self.mp_size, dim=0)
bias_data = child.bias.data.split(get_shard_size_list(
weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size),
dim=0)
bias_data = bias_data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
bias_data = torch.nn.parameter.Parameter(bias_data, requires_grad=False)
else:
Expand All @@ -354,13 +357,13 @@ def _slice_embedding(self, child, name, conv_linear_layer):
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)

if hasattr(child.weight, 'ds_tensor'):
data = child.weight.ds_tensor.data.split(child.weight.shape[1] // self.mp_size, dim=1)
data = child.weight.ds_tensor.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1)
else:
data = child.weight.data.split(child.weight.shape[1] // self.mp_size, dim=1)
data = child.weight.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1)
data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name())
data = torch.nn.parameter.Parameter(data, requires_grad=False)

new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // self.mp_size)
new_embedding = nn.Embedding(child.weight.shape[0], get_shard_size(child.weight.shape[1], self.mp_size))
new_embedding.weight.data.copy_(data)
setattr(child, "replaced", True)
return new_embedding
Expand All @@ -374,8 +377,8 @@ def update_mp_params(self, child):
]:
if hasattr(child, param):
param_val = getattr(child, param)
assert param_val % self.mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({self.mp_size})"
setattr(child, param, param_val // self.mp_size)
#assert param_val % self.mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({self.mp_size})"
mrwyattii marked this conversation as resolved.
Show resolved Hide resolved
setattr(child, param, get_shard_size(param_val, self.mp_size))
setattr(child, "replaced", True)

def update_linear_policies(self):
Expand Down
9 changes: 5 additions & 4 deletions deepspeed/module_inject/auto_tp_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from deepspeed import comm as dist
import torch
from typing import Optional
from deepspeed.utils.tp_shard import get_shard_size, get_shard_size_list


def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
Expand Down Expand Up @@ -51,8 +52,8 @@ def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
if dist.is_initialized():
num_heads_per_rank = int(num_heads / dist.get_world_size())
offset = dist.get_rank() * num_heads_per_rank
num_heads_per_rank = get_shard_size(num_heads, dist.get_world_size())
offset = sum(get_shard_size_list(num_heads, dist.get_world_size())[0:dist.get_rank()])
alibi = alibi.view(batch_size, num_heads, 1, seq_length)
alibi = alibi[:, offset:num_heads_per_rank + offset, :, :]
return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype)
Expand All @@ -72,8 +73,8 @@ def build_mpt_atten_bias_tensor(self,
prefix_mask=prefix_mask,
sequence_id=sequence_id)
if dist.is_initialized():
num_heads_per_rank = int(self.config.n_heads / dist.get_world_size())
offset = dist.get_rank() * num_heads_per_rank
num_heads_per_rank = get_shard_size(self.config.n_heads, dist.get_world_size())
offset = sum(get_shard_size_list(self.config.n_heads, dist.get_world_size())[0:dist.get_rank()])
attn_bias = attn_bias[:, offset:num_heads_per_rank + offset, :, :]
return attn_bias, attention_mask

Expand Down
23 changes: 13 additions & 10 deletions deepspeed/module_inject/fusedqkv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# DeepSpeed Team
import torch
from deepspeed.utils.logging import warning_once
from deepspeed.utils.tp_shard import get_shard_size, get_shard_size_list
import deepspeed.utils.tp_shard as tp_shard
import re


Expand Down Expand Up @@ -39,18 +41,19 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index):

def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):
# codegen_mp_num defined in https://github.com/huggingface/transformers/blob/main/src/transformers/models/codegen/modeling_codegen.py
#TODO: assert num_heads % (mp_size*codegen_mp_num) == 0
assert tp_shard.num_kv_heads % (
mp_size * codegen_mp_num) == 0, "codgen autoTP requires num_kv_heads % (mp_size*codegen_mp_num) == 0"
#input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias)

shape = input.shape
dst_shape = shape[0] // mp_size
dst_shape = get_shard_size(shape[0], mp_size)
num_mp_blocks = input.reshape(codegen_mp_num, shape[0] // codegen_mp_num, shape[1])

#num_mp_blocks : [codegen_mp_num, 3*hidden_dim/codegen_mp_num, :]
src_split = list(torch.split(num_mp_blocks, num_mp_blocks.shape[1] // 3, dim=1))
src_split = [x.reshape(codegen_mp_num * mp_size, -1, shape[1]) for x in src_split]

split_fusedqkv = split_by_qkvlist_and_refuse(src_split, shape[0] // 3 // mp_size, 0, 1)
split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size(shape[0] // 3, mp_size), 0, 1)
tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0).reshape(shape[0], -1)

return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]
Expand All @@ -59,16 +62,16 @@ def _glm_type_transpose(input, mp_size):
#input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias)

shape = input.shape
dst_shape = shape[0] // mp_size
src_split = torch.split(input, shape[0] // 3, dim=0)

split_fusedqkv = split_by_qkvlist_and_refuse(src_split, shape[0] // 3 // mp_size)
tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0)

return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape]
split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size_list(shape[0] // 3, mp_size))
return split_fusedqkv[gpu_index]

def _bloom_type_transpose(input, mp_size):
return input
shape = input.shape

split_fusedqkv = input.split(get_shard_size_list(shape[0], mp_size), dim=0)
return split_fusedqkv[gpu_index]

def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None):

Expand All @@ -91,4 +94,4 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None):
return _transpose_fused_qkvw(src, mp_size, fused_type)
warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type,"
f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors")
return src
return _bloom_type_transpose(src, mp_size)
17 changes: 15 additions & 2 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading

from deepspeed import comm as dist
from deepspeed.utils.tp_shard import set_num_kv_heads

from .load_checkpoint import load_model_with_checkpoint
import time
Expand Down Expand Up @@ -271,10 +272,22 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
# 2. Set the tensor parallelism config
_autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group)

# 3. Set linear policies
# 3. Try to get num_key_heads from model_config.num_key_value_heads
num_kv_heads = None
kv_head_names = ['num_key_value_heads', 'num_attention_heads', 'n_heads']
for name in kv_head_names:
if hasattr(model_config, name):
num_kv_heads = getattr(model_config, name)
if num_kv_heads != None:
break

# 5. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division
set_num_kv_heads(num_kv_heads)

# 6. Set linear policies
mrwyattii marked this conversation as resolved.
Show resolved Hide resolved
_autotp.update_linear_policies()

# 4. Replace modules
# 7. Replace modules
return _autotp._replace_module(module)

def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None):
Expand Down
35 changes: 35 additions & 0 deletions deepspeed/utils/tp_shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from deepspeed import comm as dist
global num_kv_heads
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would adding a global variable here be a good practice? Maybe you can make it a class attribute by adding a util class like Loading? Or combine these functions with AutoTP or ReplaceWithTensorSlicing ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_shard_size and get_shard_size_list are called from three different files. Wrap num_kv_heads in a class attribute will need a place to hold the instance. Any suggstion how to ref to this class instance?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @dc3671 that it would be nice to have these functions in AutoTP but that could create circular dependency with fusedqkv_utils.py. I think the module_inject folder is more appropriate than the utils folder though. I think adding these functions to auto_tp_model_utils.py would also be okay.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @molly-smith I'll try move this file into module_inject folder.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@molly-smith tp_shard.py had been moved into module_inject folder.



def set_num_kv_heads(num):
global num_kv_heads
num_kv_heads = num


def get_shard_size(total_size, mp_size, rank=None):
global num_kv_heads
# When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division
# In the case that total_size cannot be divided by num_kv_heads, only even sharding is possible
if num_kv_heads != None and (total_size % num_kv_heads) == 0:
if (rank == None):
rank = dist.get_rank()
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
return (total_size // num_kv_heads) * my_slices
else:
if total_size % mp_size == 0:
return total_size // mp_size
else:
assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({mp_size})"


def get_shard_size_list(total_size, mp_size):
shard_sizes = []
for i in range(mp_size):
shard_sizes.append(get_shard_size(total_size, mp_size, i))
return shard_sizes
33 changes: 33 additions & 0 deletions tests/unit/inference/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,39 @@ def test(
print(local_rank, "deepspeed", ds_output)
assert assert_fn(bs_output, ds_output)

@pytest.mark.world_size(3)
def test_odd_world_size(
self,
model_w_task,
query,
inf_kwargs,
assert_fn,
dtype,
):
invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False)
if invalid_test_msg:
pytest.skip(invalid_test_msg)

model, task = model_w_task
if model == "Salesforce/codegen-350M-mono":
pytest.skip("codegen does not supported by odd world_size")
local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "3"))

# We have to load these large models on CPU with pipeline because not
# enough GPU memory
mrwyattii marked this conversation as resolved.
Show resolved Hide resolved
pipe = pipeline(task, model=model, device=torch.device("cpu"), framework="pt")
mrwyattii marked this conversation as resolved.
Show resolved Hide resolved
bs_output = pipe(query, **inf_kwargs)

pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype)
# Switch device to GPU so that input tensors are not on CPU
pipe.device = torch.device(get_accelerator().device_name(local_rank))
mrwyattii marked this conversation as resolved.
Show resolved Hide resolved
ds_output = pipe(query, **inf_kwargs)

print(local_rank, "baseline", bs_output)
print(local_rank, "deepspeed", ds_output)
assert assert_fn(bs_output, ds_output)


@pytest.mark.nightly
@pytest.mark.parametrize(
Expand Down