From 0706acd78dadaed07278d4ca6d883939c157098f Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Thu, 20 Jul 2023 05:41:07 -0400 Subject: [PATCH 01/25] allow number of heads not divisible by number of ranks --- deepspeed/module_inject/replace_module.py | 38 +++++++++++++++-------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 4c488e3bf2cd..9c6e6336bc25 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -304,6 +304,9 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m seed = -1 local_rank = -1 + global num_heads + num_heads = -1 + mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group, mp_size=config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1) @@ -364,6 +367,11 @@ def replace_with_policy(child, policy_cls, triangular_masking, inference=False, return _container.module + def get_shard_size(total_size, num_slices): + num_units = num_heads + my_slices = num_units // num_slices + (1 if dist.get_rank() < (num_units % num_slices) else 0) + return total_size // num_units * my_slices + def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): mp_size = config.tensor_parallel.tp_size mp_group = config.tensor_parallel.tp_group @@ -374,12 +382,11 @@ def _replace(child, name, conv_linear_layer): mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) weight_shape = child.weight.shape if name in all_reduce_linears: - new_weight = torch.empty(( - weight_shape[1] if conv_linear_layer else weight_shape[0], - (weight_shape[0] if conv_linear_layer else weight_shape[1]) // mp_size, - ), - device=child.weight.device, - dtype=child.weight.dtype) + new_weight = torch.empty( + (weight_shape[1] if conv_linear_layer else weight_shape[0], + get_shard_size(weight_shape[0] if conv_linear_layer else weight_shape[1], mp_size)), + device=child.weight.device, + dtype=child.weight.dtype) if conv_linear_layer: child.weight.data = child.weight.data.transpose(-1, -2).contiguous() data = mp_replace.copy(new_weight, child.weight.data) @@ -391,8 +398,8 @@ def _replace(child, name, conv_linear_layer): torch.nn.parameter.Parameter(new_bias.to(get_accelerator().current_device_name())), mp_group) else: new_weight = torch.empty(( - (weight_shape[1] if conv_linear_layer else weight_shape[0]) // mp_size, - weight_shape[0] // mp_size if conv_linear_layer else weight_shape[1], + get_shard_size(weight_shape[1] if conv_linear_layer else weight_shape[0], mp_size), + get_shard_size(weight_shape[0], mp_size) if conv_linear_layer else weight_shape[1], ), device=child.weight.device, dtype=child.weight.dtype) @@ -400,7 +407,7 @@ def _replace(child, name, conv_linear_layer): child.weight.data = child.weight.data.transpose(-1, -2).contiguous() data = mp_replace.copy(new_weight, child.weight.data) - new_bias = torch.empty((weight_shape[0] // mp_size), + new_bias = torch.empty(get_shard_size(weight_shape[0], mp_size), device=child.weight.device, dtype=child.weight.dtype) bias_data = None if child.bias is None else mp_replace.copy(new_bias, child.bias.data).to( @@ -412,13 +419,13 @@ def _slice_embedding(child, name, conv_linear_layer): if getattr(child, "replaced", False) == True: return mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group) - new_weight = torch.empty((child.weight.shape[0], child.weight.shape[1] // mp_size), + new_weight = torch.empty((child.weight.shape[0], get_shard_size(child.weight.shape[1], mp_size)), device=child.weight.device, dtype=child.weight.dtype) data = mp_replace.copy(new_weight, child.weight.ds_tensor.data if hasattr(child.weight, 'ds_tensor') else \ child.weight.data) - new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // mp_size) + new_embedding = nn.Embedding(child.weight.shape[0], get_shard_size(child.weight.shape[1], mp_size)) new_embedding.weight.data.copy_(data) setattr(child, "replaced", True) return new_embedding @@ -432,8 +439,13 @@ def update_mp_params(child): ]: if hasattr(child, param): param_val = getattr(child, param) - assert param_val % mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({mp_size})" - setattr(child, param, param_val // mp_size) + if param in ["n_heads", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads"]: + global num_heads + num_heads = param_val + setattr(child, param, param_val // mp_size + (1 if dist.get_rank() < + (param_val % mp_size) else 0)) + else: + setattr(child, param, get_shard_size(param_val, mp_size)) setattr(child, "replaced", True) conv_linear_layer = False From 0bf785f893e1ef9c3acf140553cff93eb861b4d0 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Fri, 21 Jul 2023 01:18:32 -0400 Subject: [PATCH 02/25] get num_heads from model config, more robust --- deepspeed/module_inject/replace_module.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 9c6e6336bc25..9843e30fb113 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -304,9 +304,6 @@ def replace_transformer_layer(orig_layer_impl, model, checkpoint_dict, config, m seed = -1 local_rank = -1 - global num_heads - num_heads = -1 - mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group, mp_size=config.tensor_parallel.tp_size) #, out_dim=0, in_dim=1) @@ -368,7 +365,7 @@ def replace_with_policy(child, policy_cls, triangular_masking, inference=False, return _container.module def get_shard_size(total_size, num_slices): - num_units = num_heads + num_units = model_config.num_attention_heads my_slices = num_units // num_slices + (1 if dist.get_rank() < (num_units % num_slices) else 0) return total_size // num_units * my_slices @@ -440,8 +437,6 @@ def update_mp_params(child): if hasattr(child, param): param_val = getattr(child, param) if param in ["n_heads", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads"]: - global num_heads - num_heads = param_val setattr(child, param, param_val // mp_size + (1 if dist.get_rank() < (param_val % mp_size) else 0)) else: From 72b9e1addb8b932e4862fdcbe948e188b0e4aaa5 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Fri, 21 Jul 2023 01:23:07 -0400 Subject: [PATCH 03/25] simplify logic where num_head itself is sharded --- deepspeed/module_inject/replace_module.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 9843e30fb113..cb4365cd77b1 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -436,11 +436,7 @@ def update_mp_params(child): ]: if hasattr(child, param): param_val = getattr(child, param) - if param in ["n_heads", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads"]: - setattr(child, param, param_val // mp_size + (1 if dist.get_rank() < - (param_val % mp_size) else 0)) - else: - setattr(child, param, get_shard_size(param_val, mp_size)) + setattr(child, param, get_shard_size(param_val, mp_size)) setattr(child, "replaced", True) conv_linear_layer = False From 5ed9a5662c1464e28933b2267c716dd353c7b8a2 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Fri, 21 Jul 2023 01:32:15 -0400 Subject: [PATCH 04/25] name tweaks --- deepspeed/module_inject/replace_module.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index cb4365cd77b1..238cffd79711 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -365,9 +365,9 @@ def replace_with_policy(child, policy_cls, triangular_masking, inference=False, return _container.module def get_shard_size(total_size, num_slices): - num_units = model_config.num_attention_heads - my_slices = num_units // num_slices + (1 if dist.get_rank() < (num_units % num_slices) else 0) - return total_size // num_units * my_slices + num_heads = model_config.num_attention_heads + my_slices = num_heads // num_slices + (1 if dist.get_rank() < (num_heads % num_slices) else 0) + return total_size // num_heads * my_slices def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): mp_size = config.tensor_parallel.tp_size From 73f499d9226a10f6304b9d8cbd425e4d7aac0912 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Fri, 21 Jul 2023 01:39:09 -0400 Subject: [PATCH 05/25] make code more robust where num_attention_heads may not be defined in model_config --- deepspeed/module_inject/replace_module.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 238cffd79711..0b937bfe1928 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -365,9 +365,15 @@ def replace_with_policy(child, policy_cls, triangular_masking, inference=False, return _container.module def get_shard_size(total_size, num_slices): - num_heads = model_config.num_attention_heads - my_slices = num_heads // num_slices + (1 if dist.get_rank() < (num_heads % num_slices) else 0) - return total_size // num_heads * my_slices + if hasattr(model_config, 'num_attention_heads'): + num_heads = model_config.num_attention_heads + my_slices = num_heads // num_slices + (1 if dist.get_rank() < (num_heads % num_slices) else 0) + return total_size // num_heads * my_slices + else: + if total_size % num_slices == 0: + return total_size // num_slices + else: + assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({num_slices})" def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): mp_size = config.tensor_parallel.tp_size From 12c0628631da848ee61ec75ad61c62431513db03 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Mon, 24 Jul 2023 22:39:28 -0400 Subject: [PATCH 06/25] support num_key_value_heads < num_attention_heads which is used by llama2 --- deepspeed/module_inject/replace_module.py | 26 ++++++++++++++++------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 0b937bfe1928..7bcd059496b3 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -364,16 +364,26 @@ def replace_with_policy(child, policy_cls, triangular_masking, inference=False, return _container.module - def get_shard_size(total_size, num_slices): - if hasattr(model_config, 'num_attention_heads'): - num_heads = model_config.num_attention_heads - my_slices = num_heads // num_slices + (1 if dist.get_rank() < (num_heads % num_slices) else 0) - return total_size // num_heads * my_slices + def get_shard_size(total_size, mp_size): + num_kv_heads = None + + # 1. Try to get num_key_heads from model_config.num_key_value_heads + if hasattr(model_config, 'num_key_value_heads'): + num_kv_heads = model_config.num_key_value_heads + + # 2. Fallback to model_config.num_attention_heads when necessary + if num_kv_heads == None and hasattr(model_config, 'num_attention_heads'): + num_kv_heads = model_config.num_attention_heads + + # 3. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division + if num_kv_heads != None: + my_slices = num_kv_heads // mp_size + (1 if dist.get_rank() < (num_kv_heads % mp_size) else 0) + return total_size // num_kv_heads * my_slices else: - if total_size % num_slices == 0: - return total_size // num_slices + if total_size % mp_size == 0: + return total_size // mp_size else: - assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({num_slices})" + assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({mp_size})" def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): mp_size = config.tensor_parallel.tp_size From 8f23d9bfebb59dd7046edfc8a4a978124f9276fc Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Mon, 24 Jul 2023 23:43:04 -0400 Subject: [PATCH 07/25] add test for 5 ranks --- tests/unit/inference/test_inference.py | 31 ++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 4209bfa02ab4..50bf8bc661fa 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -519,6 +519,37 @@ def test( print(local_rank, "deepspeed", ds_output) assert assert_fn(bs_output, ds_output) + @pytest.mark.world_size(5) + def test_odd_world_size( + self, + model_w_task, + query, + inf_kwargs, + assert_fn, + dtype, + ): + invalid_test_msg = validate_test(model_w_task, dtype, enable_cuda_graph=False, enable_triton=False) + if invalid_test_msg: + pytest.skip(invalid_test_msg) + + model, task = model_w_task + local_rank = int(os.getenv("LOCAL_RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "5")) + + # We have to load these large models on CPU with pipeline because not + # enough GPU memory + pipe = pipeline(task, model=model, device=torch.device("cpu"), framework="pt") + bs_output = pipe(query, **inf_kwargs) + + pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype) + # Switch device to GPU so that input tensors are not on CPU + pipe.device = torch.device(get_accelerator().device_name(local_rank)) + ds_output = pipe(query, **inf_kwargs) + + print(local_rank, "baseline", bs_output) + print(local_rank, "deepspeed", ds_output) + assert assert_fn(bs_output, ds_output) + @pytest.mark.nightly @pytest.mark.parametrize( From 9c53bd74e314c0361c77a12f82a5a256892a4a8d Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Tue, 25 Jul 2023 00:54:04 -0400 Subject: [PATCH 08/25] change odd rank # to 3 to avoid test skip --- tests/unit/inference/test_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 50bf8bc661fa..28a458422e4f 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -519,7 +519,7 @@ def test( print(local_rank, "deepspeed", ds_output) assert assert_fn(bs_output, ds_output) - @pytest.mark.world_size(5) + @pytest.mark.world_size(3) def test_odd_world_size( self, model_w_task, @@ -534,7 +534,7 @@ def test_odd_world_size( model, task = model_w_task local_rank = int(os.getenv("LOCAL_RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "5")) + world_size = int(os.getenv("WORLD_SIZE", "3")) # We have to load these large models on CPU with pipeline because not # enough GPU memory From 27fde308692e267b1e21211d495ad6b68c04533d Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Wed, 9 Aug 2023 10:17:01 -0400 Subject: [PATCH 09/25] add get_shard_size function --- deepspeed/module_inject/auto_tp.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 8c1f78d0cfec..c6768d605c7a 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -402,6 +402,27 @@ def update_linear_policies(self): else: self.linear_policies = {nn.Linear: self._replace, nn.Embedding: self._slice_embedding} + def get_shard_size(total_size, mp_size): + num_kv_heads = None + + # 1. Try to get num_key_heads from model_config.num_key_value_heads + if hasattr(model_config, 'num_key_value_heads'): + num_kv_heads = model_config.num_key_value_heads + + # 2. Fallback to model_config.num_attention_heads when necessary + if num_kv_heads == None and hasattr(model_config, 'num_attention_heads'): + num_kv_heads = model_config.num_attention_heads + + # 3. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division + if num_kv_heads != None: + my_slices = num_kv_heads // mp_size + (1 if dist.get_rank() < (num_kv_heads % mp_size) else 0) + return total_size // num_kv_heads * my_slices + else: + if total_size % mp_size == 0: + return total_size // mp_size + else: + assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({mp_size})" + def _replace_module(self, r_module, prev_name='', prev_class_name=''): for name, child in r_module.named_children(): if prev_class_name == "": From 8e1fd277b026e0aa17bbcc9196ffdcd571425130 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Thu, 10 Aug 2023 02:56:07 -0400 Subject: [PATCH 10/25] modify sharding mechanism according to latest auto TP --- deepspeed/module_inject/auto_tp.py | 40 ++++++----------------- deepspeed/module_inject/fusedqkv_utils.py | 9 ++--- deepspeed/module_inject/replace_module.py | 38 +++++++++------------ deepspeed/utils/tp_shard.py | 18 ++++++++++ 4 files changed, 48 insertions(+), 57 deletions(-) create mode 100644 deepspeed/utils/tp_shard.py diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index c6768d605c7a..039b9892a6bd 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -14,6 +14,7 @@ from .layers import LinearAllreduce, LinearLayer from deepspeed.accelerator import get_accelerator from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw +from deepspeed.utils.tp_shard import get_shard_size class ReplaceWithTensorSlicing: @@ -300,7 +301,7 @@ def _replace(self, child, name, conv_linear_layer): # MPT block qkv weight's allocation is different from other models, it's [3,num_head,head_dim,hidden_size] # instead of [num_head,3,head_dim,hidden_size] new_weight = torch.empty(( - weight_shape[0] // self.mp_size, + get_shard_size(weight_shape[0], self.mp_size), weight_shape[1], ), device=child.weight.device, @@ -319,7 +320,7 @@ def _replace(self, child, name, conv_linear_layer): if self.conv_linear_layer: child.weight.data = child.weight.data.transpose(-1, -2).contiguous() data = child.weight.data.split( - (weight_shape[0] if self.conv_linear_layer else weight_shape[1]) // self.mp_size, dim=1) + get_shard_size(weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size), dim=1) data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) setattr(child, "replaced", True) @@ -342,13 +343,13 @@ def _replace(self, child, name, conv_linear_layer): module_str, child.bias.data, self.mp_size, mp_replace.gpu_index).to( get_accelerator().current_device_name()) else: - data = child.weight.data.split((weight_shape[0]) // self.mp_size, + data = child.weight.data.split(get_shard_size(weight_shape[0], self.mp_size), dim=1 if self.conv_linear_layer else 0) data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) if child.bias is not None: bias_data = child.bias.data.split( - (weight_shape[1] if self.conv_linear_layer else weight_shape[0]) // self.mp_size, dim=0) + get_shard_size(weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size), dim=0) bias_data = bias_data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) else: bias_data = None @@ -362,12 +363,12 @@ def _slice_embedding(self, child, name, conv_linear_layer): mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group) if hasattr(child.weight, 'ds_tensor'): - data = child.weight.ds_tensor.data.split(child.weight.shape[1] // self.mp_size, dim=1) + data = child.weight.ds_tensor.data.split(get_shard_size(child.weight.shape[1], self.mp_size), dim=1) else: - data = child.weight.data.split(child.weight.shape[1] // self.mp_size, dim=1) + data = child.weight.data.split(get_shard_size(child.weight.shape[1], self.mp_size), dim=1) data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) - new_embedding = nn.Embedding(child.weight.shape[0], child.weight.shape[1] // self.mp_size) + new_embedding = nn.Embedding(child.weight.shape[0], get_shard_size(child.weight.shape[1], self.mp_size)) new_embedding.weight.data.copy_(data) setattr(child, "replaced", True) return new_embedding @@ -381,8 +382,8 @@ def update_mp_params(self, child): ]: if hasattr(child, param): param_val = getattr(child, param) - assert param_val % self.mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({self.mp_size})" - setattr(child, param, param_val // self.mp_size) + #assert param_val % self.mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({self.mp_size})" + setattr(child, param, get_shard_size(param_val, self.mp_size)) setattr(child, "replaced", True) def update_linear_policies(self): @@ -402,27 +403,6 @@ def update_linear_policies(self): else: self.linear_policies = {nn.Linear: self._replace, nn.Embedding: self._slice_embedding} - def get_shard_size(total_size, mp_size): - num_kv_heads = None - - # 1. Try to get num_key_heads from model_config.num_key_value_heads - if hasattr(model_config, 'num_key_value_heads'): - num_kv_heads = model_config.num_key_value_heads - - # 2. Fallback to model_config.num_attention_heads when necessary - if num_kv_heads == None and hasattr(model_config, 'num_attention_heads'): - num_kv_heads = model_config.num_attention_heads - - # 3. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division - if num_kv_heads != None: - my_slices = num_kv_heads // mp_size + (1 if dist.get_rank() < (num_kv_heads % mp_size) else 0) - return total_size // num_kv_heads * my_slices - else: - if total_size % mp_size == 0: - return total_size // mp_size - else: - assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({mp_size})" - def _replace_module(self, r_module, prev_name='', prev_class_name=''): for name, child in r_module.named_children(): if prev_class_name == "": diff --git a/deepspeed/module_inject/fusedqkv_utils.py b/deepspeed/module_inject/fusedqkv_utils.py index 30a5bb75db23..d7850cb85aa3 100644 --- a/deepspeed/module_inject/fusedqkv_utils.py +++ b/deepspeed/module_inject/fusedqkv_utils.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch from deepspeed.utils.logging import warning_once +from deepspeed.utils.tp_shard import get_shard_size import re @@ -41,14 +42,14 @@ def _codegen_type_transpose(input, mp_size, codegen_mp_num=4): #input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias) shape = input.shape - dst_shape = shape[0] // mp_size + dst_shape = get_shard_size(shape[0], mp_size) num_mp_blocks = input.reshape(codegen_mp_num, shape[0] // codegen_mp_num, shape[1]) #num_mp_blocks : [codegen_mp_num, 3*hidden_dim/codegen_mp_num, :] src_split = list(torch.split(num_mp_blocks, num_mp_blocks.shape[1] // 3, dim=1)) src_split = [x.reshape(codegen_mp_num * mp_size, -1, shape[1]) for x in src_split] - split_fusedqkv = split_by_qkvlist_and_refuse(src_split, shape[0] // 3 // mp_size, 0, 1) + split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size(shape[0] // 3, mp_size), 0, 1) tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0).reshape(shape[0], -1) return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape] @@ -57,10 +58,10 @@ def _glm_type_transpose(input, mp_size): #input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias) shape = input.shape - dst_shape = shape[0] // mp_size + dst_shape = get_shard_size(shape[0], mp_size) src_split = torch.split(input, shape[0] // 3, dim=0) - split_fusedqkv = split_by_qkvlist_and_refuse(src_split, shape[0] // 3 // mp_size) + split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size(shape[0] // 3, mp_size)) tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0) return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape] diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index d93146857a66..4010f164420d 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -16,6 +16,7 @@ from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading from deepspeed import comm as dist +from deepspeed.utils.tp_shard import set_num_kv_heads from torch import nn from .load_checkpoint import load_model_with_checkpoint @@ -263,27 +264,6 @@ def replace_with_policy(child, policy_cls, triangular_masking, inference=False, return _container.module - def get_shard_size(total_size, mp_size): - num_kv_heads = None - - # 1. Try to get num_key_heads from model_config.num_key_value_heads - if hasattr(model_config, 'num_key_value_heads'): - num_kv_heads = model_config.num_key_value_heads - - # 2. Fallback to model_config.num_attention_heads when necessary - if num_kv_heads == None and hasattr(model_config, 'num_attention_heads'): - num_kv_heads = model_config.num_attention_heads - - # 3. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division - if num_kv_heads != None: - my_slices = num_kv_heads // mp_size + (1 if dist.get_rank() < (num_kv_heads % mp_size) else 0) - return total_size // num_kv_heads * my_slices - else: - if total_size % mp_size == 0: - return total_size // mp_size - else: - assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({mp_size})" - def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): #mp_replace = ReplaceWithTensorSlicing(mp_group=config.tensor_parallel.tp_group) @@ -293,10 +273,22 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): # 2. Set the tensor parallelism config _autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group) - # 3. Set linear policies + # 3. Try to get num_key_heads from model_config.num_key_value_heads + num_kv_heads = None + if hasattr(model_config, 'num_key_value_heads'): + num_kv_heads = model_config.num_key_value_heads + + # 4. Fallback to model_config.num_attention_heads when necessary + if num_kv_heads == None and hasattr(model_config, 'num_attention_heads'): + num_kv_heads = model_config.num_attention_heads + + # 5. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division + set_num_kv_heads(num_kv_heads) + + # 6. Set linear policies _autotp.update_linear_policies() - # 4. Replace modules + # 7. Replace modules return _autotp._replace_module(module) def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None): diff --git a/deepspeed/utils/tp_shard.py b/deepspeed/utils/tp_shard.py new file mode 100644 index 000000000000..42d6f50f7e31 --- /dev/null +++ b/deepspeed/utils/tp_shard.py @@ -0,0 +1,18 @@ +from deepspeed import comm as dist +global num_kv_heads + +def set_num_kv_heads(num): + global num_kv_heads + num_kv_heads = num + +def get_shard_size(total_size, mp_size): + global num_kv_heads + # When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division + if num_kv_heads != None: + my_slices = num_kv_heads // mp_size + (1 if dist.get_rank() < (num_kv_heads % mp_size) else 0) + return total_size // num_kv_heads * my_slices + else: + if total_size % mp_size == 0: + return total_size // mp_size + else: + assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({mp_size})" From 2dac94f35a8630973d58988a9f94623cdb5f5052 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Thu, 17 Aug 2023 10:07:27 +0000 Subject: [PATCH 11/25] fix accuracy issue --- deepspeed/module_inject/auto_tp.py | 12 ++++++------ deepspeed/module_inject/auto_tp_model_utils.py | 9 +++++---- deepspeed/utils/tp_shard.py | 14 +++++++++++--- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index c458fc0ac86f..17a9f7aebec4 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -14,7 +14,7 @@ from .layers import LinearAllreduce, LinearLayer from deepspeed.accelerator import get_accelerator from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw -from deepspeed.utils.tp_shard import get_shard_size +from deepspeed.utils.tp_shard import get_shard_size, get_shard_size_list class ReplaceWithTensorSlicing: @@ -320,7 +320,7 @@ def _replace(self, child, name, conv_linear_layer): if self.conv_linear_layer: child.weight.data = child.weight.data.transpose(-1, -2).contiguous() data = child.weight.data.split( - get_shard_size(weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size), dim=1) + get_shard_size_list(weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size), dim=1) data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) setattr(child, "replaced", True) @@ -343,13 +343,13 @@ def _replace(self, child, name, conv_linear_layer): module_str, child.bias.data, self.mp_size, mp_replace.gpu_index).to( get_accelerator().current_device_name()) else: - data = child.weight.data.split(get_shard_size(weight_shape[0], self.mp_size), + data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size), dim=1 if self.conv_linear_layer else 0) data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) if child.bias is not None: bias_data = child.bias.data.split( - get_shard_size(weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size), dim=0) + get_shard_size_list(weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size), dim=0) bias_data = bias_data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) bias_data = torch.nn.parameter.Parameter(bias_data, requires_grad=False) else: @@ -365,9 +365,9 @@ def _slice_embedding(self, child, name, conv_linear_layer): mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group) if hasattr(child.weight, 'ds_tensor'): - data = child.weight.ds_tensor.data.split(get_shard_size(child.weight.shape[1], self.mp_size), dim=1) + data = child.weight.ds_tensor.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1) else: - data = child.weight.data.split(get_shard_size(child.weight.shape[1], self.mp_size), dim=1) + data = child.weight.data.split(get_shard_size_list(child.weight.shape[1], self.mp_size), dim=1) data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) data = torch.nn.parameter.Parameter(data, requires_grad=False) diff --git a/deepspeed/module_inject/auto_tp_model_utils.py b/deepspeed/module_inject/auto_tp_model_utils.py index d31dfd17a2a9..445619dcd37a 100644 --- a/deepspeed/module_inject/auto_tp_model_utils.py +++ b/deepspeed/module_inject/auto_tp_model_utils.py @@ -6,6 +6,7 @@ from deepspeed import comm as dist import torch from typing import Optional +from deepspeed.utils.tp_shard import get_shard_size, get_shard_size_list def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: @@ -51,8 +52,8 @@ def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] alibi = slopes[..., None] * arange_tensor if dist.is_initialized(): - num_heads_per_rank = int(num_heads / dist.get_world_size()) - offset = dist.get_rank() * num_heads_per_rank + num_heads_per_rank = get_shard_size(num_heads, dist.get_world_size()) + offset = sum(get_shard_size_list(num_heads, dist.get_world_size())[0:dist.get_rank()]) alibi = alibi.view(batch_size, num_heads, 1, seq_length) alibi = alibi[:, offset:num_heads_per_rank + offset, :, :] return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype) @@ -72,7 +73,7 @@ def build_mpt_atten_bias_tensor(self, prefix_mask=prefix_mask, sequence_id=sequence_id) if dist.is_initialized(): - num_heads_per_rank = int(self.config.n_heads / dist.get_world_size()) - offset = dist.get_rank() * num_heads_per_rank + num_heads_per_rank = get_shard_size(self.config.n_heads, dist.get_world_size()) + offset = sum(get_shard_size_list(self.config.n_heads, dist.get_world_size())[0:dist.get_rank()]) attn_bias = attn_bias[:, offset:num_heads_per_rank + offset, :, :] return attn_bias, attention_mask diff --git a/deepspeed/utils/tp_shard.py b/deepspeed/utils/tp_shard.py index 42d6f50f7e31..67bb6c1f7872 100644 --- a/deepspeed/utils/tp_shard.py +++ b/deepspeed/utils/tp_shard.py @@ -5,14 +5,22 @@ def set_num_kv_heads(num): global num_kv_heads num_kv_heads = num -def get_shard_size(total_size, mp_size): +def get_shard_size(total_size, mp_size, rank=None): global num_kv_heads # When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division if num_kv_heads != None: - my_slices = num_kv_heads // mp_size + (1 if dist.get_rank() < (num_kv_heads % mp_size) else 0) - return total_size // num_kv_heads * my_slices + if (rank == None): + rank = dist.get_rank() + my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0) + return (total_size // num_kv_heads) * my_slices else: if total_size % mp_size == 0: return total_size // mp_size else: assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({mp_size})" + +def get_shard_size_list(total_size, mp_size): + shard_sizes = [] + for i in range(mp_size): + shard_sizes.append(get_shard_size(total_size, mp_size, i)) + return shard_sizes From 71f9f401b2302d8ba75c2f90a6f307041955bfee Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Mon, 21 Aug 2023 01:38:06 +0000 Subject: [PATCH 12/25] fix format --- deepspeed/module_inject/auto_tp.py | 10 ++++++---- deepspeed/utils/tp_shard.py | 8 ++++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 17a9f7aebec4..4198f73b42dd 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -319,8 +319,9 @@ def _replace(self, child, name, conv_linear_layer): if self.conv_linear_layer: child.weight.data = child.weight.data.transpose(-1, -2).contiguous() - data = child.weight.data.split( - get_shard_size_list(weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size), dim=1) + data = child.weight.data.split(get_shard_size_list( + weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size), + dim=1) data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) setattr(child, "replaced", True) @@ -348,8 +349,9 @@ def _replace(self, child, name, conv_linear_layer): data = data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) if child.bias is not None: - bias_data = child.bias.data.split( - get_shard_size_list(weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size), dim=0) + bias_data = child.bias.data.split(get_shard_size_list( + weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size), + dim=0) bias_data = bias_data[mp_replace.gpu_index].to(get_accelerator().current_device_name()) bias_data = torch.nn.parameter.Parameter(bias_data, requires_grad=False) else: diff --git a/deepspeed/utils/tp_shard.py b/deepspeed/utils/tp_shard.py index 67bb6c1f7872..61a017184e42 100644 --- a/deepspeed/utils/tp_shard.py +++ b/deepspeed/utils/tp_shard.py @@ -1,10 +1,17 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + from deepspeed import comm as dist global num_kv_heads + def set_num_kv_heads(num): global num_kv_heads num_kv_heads = num + def get_shard_size(total_size, mp_size, rank=None): global num_kv_heads # When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division @@ -19,6 +26,7 @@ def get_shard_size(total_size, mp_size, rank=None): else: assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({mp_size})" + def get_shard_size_list(total_size, mp_size): shard_sizes = [] for i in range(mp_size): From db9db6bfe29da97daa7c95784558a3cdf4503f2e Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Wed, 23 Aug 2023 06:22:12 +0000 Subject: [PATCH 13/25] skip tests with fusedqkv --- tests/unit/inference/test_inference.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index fdd53425a947..9d3876265dd2 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -529,6 +529,8 @@ def test_odd_world_size( pytest.skip(invalid_test_msg) model, task = model_w_task + if model == "Salesforce/codegen-350M-mono": + pytest.skip("fusedqkv does not supported by odd world_size") local_rank = int(os.getenv("LOCAL_RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "3")) From 9d5eae34b0361f9ef0acfca4e9823acd5a6f8c4d Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Wed, 23 Aug 2023 06:26:43 +0000 Subject: [PATCH 14/25] remove skip of fusedqkv tests --- tests/unit/inference/test_inference.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 6639654b9ecc..896f5cc13db7 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -558,8 +558,6 @@ def test_odd_world_size( pytest.skip(invalid_test_msg) model, task = model_w_task - if model == "Salesforce/codegen-350M-mono": - pytest.skip("fusedqkv does not supported by odd world_size") local_rank = int(os.getenv("LOCAL_RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "3")) From 25e656dbd9666746be3daa237edb9a044b24b863 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Wed, 23 Aug 2023 06:32:15 +0000 Subject: [PATCH 15/25] skip test fusedqkv with odd number of ranks --- tests/unit/inference/test_inference.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 896f5cc13db7..6639654b9ecc 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -558,6 +558,8 @@ def test_odd_world_size( pytest.skip(invalid_test_msg) model, task = model_w_task + if model == "Salesforce/codegen-350M-mono": + pytest.skip("fusedqkv does not supported by odd world_size") local_rank = int(os.getenv("LOCAL_RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "3")) From 7f6d7f6dac9d998e1a626dbd990a21b78ce9749e Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Thu, 24 Aug 2023 09:19:08 +0000 Subject: [PATCH 16/25] support model with n_heads in model_config --- deepspeed/module_inject/replace_module.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 4010f164420d..b77863e751f4 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -275,12 +275,12 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): # 3. Try to get num_key_heads from model_config.num_key_value_heads num_kv_heads = None - if hasattr(model_config, 'num_key_value_heads'): - num_kv_heads = model_config.num_key_value_heads - - # 4. Fallback to model_config.num_attention_heads when necessary - if num_kv_heads == None and hasattr(model_config, 'num_attention_heads'): - num_kv_heads = model_config.num_attention_heads + kv_head_names = ['num_key_value_heads', 'num_attention_heads', 'n_heads'] + for name in kv_head_names: + if hasattr(model_config, name): + num_kv_heads = getattr(model_config, name) + if num_kv_heads != None: + break # 5. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division set_num_kv_heads(num_kv_heads) From f5be25713eeb3190ffb434e147e4c2503a1256f0 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Sun, 27 Aug 2023 10:02:48 +0000 Subject: [PATCH 17/25] fix TestInjectionPolicy::test[fp32-t5] --- deepspeed/utils/tp_shard.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/utils/tp_shard.py b/deepspeed/utils/tp_shard.py index 61a017184e42..19e17db97dcc 100644 --- a/deepspeed/utils/tp_shard.py +++ b/deepspeed/utils/tp_shard.py @@ -15,7 +15,8 @@ def set_num_kv_heads(num): def get_shard_size(total_size, mp_size, rank=None): global num_kv_heads # When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division - if num_kv_heads != None: + # In the case that total_size cannot be divided by num_kv_heads, only even sharding is possible + if num_kv_heads != None and (total_size % num_kv_heads) == 0: if (rank == None): rank = dist.get_rank() my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0) From b6710401e42bb4c8151403d7434007e13c8f0c62 Mon Sep 17 00:00:00 2001 From: mzl Date: Mon, 28 Aug 2023 14:50:10 +0800 Subject: [PATCH 18/25] fix uneven_heads on some fusedqkv types (#12) * odd support fusedqkv * fix format and clear text --- deepspeed/module_inject/fusedqkv_utils.py | 20 +++++++++++--------- tests/unit/inference/test_inference.py | 2 +- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/deepspeed/module_inject/fusedqkv_utils.py b/deepspeed/module_inject/fusedqkv_utils.py index 34740e78f54f..c8c8e0ab7a17 100644 --- a/deepspeed/module_inject/fusedqkv_utils.py +++ b/deepspeed/module_inject/fusedqkv_utils.py @@ -4,7 +4,8 @@ # DeepSpeed Team import torch from deepspeed.utils.logging import warning_once -from deepspeed.utils.tp_shard import get_shard_size +from deepspeed.utils.tp_shard import get_shard_size, get_shard_size_list +import deepspeed.utils.tp_shard as tp_shard import re @@ -40,7 +41,8 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index): def _codegen_type_transpose(input, mp_size, codegen_mp_num=4): # codegen_mp_num defined in https://github.com/huggingface/transformers/blob/main/src/transformers/models/codegen/modeling_codegen.py - #TODO: assert num_heads % (mp_size*codegen_mp_num) == 0 + assert tp_shard.num_kv_heads % ( + mp_size * codegen_mp_num) == 0, "codgen autoTP requires num_kv_heads % (mp_size*codegen_mp_num) == 0" #input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias) shape = input.shape @@ -60,16 +62,16 @@ def _glm_type_transpose(input, mp_size): #input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias) shape = input.shape - dst_shape = get_shard_size(shape[0], mp_size) src_split = torch.split(input, shape[0] // 3, dim=0) - split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size(shape[0] // 3, mp_size)) - tp_fuseqkv_weight = torch.cat(split_fusedqkv, dim=0) - - return tp_fuseqkv_weight[gpu_index * dst_shape:(gpu_index + 1) * dst_shape] + split_fusedqkv = split_by_qkvlist_and_refuse(src_split, get_shard_size_list(shape[0] // 3, mp_size)) + return split_fusedqkv[gpu_index] def _bloom_type_transpose(input, mp_size): - return input + shape = input.shape + + split_fusedqkv = input.split(get_shard_size_list(shape[0], mp_size), dim=0) + return split_fusedqkv[gpu_index] def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None): @@ -92,4 +94,4 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None): return _transpose_fused_qkvw(src, mp_size, fused_type) warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type," f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors") - return src + return _bloom_type_transpose(src, mp_size) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 6639654b9ecc..7e2b3465e100 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -559,7 +559,7 @@ def test_odd_world_size( model, task = model_w_task if model == "Salesforce/codegen-350M-mono": - pytest.skip("fusedqkv does not supported by odd world_size") + pytest.skip("codegen does not supported by odd world_size") local_rank = int(os.getenv("LOCAL_RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "3")) From d59ff222ea12a97a282f5d22b24b6392b245ff2e Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Wed, 30 Aug 2023 04:06:11 +0000 Subject: [PATCH 19/25] better fix when activation size cannot be divided by number of heads --- deepspeed/utils/tp_shard.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/deepspeed/utils/tp_shard.py b/deepspeed/utils/tp_shard.py index 19e17db97dcc..a3352ada3bcb 100644 --- a/deepspeed/utils/tp_shard.py +++ b/deepspeed/utils/tp_shard.py @@ -15,12 +15,11 @@ def set_num_kv_heads(num): def get_shard_size(total_size, mp_size, rank=None): global num_kv_heads # When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division - # In the case that total_size cannot be divided by num_kv_heads, only even sharding is possible - if num_kv_heads != None and (total_size % num_kv_heads) == 0: + if num_kv_heads != None: if (rank == None): rank = dist.get_rank() my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0) - return (total_size // num_kv_heads) * my_slices + return total_size * my_slices // num_kv_heads else: if total_size % mp_size == 0: return total_size // mp_size From 4c6b7fa28307484a8da32262cbd04f787f969dce Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Wed, 6 Sep 2023 09:06:34 +0000 Subject: [PATCH 20/25] move tp_shard.py under module_inject --- deepspeed/module_inject/auto_tp.py | 2 +- deepspeed/module_inject/auto_tp_model_utils.py | 2 +- deepspeed/module_inject/fusedqkv_utils.py | 5 ++--- deepspeed/module_inject/replace_module.py | 2 +- deepspeed/{utils => module_inject}/tp_shard.py | 0 5 files changed, 5 insertions(+), 6 deletions(-) rename deepspeed/{utils => module_inject}/tp_shard.py (100%) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 2a4a8348258a..b01092f43006 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -14,7 +14,7 @@ from .layers import LinearAllreduce, LinearLayer from deepspeed.accelerator import get_accelerator from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw -from deepspeed.utils.tp_shard import get_shard_size, get_shard_size_list +from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list class ReplaceWithTensorSlicing: diff --git a/deepspeed/module_inject/auto_tp_model_utils.py b/deepspeed/module_inject/auto_tp_model_utils.py index 847494a30cd8..51e52e3258dd 100644 --- a/deepspeed/module_inject/auto_tp_model_utils.py +++ b/deepspeed/module_inject/auto_tp_model_utils.py @@ -6,7 +6,7 @@ from deepspeed import comm as dist import torch from typing import Optional -from deepspeed.utils.tp_shard import get_shard_size, get_shard_size_list +from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: diff --git a/deepspeed/module_inject/fusedqkv_utils.py b/deepspeed/module_inject/fusedqkv_utils.py index c8c8e0ab7a17..4001db79204a 100644 --- a/deepspeed/module_inject/fusedqkv_utils.py +++ b/deepspeed/module_inject/fusedqkv_utils.py @@ -4,8 +4,7 @@ # DeepSpeed Team import torch from deepspeed.utils.logging import warning_once -from deepspeed.utils.tp_shard import get_shard_size, get_shard_size_list -import deepspeed.utils.tp_shard as tp_shard +from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, num_kv_heads import re @@ -41,7 +40,7 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index): def _codegen_type_transpose(input, mp_size, codegen_mp_num=4): # codegen_mp_num defined in https://github.com/huggingface/transformers/blob/main/src/transformers/models/codegen/modeling_codegen.py - assert tp_shard.num_kv_heads % ( + assert num_kv_heads % ( mp_size * codegen_mp_num) == 0, "codgen autoTP requires num_kv_heads % (mp_size*codegen_mp_num) == 0" #input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 11472432dd53..7a3820f773ed 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -16,7 +16,7 @@ from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading from deepspeed import comm as dist -from deepspeed.utils.tp_shard import set_num_kv_heads +from deepspeed.module_inject.tp_shard import set_num_kv_heads from .load_checkpoint import load_model_with_checkpoint import time diff --git a/deepspeed/utils/tp_shard.py b/deepspeed/module_inject/tp_shard.py similarity index 100% rename from deepspeed/utils/tp_shard.py rename to deepspeed/module_inject/tp_shard.py From 8ef01e28c112093be5de29af71e5fe17c20dcbb2 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Thu, 7 Sep 2023 07:22:15 +0000 Subject: [PATCH 21/25] Add get_num_kv_heads in tp_shard.py --- deepspeed/module_inject/fusedqkv_utils.py | 4 ++-- deepspeed/module_inject/tp_shard.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/deepspeed/module_inject/fusedqkv_utils.py b/deepspeed/module_inject/fusedqkv_utils.py index 4001db79204a..2e8f6b5917ed 100644 --- a/deepspeed/module_inject/fusedqkv_utils.py +++ b/deepspeed/module_inject/fusedqkv_utils.py @@ -4,7 +4,7 @@ # DeepSpeed Team import torch from deepspeed.utils.logging import warning_once -from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, num_kv_heads +from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads import re @@ -40,7 +40,7 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index): def _codegen_type_transpose(input, mp_size, codegen_mp_num=4): # codegen_mp_num defined in https://github.com/huggingface/transformers/blob/main/src/transformers/models/codegen/modeling_codegen.py - assert num_kv_heads % ( + assert get_num_kv_heads() % ( mp_size * codegen_mp_num) == 0, "codgen autoTP requires num_kv_heads % (mp_size*codegen_mp_num) == 0" #input : [3*hidden_dim, hidden_dim](weight) or [3*hidden_dim](bias) diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py index a3352ada3bcb..8e2fa78d883f 100644 --- a/deepspeed/module_inject/tp_shard.py +++ b/deepspeed/module_inject/tp_shard.py @@ -12,6 +12,11 @@ def set_num_kv_heads(num): num_kv_heads = num +def get_num_kv_heads(): + global num_kv_heads + return num_kv_heads + + def get_shard_size(total_size, mp_size, rank=None): global num_kv_heads # When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division From 0781c41ea4e0d111f3bfe8f19b0ee49304f2d036 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Thu, 14 Sep 2023 03:29:55 +0000 Subject: [PATCH 22/25] Refine according to comments --- deepspeed/module_inject/auto_tp.py | 10 +++++++++- deepspeed/module_inject/replace_module.py | 14 ++++---------- tests/unit/inference/test_inference.py | 7 ++++--- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index b01092f43006..4da40420c269 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -377,7 +377,6 @@ def update_mp_params(self, child): ]: if hasattr(child, param): param_val = getattr(child, param) - #assert param_val % self.mp_size == 0, f"{param} ({param_val}) must be divisible by mp_size ({self.mp_size})" setattr(child, param, get_shard_size(param_val, self.mp_size)) setattr(child, "replaced", True) @@ -432,3 +431,12 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''): self.update_mp_params(child) self._replace_module(child, name, class_name) return r_module + + def get_model_num_kv_heads(self, config): + num_kv_heads = None + kv_head_names = ['num_key_value_heads', 'num_attention_heads', 'n_heads'] + for name in kv_head_names: + if hasattr(config, name): + num_kv_heads = getattr(config, name) + if num_kv_heads != None: + break diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 15e24cbb738c..e5a81607f3c3 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -273,21 +273,15 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): _autotp.set_tensor_parallel_config(config.tensor_parallel.tp_size, config.tensor_parallel.tp_group) # 3. Try to get num_key_heads from model_config.num_key_value_heads - num_kv_heads = None - kv_head_names = ['num_key_value_heads', 'num_attention_heads', 'n_heads'] - for name in kv_head_names: - if hasattr(model_config, name): - num_kv_heads = getattr(model_config, name) - if num_kv_heads != None: - break + num_kv_heads = _autotp.get_model_num_kv_heads(model_config) - # 5. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division + # 4. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division set_num_kv_heads(num_kv_heads) - # 6. Set linear policies + # 5. Set linear policies _autotp.update_linear_policies() - # 7. Replace modules + # 6. Replace modules return _autotp._replace_module(module) def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None): diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 7e2b3465e100..6bfcb538a000 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -565,12 +565,13 @@ def test_odd_world_size( # We have to load these large models on CPU with pipeline because not # enough GPU memory - pipe = pipeline(task, model=model, device=torch.device("cpu"), framework="pt") + pipe = pipeline(task, + model=model, + device=torch.device(get_accelerator().device_name(local_rank)), + framework="pt") bs_output = pipe(query, **inf_kwargs) pipe.model = deepspeed.init_inference(pipe.model, mp_size=world_size, dtype=dtype) - # Switch device to GPU so that input tensors are not on CPU - pipe.device = torch.device(get_accelerator().device_name(local_rank)) ds_output = pipe(query, **inf_kwargs) print(local_rank, "baseline", bs_output) From 194337fe5c154368e925778c90528e74a1629fff Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Thu, 14 Sep 2023 10:30:15 -0700 Subject: [PATCH 23/25] remove old comment --- tests/unit/inference/test_inference.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index 6bfcb538a000..fbd25ef0f660 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -563,8 +563,6 @@ def test_odd_world_size( local_rank = int(os.getenv("LOCAL_RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "3")) - # We have to load these large models on CPU with pipeline because not - # enough GPU memory pipe = pipeline(task, model=model, device=torch.device(get_accelerator().device_name(local_rank)), From 567fb9ae11831e35feab6c47a3d9f1f8f97b5f44 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Wed, 20 Sep 2023 10:12:24 +0000 Subject: [PATCH 24/25] fix bug in getting num_kv_heads --- deepspeed/module_inject/auto_tp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 4da40420c269..cc53ea7563a7 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -440,3 +440,4 @@ def get_model_num_kv_heads(self, config): num_kv_heads = getattr(config, name) if num_kv_heads != None: break + return num_kv_heads From d75149fb731a189bfa635aa2539d5e22eb522237 Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Tue, 10 Oct 2023 15:24:25 +0800 Subject: [PATCH 25/25] support uneven sharding of lm_head tensor parallel --- deepspeed/module_inject/layers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 7a565560dec9..969826ad0289 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter from deepspeed.accelerator import get_accelerator +from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list class LinearAllreduce(nn.Module): @@ -47,10 +48,9 @@ def __init__( self.world_size = world_size def forward(self, input): - assert input.shape[ - -1] % self.world_size == 0, 'Please ensure that self.world_size is divisible by input.shape[-1]' - input_shard = input.shape[-1] // self.world_size - output = torch.matmul(input[:, :, self.rank * input_shard:(self.rank + 1) * input_shard], + input_shard_size = get_shard_size(input.shape[-1], self.world_size) + input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.world_size)[0:self.rank]) + output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size], self.weight.transpose(-1, -2)) if self.mp_group is not None: dist.inference_all_reduce(output, group=self.mp_group)