Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoTP] Make AutoTP work when num_heads not divisible by number of workers #4011

Merged
merged 55 commits into from
Oct 25, 2023
Merged
Changes from 7 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
0706acd
allow number of heads not divisible by number of ranks
delock Jul 20, 2023
0bf785f
get num_heads from model config, more robust
delock Jul 21, 2023
72b9e1a
simplify logic where num_head itself is sharded
delock Jul 21, 2023
5ed9a56
name tweaks
delock Jul 21, 2023
73f499d
make code more robust where num_attention_heads may not be defined in…
delock Jul 21, 2023
48322c7
Merge branch 'master' into gma/uneven_heads
delock Jul 21, 2023
f14e290
Merge branch 'master' into gma/uneven_heads
delock Jul 24, 2023
b62317c
Merge branch 'master' into gma/uneven_heads
loadams Jul 24, 2023
12c0628
support num_key_value_heads < num_attention_heads which is used by ll…
delock Jul 25, 2023
8f23d9b
add test for 5 ranks
delock Jul 25, 2023
9c53bd7
change odd rank # to 3 to avoid test skip
delock Jul 25, 2023
413224b
Merge branch 'master' into gma/uneven_heads
tjruwase Jul 25, 2023
78d6667
Merge branch 'master' into gma/uneven_heads
delock Aug 9, 2023
27fde30
add get_shard_size function
delock Aug 9, 2023
8e1fd27
modify sharding mechanism according to latest auto TP
delock Aug 10, 2023
9a6bc12
Merge branch 'master' into gma/uneven_heads
delock Aug 16, 2023
2dac94f
fix accuracy issue
delock Aug 17, 2023
885f6a3
Merge branch 'master' into gma/uneven_heads
delock Aug 17, 2023
7ffd811
Merge branch 'master' into gma/uneven_heads
molly-smith Aug 18, 2023
40659ba
Merge branch 'master' into gma/uneven_heads
tjruwase Aug 22, 2023
71f9f40
fix format
delock Aug 21, 2023
db9db6b
skip tests with fusedqkv
delock Aug 23, 2023
72531c0
Merge branch 'master' into gma/uneven_heads
delock Aug 23, 2023
9d5eae3
remove skip of fusedqkv tests
delock Aug 23, 2023
25e656d
skip test fusedqkv with odd number of ranks
delock Aug 23, 2023
7f6d7f6
support model with n_heads in model_config
delock Aug 24, 2023
e3a5b77
Merge branch 'master' into gma/uneven_heads
molly-smith Aug 24, 2023
c9ec881
Merge branch 'master' into gma/uneven_heads
delock Aug 26, 2023
f5be257
fix TestInjectionPolicy::test[fp32-t5]
delock Aug 27, 2023
b671040
fix uneven_heads on some fusedqkv types (#12)
inkcherry Aug 28, 2023
d59ff22
better fix when activation size cannot be divided by number of heads
delock Aug 30, 2023
6c3c841
Merge branch 'master' into gma/uneven_heads_rebase
delock Aug 30, 2023
58e8b24
Merge branch 'master' into gma/uneven_heads
molly-smith Sep 1, 2023
4c6b7fa
move tp_shard.py under module_inject
delock Sep 6, 2023
18e1c5d
Merge branch 'master' into gma/uneven_heads
delock Sep 6, 2023
8ef01e2
Add get_num_kv_heads in tp_shard.py
delock Sep 7, 2023
9a61fc2
Merge branch 'master' into gma/uneven_heads
delock Sep 11, 2023
74870db
Merge branch 'master' into gma/uneven_heads
delock Sep 13, 2023
115cc20
Merge branch 'master' into gma/uneven_heads
molly-smith Sep 13, 2023
0781c41
Refine according to comments
delock Sep 14, 2023
194337f
remove old comment
mrwyattii Sep 14, 2023
47d84ca
Merge branch 'master' into gma/uneven_heads
delock Sep 18, 2023
369eb3e
Merge branch 'master' into gma/uneven_heads
mrwyattii Sep 19, 2023
567fb9a
fix bug in getting num_kv_heads
delock Sep 20, 2023
47c83ca
Merge branch 'master' into gma/uneven_heads
molly-smith Sep 20, 2023
d194ab0
Merge branch 'master' into gma/uneven_heads
tjruwase Sep 27, 2023
6db5ddd
Merge branch 'master' into gma/uneven_heads
delock Oct 7, 2023
698b62a
Merge branch 'up-master' into gma/uneven_heads
delock Oct 10, 2023
d75149f
support uneven sharding of lm_head tensor parallel
delock Oct 10, 2023
248532d
Merge branch 'master' into gma/uneven_heads
delock Oct 11, 2023
a9056fd
Merge branch 'master' into gma/uneven_heads
delock Oct 11, 2023
81bd29f
Merge branch 'master' into gma/uneven_heads
delock Oct 12, 2023
693a9fe
Merge branch 'master' into gma/uneven_heads
delock Oct 18, 2023
4c45a5b
Merge branch 'master' into gma/uneven_heads
delock Oct 19, 2023
a7513e1
Merge branch 'master' into gma/uneven_heads
delock Oct 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,17 @@ def replace_with_policy(child, policy_cls, triangular_masking, inference=False,

return _container.module

def get_shard_size(total_size, num_slices):
if hasattr(model_config, 'num_attention_heads'):
num_heads = model_config.num_attention_heads
my_slices = num_heads // num_slices + (1 if dist.get_rank() < (num_heads % num_slices) else 0)
return total_size // num_heads * my_slices
else:
if total_size % num_slices == 0:
return total_size // num_slices
else:
assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({num_slices})"

def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
mp_size = config.tensor_parallel.tp_size
mp_group = config.tensor_parallel.tp_group
Expand All @@ -374,12 +385,11 @@ def _replace(child, name, conv_linear_layer):
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
weight_shape = child.weight.shape
if name in all_reduce_linears:
new_weight = torch.empty((
weight_shape[1] if conv_linear_layer else weight_shape[0],
(weight_shape[0] if conv_linear_layer else weight_shape[1]) // mp_size,
),
device=child.weight.device,
dtype=child.weight.dtype)
new_weight = torch.empty(
(weight_shape[1] if conv_linear_layer else weight_shape[0],
get_shard_size(weight_shape[0] if conv_linear_layer else weight_shape[1], mp_size)),
device=child.weight.device,
dtype=child.weight.dtype)
if conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
data = mp_replace.copy(new_weight, child.weight.data)
Expand All @@ -391,16 +401,16 @@ def _replace(child, name, conv_linear_layer):
torch.nn.parameter.Parameter(new_bias.to(get_accelerator().current_device_name())), mp_group)
else:
new_weight = torch.empty((
(weight_shape[1] if conv_linear_layer else weight_shape[0]) // mp_size,
weight_shape[0] // mp_size if conv_linear_layer else weight_shape[1],
get_shard_size(weight_shape[1] if conv_linear_layer else weight_shape[0], mp_size),
get_shard_size(weight_shape[0], mp_size) if conv_linear_layer else weight_shape[1],
),
device=child.weight.device,
dtype=child.weight.dtype)
if conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
data = mp_replace.copy(new_weight, child.weight.data)

new_bias = torch.empty((weight_shape[0] // mp_size),
new_bias = torch.empty(get_shard_size(weight_shape[0], mp_size),
device=child.weight.device,
dtype=child.weight.dtype)
bias_data = None if child.bias is None else mp_replace.copy(new_bias, child.bias.data).to(
Expand All @@ -412,13 +422,13 @@ def _slice_embedding(child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
return
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
new_weight = torch.empty((child.weight.shape[0], child.weight.shape[1] // mp_size),
new_weight = torch.empty((child.weight.shape[0], get_shard_size(child.weight.shape[1], mp_size)),
device=child.weight.device,
dtype=child.weight.dtype)
data = mp_replace.copy(new_weight,
child.weight.ds_tensor.data if hasattr(child.weight, 'ds_tensor') else \
child.weight.data)
new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // mp_size)
new_embedding = nn.Embedding(child.weight.shape[0], get_shard_size(child.weight.shape[1], mp_size))
new_embedding.weight.data.copy_(data)
setattr(child, "replaced", True)
return new_embedding
Expand All @@ -432,8 +442,7 @@ def update_mp_params(child):
]:
if hasattr(child, param):
param_val = getattr(child, param)
assert param_val % mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({mp_size})"
setattr(child, param, param_val // mp_size)
setattr(child, param, get_shard_size(param_val, mp_size))
setattr(child, "replaced", True)

conv_linear_layer = False
Expand Down