Skip to content

Commit

Permalink
enable phi3_mini autotp (#5501)
Browse files Browse the repository at this point in the history
This PR aims to enable phi3 mini autotp.

Phi3 mini uses chunk MLP. We adjust this linear layer weight order to
support this model.

Please kindly review~ Thanks!

---------

Co-authored-by: Lev Kurilenko <[email protected]>
  • Loading branch information
Yejing-Lai and lekurile authored May 8, 2024
1 parent 0b224ed commit 3dd7ccf
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 4 deletions.
9 changes: 7 additions & 2 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from deepspeed import comm as dist
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_chunk_mlp
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list


Expand Down Expand Up @@ -133,7 +133,8 @@ def is_load_module(module):
load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm]
load_layer_names = [
"LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear",
"MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm"
"MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding",
"Phi3RMSNorm"
]
return module.__class__ in load_layers or module._get_name() in load_layer_names

Expand Down Expand Up @@ -328,6 +329,10 @@ def _replace(self, child, name, conv_linear_layer):
# For mixtral-7x8b, need to skip MoE gate linear replace.
if name == "block_sparse_moe.gate":
return child
# for phi3.
if 'gate_up_proj' in name:
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
return LinearLayer(weight=weight, bias=bias)
if name in self.all_reduce_linears:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]
Expand Down
40 changes: 39 additions & 1 deletion deepspeed/module_inject/fusedqkv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# DeepSpeed Team
import torch
from deepspeed.utils.logging import warning_once
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd, get_num_attention_heads


def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0):
Expand Down Expand Up @@ -42,6 +42,7 @@ def prepare_tp_fused_qkvw(module, src, mp_size, gpu_index):
"FalconDecoderLayer": 'bloomtype',
"GPTBigCodeBlock": 'bigcodetype',
"DecoderLayer": 'glmtype',
"Phi3DecoderLayer": "phi3type"
}

def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):
Expand Down Expand Up @@ -93,6 +94,20 @@ def _bigcode_type_transpose(input, mp_size):
split_q = q.split(get_shard_size_list(shape[0], mp_size), dim=0)
return torch.cat((split_q[gpu_index], kv), dim=0)

def _phi3_type_transpose(input, mp_size):
num_kv_heads = get_num_kv_heads()
num_heads = get_num_attention_heads()
hidden_size = input.shape[1]
head_dim = hidden_size // num_heads
q_pos = input.shape[0] - 2 * num_kv_heads * head_dim
q = input[:q_pos]
k = input[q_pos:q_pos + num_kv_heads * head_dim]
v = input[q_pos + num_kv_heads * head_dim:]
split_q = q.split(get_shard_size_list(q.shape[0], mp_size), dim=0)
split_k = k.split(get_shard_size_list(k.shape[0], mp_size), dim=0)
split_v = v.split(get_shard_size_list(v.shape[0], mp_size), dim=0)
return torch.cat((split_q[gpu_index], split_k[gpu_index], split_v[gpu_index]), dim=0)

def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None):

# suppose num_heads=n, q(n)_w means the n-th q head linear weight, the weight format are as following
Expand All @@ -110,6 +125,8 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None):
return _qwen_type_transpose(src, mp_size, module)
elif fused_qkv_type == 'bigcodetype':
return _bigcode_type_transpose(src, mp_size)
elif fused_qkv_type == 'phi3type':
return _phi3_type_transpose(src, mp_size)

raise ValueError("unknown fused_qkv_type")

Expand All @@ -123,3 +140,24 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None):
warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type,"
f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors")
return _bloom_type_transpose(src, mp_size)


# For phi3 with chunk mlp, adjust the weight order.
def shard_chunk_mlp(
weight,
bias,
rank,
world_size,
):
weight_gate, weight_states = weight.chunk(2, dim=0)
total_size = weight_gate.shape[0]
split_weight_gate = weight_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
split_weight_states = weight_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
shard_weight = torch.cat((split_weight_gate[rank], split_weight_states[rank]), dim=0)
if bias is not None:
bias_gate, bias_states = bias.chunk(2, dim=0)
split_bias_gate = bias_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
split_bias_states = bias_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0)
return shard_weight, torch.cat((split_bias_gate[rank], split_bias_states[rank]), dim=0)

return shard_weight, None
6 changes: 5 additions & 1 deletion deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading

from deepspeed import comm as dist
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd
from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads

from .load_checkpoint import load_model_with_checkpoint
import time
Expand Down Expand Up @@ -290,6 +290,10 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None):
# 4.2 set n_embd
set_n_embd(n_embd)

# 4.3 set attention_heads
if hasattr(model_config, 'num_attention_heads'):
set_num_attention_heads(getattr(model_config, 'num_attention_heads'))

# 5. Set linear policies
_autotp.update_linear_policies()

Expand Down
10 changes: 10 additions & 0 deletions deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ def set_num_kv_heads(num):
num_kv_heads = num


def set_num_attention_heads(num):
global num_attention_heads
num_attention_heads = num


def set_n_embd(num):
global n_embd
n_embd = num
Expand All @@ -22,6 +27,11 @@ def get_num_kv_heads():
return num_kv_heads


def get_num_attention_heads():
global num_attention_heads
return num_attention_heads


def get_shard_size(total_size, mp_size, name=None, rank=None):
global num_kv_heads
last_linear = ["lm_head", "embed_out"]
Expand Down

0 comments on commit 3dd7ccf

Please sign in to comment.