Skip to content

Commit

Permalink
Fix a lot (#610)
Browse files Browse the repository at this point in the history
  • Loading branch information
sufubao authored Nov 20, 2024
1 parent 18a0f08 commit f69e2ec
Show file tree
Hide file tree
Showing 12 changed files with 18 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# from lightllm.common.layers.mm import MM
from .base_layer_weight import BaseLayerWeight
from .meta_weights import MMWeight, FusedMoeWeight
from .meta_weights import MMWeight, ROWMMWeight, FusedMoeWeight
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)
Expand All @@ -20,6 +20,7 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo
self.quant_cfg = quant_cfg
self.init_static_params()
self.fuse_pairs = {"k_proj&v_proj": "kv_proj"}
self.kv_proj: ROWMMWeight = None
return

def load_hf_weights(self, weights):
Expand All @@ -30,7 +31,7 @@ def fuse_weights(self):
for pair_name, fuse_name in self.fuse_pairs.items():
attr1_name, attr2_name = pair_name.split("&")
with self.lock:
if hasattr(self, fuse_name):
if getattr(self, fuse_name, None) is not None:
continue
attr1 = getattr(self, attr1_name)
attr2 = getattr(self, attr2_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _bind_func(self):
return

def _get_qkv(self, input, cache_kv, infer_state, layer_weight: BaiChuan13bTransformerLayerWeight) -> torch.Tensor:
q = layer_weight.q_proj.mm(input)
q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_))
cache_kv = layer_weight.kv_proj.mm(
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,3 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo

def init_static_params(self):
return BloomTransformerLayerWeight.init_static_params(self)

def verify_load(self):
super().verify_load()
assert self.tp_alibi is not None, "load error"
return
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
def _get_qkv(
self, input, cache_kv: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
) -> torch.Tensor:

q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_)).view(-1, self.tp_q_head_num_, self.head_dim_)
cache_kv = layer_weight.kv_proj.mm(
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo

def _init_config(self):
self.network_config_["num_key_value_heads"] = self.network_config_["num_attention_heads"]
self.n_embed = self.network_config_["hidden_size"]
self.n_head = self.network_config_["num_attention_heads"]
self.n_inter = self.network_config_["intermediate_size"]
self.n_kv_head = self.network_config_["num_key_value_heads"]
self.head_dim = self.network_config_.get("head_dim", self.n_embed // self.n_head)
super()._init_config()

def load_hf_weights(self, weights):
qkv_weight_name = f"{self.layer_name}.self_attn.W_pack.weight"
Expand Down
8 changes: 3 additions & 5 deletions lightllm/models/bloom/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight: BloomTrans
def _get_qkv(
self, input, cache_kv, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight
) -> torch.Tensor:
q = layer_weight.q_proj.mm(input)
q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_))
cache_kv = layer_weight.kv_proj.mm(
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
Expand Down Expand Up @@ -94,13 +94,11 @@ def _token_attention_kernel(
return o_tensor

def _get_o(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor:
input = input.view(-1, self.tp_o_head_num_ * self.head_dim_)
o_tensor = layer_weight.o_proj.mm(input)
o_tensor = layer_weight.o_proj.mm(input.view(-1, self.tp_o_head_num_ * self.head_dim_))
return o_tensor

def _ffn(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor:
input = input.view(-1, self.embed_dim_)
ffn1_out = layer_weight.up_proj.mm(input)
ffn1_out = layer_weight.up_proj.mm(input.view(-1, self.embed_dim_))
input = None
gelu_out = torch.nn.functional.gelu(ffn1_out, approximate="tanh")
ffn1_out = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def get_slopes_power_of_2(n):
class BloomTransformerLayerWeight(LlamaTransformerLayerWeight):
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg=None):
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg, layer_prefix="h")
self.init_static_params()
return

def _init_config(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ def _preprocess_weight(self, weights):
qkv_weight_name = f"{self.layer_name}.self_attention.query_key_value.weight"
if qkv_weight_name in weights:
qkv_weight_ = weights[qkv_weight_name]
weights[self._q_weight_name] = qkv_weight_[:, : self.n_embed]
weights[self._k_weight_name] = qkv_weight_[:, self.n_embed : self.n_embed + n_kv_embed]
weights[self._v_weight_name] = qkv_weight_[:, self.n_embed + n_kv_embed : self.n_embed + 2 * n_kv_embed]
weights[self._q_weight_name] = qkv_weight_[: self.n_embed, :]
weights[self._k_weight_name] = qkv_weight_[self.n_embed : self.n_embed + n_kv_embed, :]
weights[self._v_weight_name] = qkv_weight_[self.n_embed + n_kv_embed : self.n_embed + 2 * n_kv_embed, :]
del weights[qkv_weight_name]

qkv_bias_name = f"{self.layer_name}.self_attention.query_key_value.bias"
if qkv_bias_name in weights:
qkv_bias_ = weights[qkv_bias_name]
weights[self._q_bias_name] = qkv_bias_[: self.n_embed]
weights[self._k_bias_name] = qkv_bias_[:, self.n_embed : self.n_embed + n_kv_embed]
weights[self._v_bias_name] = qkv_bias_[:, self.n_embed + n_kv_embed : self.n_embed + 2 * n_kv_embed]
weights[self._k_bias_name] = qkv_bias_[self.n_embed : self.n_embed + n_kv_embed]
weights[self._v_bias_name] = qkv_bias_[self.n_embed + n_kv_embed : self.n_embed + 2 * n_kv_embed]
del weights[qkv_bias_name]

def _init_config(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
def _ffn(
self, input, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bTransformerLayerWeight
) -> torch.Tensor:
up_gate_out = layer_weight.gate_up_proj.mm(input)
up_gate_out = layer_weight.gate_up_proj.mm(input.view(-1, self.embed_dim_))
ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype)
gelu_and_mul_fwd(up_gate_out, ffn1_out)
input = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo

def _init_qkv(self):
q_split_n_embed = self.head_dim * self.n_head // self.world_size_
kv_split_n_embed = self.head_dim * self.n_kv_head // self.world_size_
kv_split_n_embed = self.head_dim * self.n_kv_head
self.q_proj = ROWMMWeight(self._q_weight_name, self.data_type_, q_split_n_embed, bias_name=self._q_bias_name)
self.k_proj = ROWMMWeight(
self._k_weight_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def load_hf_weights(self, weights):
def _init_weight_names(self):
super()._init_weight_names()
self._o_weight_name = f"{self.layer_name}.attention.wo.weight"
self._o_weight_name = f"{self.layer_name}.attention.wo.bias"

self._gate_weight_name = f"{self.layer_name}.feed_forward.w1.weight"
self._up_weight_name = f"{self.layer_name}.feed_forward.w3.weight"
self._down_weight_name = f"{self.layer_name}.feed_forward.w2.weight"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def _init_ffn(self):
self._down_weight_name, self.data_type_, split_inter_size, bias_name=self._down_bias_name
)
self.fuse_pairs.update({"gate_proj&up_proj": "gate_up_proj"})
self.gate_up_proj: ROWMMWeight = None

def _init_norm(self):
self.att_norm_weight_ = NormWeight(
Expand Down

0 comments on commit f69e2ec

Please sign in to comment.