Skip to content

Commit

Permalink
[Refactor][model] Refactor models. (#608)
Browse files Browse the repository at this point in the history
  • Loading branch information
sufubao authored Nov 20, 2024
1 parent 07c0a98 commit 8990f47
Show file tree
Hide file tree
Showing 20 changed files with 375 additions and 579 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _get_qkv(
self, input, cache_kv: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
) -> torch.Tensor:

q = layer_weight.q_proj.mm(input)
q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_)).view(-1, self.tp_q_head_num_, self.head_dim_)
cache_kv = layer_weight.kv_proj.mm(
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,24 @@ class BaiChuan7bTransformerLayerWeight(LlamaTransformerLayerWeight):
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[], quant_cfg=None):
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg)
return

def _init_config(self):
self.network_config_["num_key_value_heads"] = self.network_config_["num_attention_heads"]
self.n_embed = self.network_config_["hidden_size"]
self.n_head = self.network_config_["num_attention_heads"]
self.n_inter = self.network_config_["intermediate_size"]
self.n_kv_head = self.network_config_["num_key_value_heads"]
self.head_dim = self.network_config_.get("head_dim", self.n_embed // self.n_head)

def load_hf_weights(self, weights):
self.network_config_["num_key_value_heads"] = self.network_config_["num_attention_heads"]
if f"model.layers.{self.layer_num_}.self_attn.W_pack.weight" in weights:
qkv_weights = weights[f"model.layers.{self.layer_num_}.self_attn.W_pack.weight"]
qkv_weight_name = f"{self.layer_name}.self_attn.W_pack.weight"
if qkv_weight_name in weights:
qkv_weights = weights[qkv_weight_name]
split_size = qkv_weights.shape[0] // 3
q_weights, k_weights, v_weights = torch.split(qkv_weights, split_size, dim=0)
weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] = q_weights
weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] = k_weights
weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] = v_weights
del weights[f"model.layers.{self.layer_num_}.self_attn.W_pack.weight"]
weights[self._q_weight_name] = q_weights
weights[self._k_weight_name] = k_weights
weights[self._v_weight_name] = v_weights
del weights[qkv_weight_name]
super().load_hf_weights(weights)
return
133 changes: 53 additions & 80 deletions lightllm/models/bloom/layer_weights/transformer_layer_weight.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import math
import numpy as np
from lightllm.common.basemodel import TransformerLayerWeight
from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NormWeight
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight


def generate_alibi(n_head, dtype=torch.float16):
Expand Down Expand Up @@ -47,56 +47,58 @@ def get_slopes_power_of_2(n):
return head_alibi


class BloomTransformerLayerWeight(TransformerLayerWeight):
def __init__(
self, layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg=None, layer_prefix="h"
):
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg)

self.layer_name = f"{layer_prefix}.{self.layer_num_}"

self._init_name()
self._init_qkv()
self._init_o()
self._init_ffn()
self._init_norm()
self.set_quantization()
class BloomTransformerLayerWeight(LlamaTransformerLayerWeight):
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg=None):
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg, layer_prefix="h")
self.init_static_params()
return

def _init_name(self):
self._q_name = f"{self.layer_name}.self_attention.q_proj"
self._k_name = f"{self.layer_name}.self_attention.k_proj"
self._v_name = f"{self.layer_name}.self_attention.v_proj"
self.o_name = f"{self.layer_name}.self_attention.dense"
self.up_proj_name = f"{self.layer_name}.mlp.dense_h_to_4h"
self.down_proj_name = f"{self.layer_name}.mlp.dense_4h_to_h"
self.att_norm_name = f"{self.layer_name}.input_layernorm"
self.ffn_norm_name = f"{self.layer_name}.post_attention_layernorm"

def _split_qkv_weight(self, weights):
n_embed = self.network_config_["n_embed"]
head_num = self.network_config_["num_attention_heads"]

if f"{self.layer_name}.self_attention.query_key_value.weight" in weights:
att_qkv_dense_weight = weights[f"{self.layer_name}.self_attention.query_key_value.weight"].reshape(
head_num, 3, -1, n_embed
)
weights[f"{self._q_name}.weight"] = att_qkv_dense_weight[:, 0, :, :].reshape(-1, n_embed)
weights[f"{self._k_name}.weight"] = att_qkv_dense_weight[:, 1, :, :].reshape(-1, n_embed)
weights[f"{self._v_name}.weight"] = att_qkv_dense_weight[:, 2, :, :].reshape(-1, n_embed)
del weights[f"{self.layer_name}.self_attention.query_key_value.weight"]

if f"{self.layer_name}.self_attention.query_key_value.bias" in weights:
att_qkv_dense_bias = weights[f"h.{self.layer_num_}.self_attention.query_key_value.bias"].reshape(
head_num, 3, -1
)
weights[f"{self._q_name}.bias"] = att_qkv_dense_bias[:, 0, :].reshape(-1)
weights[f"{self._k_name}.bias"] = att_qkv_dense_bias[:, 1, :].reshape(-1)
weights[f"{self._v_name}.bias"] = att_qkv_dense_bias[:, 2, :].reshape(-1)
del weights[f"h.{self.layer_num_}.self_attention.query_key_value.bias"]
def _init_config(self):
self.n_embed = self.network_config_["n_embed"]
self.n_head = self.network_config_["num_attention_heads"]
self.n_inter = self.network_config_["n_embed"] * 4
self.n_kv_head = self.network_config_["num_attention_heads"]
self.head_dim = self.network_config_.get("head_dim", self.n_embed // self.n_head)

def _init_weight_names(self):
self._q_weight_name = f"{self.layer_name}.self_attention.q_proj.weight"
self._q_bias_name = f"{self.layer_name}.self_attention.q_proj.bias"
self._k_weight_name = f"{self.layer_name}.self_attention.k_proj.weight"
self._k_bias_name = f"{self.layer_name}.self_attention.k_proj.bias"
self._v_weight_name = f"{self.layer_name}.self_attention.v_proj.weight"
self._v_bias_name = f"{self.layer_name}.self_attention.v_proj.bias"
self._o_weight_name = f"{self.layer_name}.self_attention.o_proj.weight"
self._o_bias_name = f"{self.layer_name}.self_attention.o_proj.bias"

self._up_weight_name = f"{self.layer_name}.mlp.dense_h_to_4h.weight"
self._up_bias_name = f"{self.layer_name}.mlp.dense_h_to_4h.bias"
self._down_weight_name = f"{self.layer_name}.mlp.dense_4h_to_h.weight"
self._down_bias_name = f"{self.layer_name}.mlp.dense_4h_to_h.bias"

self.att_norm_weight_name = f"{self.layer_name}.input_layernorm.weight"
self.att_norm_bias_name = f"{self.layer_name}.input_layernorm.bias"
self.ffn_norm_weight_name = f"{self.layer_name}.post_attention_layernorm.weight"
self.ffn_norm_bias_name = f"{self.layer_name}.post_attention_layernorm.bias"

def _preprocess_weight(self, weights):
qkv_weight_name = f"{self.layer_name}.self_attention.query_key_value.weight"
if qkv_weight_name in weights:
att_qkv_dense_weight = weights[qkv_weight_name].reshape(self.n_head, 3, -1, self.n_embed)
weights[self._q_weight_name] = att_qkv_dense_weight[:, 0, :, :].reshape(-1, self.n_embed)
weights[self._k_weight_name] = att_qkv_dense_weight[:, 1, :, :].reshape(-1, self.n_embed)
weights[self._v_weight_name] = att_qkv_dense_weight[:, 2, :, :].reshape(-1, self.n_embed)
del weights[qkv_weight_name]

qkv_bias_name = f"{self.layer_name}.self_attention.query_key_value.bias"
if qkv_bias_name in weights:
att_qkv_dense_bias = weights[qkv_bias_name].reshape(self.n_head, 3, -1)
weights[self._q_bias_name] = att_qkv_dense_bias[:, 0, :].reshape(-1)
weights[self._k_bias_name] = att_qkv_dense_bias[:, 1, :].reshape(-1)
weights[self._v_bias_name] = att_qkv_dense_bias[:, 2, :].reshape(-1)
del weights[qkv_bias_name]

def load_hf_weights(self, weights):
self._split_qkv_weight(weights)
self._preprocess_weight(weights)
super().load_hf_weights(weights)
return

Expand All @@ -109,40 +111,11 @@ def init_static_params(self):
self.tp_alibi = tmp_alibi[self.tp_rank_ * tp_head_num : (self.tp_rank_ + 1) * tp_head_num].contiguous().cuda()
return

def _init_qkv(self):
n_embed = self.network_config_["n_embed"]
split_n_embed = n_embed // self.world_size_
self.q_proj = ROWMMWeight(
f"{self._q_name}.weight", self.data_type_, split_n_embed, bias_name=f"{self._q_name}.bias"
)
self.k_proj = ROWMMWeight(
f"{self._k_name}.weight", self.data_type_, split_n_embed, bias_name=f"{self._k_name}.bias", wait_fuse=True
)
self.v_proj = ROWMMWeight(
f"{self._v_name}.weight", self.data_type_, split_n_embed, bias_name=f"{self._v_name}.bias", wait_fuse=True
)

def _init_o(self):
n_embed = self.network_config_["n_embed"]
split_n_embed = n_embed // self.world_size_
self.o_proj = COLMMWeight(
f"{self.o_name}.weight", self.data_type_, split_n_embed, bias_name=f"{self.o_name}.bias"
)

def _init_ffn(self):
n_embed = self.network_config_["n_embed"] * 4
split_n_embed = n_embed // self.world_size_
split_inter_size = self.n_inter // self.world_size_
self.up_proj = ROWMMWeight(
f"{self.up_proj_name}.weight", self.data_type_, split_n_embed, bias_name=f"{self.up_proj_name}.bias"
self._up_weight_name, self.data_type_, split_inter_size, bias_name=self._up_bias_name, wait_fuse=True
)
self.down_proj = COLMMWeight(
f"{self.down_proj_name}.weight", self.data_type_, split_n_embed, bias_name=f"{self.down_proj_name}.bias"
)

def _init_norm(self):
self.att_norm_weight_ = NormWeight(
f"{self.att_norm_name}.weight", self.data_type_, bias_name=f"{self.att_norm_name}.bias"
)
self.ffn_norm_weight_ = NormWeight(
f"{self.ffn_norm_name}.weight", self.data_type_, bias_name=f"{self.ffn_norm_name}.bias"
self._down_weight_name, self.data_type_, split_inter_size, bias_name=self._down_bias_name
)
107 changes: 34 additions & 73 deletions lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py
Original file line number Diff line number Diff line change
@@ -1,98 +1,59 @@
import torch
import math
from lightllm.common.basemodel import TransformerLayerWeight
from lightllm.common.basemodel.layer_weights.meta_weights import BaseWeight, ROWMMWeight, COLMMWeight, NormWeight
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight


class ChatGLM2TransformerLayerWeight(TransformerLayerWeight):
class ChatGLM2TransformerLayerWeight(LlamaTransformerLayerWeight):
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[], quant_cfg=None):
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg)
assert network_config["num_attention_heads"] % self.world_size_ == 0

self.layer_name = f"transformer.encoder.layers.{self.layer_num_}"

self._init_qkv()
self._init_o()
self._init_ffn()
self._init_norm()
self.set_quantization()
super().__init__(
layer_num,
tp_rank,
world_size,
data_type,
network_config,
mode,
quant_cfg,
layer_prefix="transformer.encoder.layers",
)
return

def _preprocess_weight(self, weights):
n_embed = self.network_config_["hidden_size"]
head_dim = self.network_config_["hidden_size"] // self.network_config_["num_attention_heads"]
multi_query_group_num = self.network_config_["multi_query_group_num"]
n_kv_embed = self.head_dim * self.n_kv_head

qkv_weight_name = f"{self.layer_name}.self_attention.query_key_value.weight"
if qkv_weight_name in weights:
qkv_weight_ = weights[qkv_weight_name]
weights[f"{self._q_name}.weight"] = qkv_weight_[:, :n_embed]
weights[f"{self._k_name}.weight"] = qkv_weight_[:, n_embed : n_embed + head_dim * multi_query_group_num]
weights[f"{self._v_name}.weight"] = qkv_weight_[
:, n_embed + multi_query_group_num * head_dim : n_embed + 2 * multi_query_group_num * head_dim
]
weights[self._q_weight_name] = qkv_weight_[:, : self.n_embed]
weights[self._k_weight_name] = qkv_weight_[:, self.n_embed : self.n_embed + n_kv_embed]
weights[self._v_weight_name] = qkv_weight_[:, self.n_embed + n_kv_embed : self.n_embed + 2 * n_kv_embed]
del weights[qkv_weight_name]

qkv_bias_name = f"{self.layer_name}.self_attention.query_key_value.bias"
if qkv_bias_name in weights:
qkv_bias_ = weights[qkv_bias_name]
weights[f"{self._q_name}.bias"] = qkv_bias_[:n_embed]
weights[f"{self._k_name}.bias"] = qkv_bias_[:, n_embed : n_embed + head_dim * multi_query_group_num]
weights[f"{self._v_name}.bias"] = qkv_bias_[
:, n_embed + multi_query_group_num * head_dim : n_embed + 2 * multi_query_group_num * head_dim
]
weights[self._q_bias_name] = qkv_bias_[: self.n_embed]
weights[self._k_bias_name] = qkv_bias_[:, self.n_embed : self.n_embed + n_kv_embed]
weights[self._v_bias_name] = qkv_bias_[:, self.n_embed + n_kv_embed : self.n_embed + 2 * n_kv_embed]
del weights[qkv_bias_name]

def _init_config(self):
self.n_embed = self.network_config_["hidden_size"]
self.n_head = self.network_config_["num_attention_heads"]
self.n_inter = self.network_config_["ffn_hidden_size"]
self.n_kv_head = self.network_config_["multi_query_group_num"]
self.head_dim = self.network_config_.get("head_dim", self.n_embed // self.n_head)

def load_hf_weights(self, weights):
self._preprocess_weight(weights)
super().load_hf_weights(weights)

def _init_qkv(self):
n_embed = self.network_config_["hidden_size"]
head_dim = self.network_config_["hidden_size"] // self.network_config_["num_attention_heads"]
multi_query_group_num = self.network_config_["multi_query_group_num"]
kv_split_n_embed = multi_query_group_num // self.world_size_ * head_dim
q_split_n_embed = n_embed // self.world_size_

self._q_name = f"{self.layer_name}.self_attention.q_proj"
self._k_name = f"{self.layer_name}.self_attention.k_proj"
self._v_name = f"{self.layer_name}.self_attention.v_proj"

self.q_proj = ROWMMWeight(
f"{self._q_name}.weight", self.data_type_, q_split_n_embed, bias_name=f"{self._q_name}.bias"
)
self.k_proj = ROWMMWeight(
f"{self._k_name}.weight",
self.data_type_,
kv_split_n_embed,
bias_name=f"{self._k_name}.bias",
wait_fuse=True,
)
self.v_proj = ROWMMWeight(
f"{self._v_name}.weight",
self.data_type_,
kv_split_n_embed,
bias_name=f"{self._v_name}.bias",
wait_fuse=True,
)

def _init_o(self):
o_split_n_embed = self.network_config_["hidden_size"] // self.world_size_
self._o_name = f"{self.layer_name}.self_attention.dense.weight"

self.o_proj = COLMMWeight(
f"model.layers.{self.layer_num_}.self_attn.o_proj.weight", self.data_type_, o_split_n_embed
)
return

def _init_ffn(self):
ffn_hidden_size = self.network_config_["ffn_hidden_size"]
split_inter_size = ffn_hidden_size // self.world_size_

self.gate_up_proj = ROWMMWeight(
f"{self.layer_name}.mlp.dense_h_to_4h.weight", self.data_type_, split_inter_size
split_inter_size = self.n_inter // self.world_size_
self.up_proj = ROWMMWeight(
self._up_weight_name, self.data_type_, split_inter_size, bias_name=self._up_bias_name, wait_fuse=True
)
self.down_proj = COLMMWeight(
self._down_weight_name, self.data_type_, split_inter_size, bias_name=self._down_bias_name
)
self.down_proj = COLMMWeight(f"{self.layer_name}.mlp.dense_4h_to_h.weight", self.data_type_, split_inter_size)

def _init_norm(self):
self.att_norm_weight_ = NormWeight(f"{self.layer_name}.input_layernorm.weight", self.data_type_)
self.ffn_norm_weight_ = NormWeight(f"{self.layer_name}.post_attention_layernorm.weight", self.data_type_)
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg)
return

def init_norm(self, weights):
def _init_norm(self, weights):
q_split_head = self.network_config_["num_attention_heads"] // self.world_size_
k_split_head = self.network_config_["num_key_value_heads"] // self.world_size_

self.att_norm_weight_ = NormWeight(f"model.layers.{self.layer_num_}.input_layernorm.weight", self.data_type_)
self.att_norm_weight_ = NormWeight(self.att_norm_weight_name, self.data_type_)

if self.use_qk_norm:
self.q_norm_weight_ = TpNormWeight(
Expand Down
Loading

0 comments on commit 8990f47

Please sign in to comment.