From a4f405d003457b039aec94cca5feaed7c90080fa Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 22 Nov 2024 17:23:40 +0800 Subject: [PATCH] delete redundant files --- ...ransformer_layer_infer_template_awquant.py | 35 -- ...transformer_layer_infer_template_wquant.py | 34 -- .../layer_infer/transformer_layer_infer.py | 34 -- .../layer_infer/transformer_layer_infer.py | 44 -- .../layer_infer/transformer_layer_infer.py | 51 -- .../layer_infer/transformer_layer_infer.py | 91 --- .../layer_infer/transformer_layer_infer.py | 566 ------------------ .../layer_infer/transformer_layer_infer.py | 135 ----- .../layer_infer/transformer_layer_infer.py | 201 ------- .../layer_infer/transformer_layer_infer.py | 47 -- .../layer_infer/transformer_layer_infer.py | 43 -- .../layer_infer/transformer_layer_infer.py | 44 -- .../layer_infer/transformer_layer_infer.py | 42 -- .../layer_infer/transformer_layer_infer.py | 91 --- 14 files changed, 1458 deletions(-) delete mode 100755 lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template_awquant.py delete mode 100755 lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template_wquant.py delete mode 100755 lightllm/models/baichuan13b/layer_infer/transformer_layer_infer.py delete mode 100755 lightllm/models/internlm/layer_infer/transformer_layer_infer.py delete mode 100755 lightllm/models/internlm_wquant/layer_infer/transformer_layer_infer.py delete mode 100755 lightllm/models/internlm_xcomposer/layer_infer/transformer_layer_infer.py delete mode 100755 lightllm/models/llama_awquant/layer_infer/transformer_layer_infer.py delete mode 100755 lightllm/models/llama_quik/layer_infer/transformer_layer_infer.py delete mode 100755 lightllm/models/llama_wquant/layer_infer/transformer_layer_infer.py delete mode 100755 lightllm/models/minicpm/layer_infer/transformer_layer_infer.py delete mode 100644 lightllm/models/qwen2/layer_infer/transformer_layer_infer.py delete mode 100644 lightllm/models/qwen2_wquant/layer_infer/transformer_layer_infer.py delete mode 100755 lightllm/models/qwen_wquant/layer_infer/transformer_layer_infer.py delete mode 100755 lightllm/models/starcoder_wquant/layer_infer/transformer_layer_infer.py diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template_awquant.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template_awquant.py deleted file mode 100755 index 9ebdda66..00000000 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template_awquant.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch -from .transformer_layer_infer_template import TransformerLayerInferTpl -from ...infer_struct import InferStateInfo -from lightllm.utils.infer_utils import mark_cost_time -from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv -from typing import Tuple - - -class TransformerLayerInferActivationWeightQuantTpl(TransformerLayerInferTpl): - """ """ - - def __init__(self, layer_num, tp_rank, world_size, network_config, mode): - super().__init__(layer_num, tp_rank, world_size, network_config, mode) - return - - def _awquant_matmul_for_qkv(self, input, quant_weight_params, is_prefill, out=None, bias=None, has_act=False): - raise Exception("need to impl") - - def _awquant_matmul_for_o(self, input, quant_weight_params, is_prefill, out=None, bias=None, has_act=False): - raise Exception("need to impl") - - def _awquant_matmul_for_ffn_up(self, input, quant_weight_params, is_prefill, out=None, bias=None, has_act=False): - raise Exception("need to impl") - - def _awquant_matmul_for_ffn_down(self, input, quant_weight_params, is_prefill, out=None, bias=None, has_act=False): - raise Exception("need to impl") - - def _awquant_att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - - def _awquant_ffn_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: - raise Exception("need to impl") - - def _pre_cache_kv(self, infer_state: InferStateInfo, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]: - return infer_state.kv_buffer diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template_wquant.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template_wquant.py deleted file mode 100755 index d5b78348..00000000 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template_wquant.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -from .transformer_layer_infer_template import TransformerLayerInferTpl -from ...infer_struct import InferStateInfo -from ...splitfuse_infer_struct import SplitFuseInferStateInfo -from lightllm.utils.infer_utils import mark_cost_time -from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv -from typing import Tuple - - -class TransformerLayerInferWeightQuantTpl(TransformerLayerInferTpl): - """ """ - - def __init__(self, layer_num, tp_rank, world_size, network_config, mode): - super().__init__(layer_num, tp_rank, world_size, network_config, mode) - return - - def _wquant_matmul_for_qkv(self, input, quant_weight_params, infer_state, out=None, bias=None, has_act=False): - raise Exception("need to impl") - - def _wquant_matmul_for_o(self, input, quant_weight_params, infer_state, out=None, bias=None, has_act=False): - raise Exception("need to impl") - - def _wquant_matmul_for_ffn_up(self, input, quant_weight_params, infer_state, out=None, bias=None, has_act=False): - raise Exception("need to impl") - - def _wquant_matmul_for_ffn_down(self, input, quant_weight_params, infer_state, out=None, bias=None, has_act=False): - raise Exception("need to impl") - - def _pre_cache_kv(self, infer_state: InferStateInfo, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Release kv buffer to save memory, since we allocate while kv projection. - """ - infer_state.kv_buffer = None - return None diff --git a/lightllm/models/baichuan13b/layer_infer/transformer_layer_infer.py b/lightllm/models/baichuan13b/layer_infer/transformer_layer_infer.py deleted file mode 100755 index 8246975b..00000000 --- a/lightllm/models/baichuan13b/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -import torch.functional as F -import numpy as np -from functools import partial - -from lightllm.models.baichuan13b.layer_weights.transformer_layer_weight import BaiChuan13bTransformerLayerWeight -from lightllm.models.bloom.layer_infer.transformer_layer_infer import BloomTransformerLayerInfer -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer - - -class Baichuan13bTransformerLayerInfer(LlamaTransformerLayerInfer): - """ """ - - def __init__(self, layer_num, tp_rank, world_size, network_config, mode): - super().__init__(layer_num, tp_rank, world_size, network_config, mode) - self._bind_func() - return - - def _bind_func(self): - """ - baichuan13b only support normal mode. - """ - self._context_attention_kernel = partial(BloomTransformerLayerInfer._context_attention_kernel, self) - self._token_attention_kernel = partial(BloomTransformerLayerInfer._token_attention_kernel, self) - return - - def _get_qkv(self, input, cache_kv, infer_state, layer_weight: BaiChuan13bTransformerLayerWeight) -> torch.Tensor: - q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_) - torch.mm( - input.view(-1, self.embed_dim_), - layer_weight.kv_weight_, - out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), - ) - return q, cache_kv diff --git a/lightllm/models/internlm/layer_infer/transformer_layer_infer.py b/lightllm/models/internlm/layer_infer/transformer_layer_infer.py deleted file mode 100755 index 417848de..00000000 --- a/lightllm/models/internlm/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -import torch.functional as F -import numpy as np - -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.internlm.layer_weights.transformer_layer_weight import InternlmTransformerLayerWeight -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.llama.infer_struct import LlamaInferStateInfo - - -class InternlmTransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): - super().__init__(layer_num, tp_rank, world_size, network_config, mode) - return - - def _get_qkv( - self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeight - ) -> torch.Tensor: - input = input.view(-1, self.embed_dim_) - q = self.alloc_tensor((input.size(0), layer_weight.q_weight_.size(1)), dtype=input.dtype) - torch.addmm(layer_weight.q_bias_, input, layer_weight.q_weight_, beta=1.0, alpha=1.0, out=q) - torch.addmm( - layer_weight.kv_bias_, - input, - layer_weight.kv_weight_, - beta=1.0, - alpha=1.0, - out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), - ) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, - ) - return q, cache_kv - - def _get_o( - self, input, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeight - ) -> torch.Tensor: - input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) - o_tensor = self.alloc_tensor((input.size(0), layer_weight.o_weight_.size(1)), input.dtype) - torch.addmm(layer_weight.o_bias_, input, layer_weight.o_weight_, beta=1.0 / self.world_size_, out=o_tensor) - return o_tensor diff --git a/lightllm/models/internlm_wquant/layer_infer/transformer_layer_infer.py b/lightllm/models/internlm_wquant/layer_infer/transformer_layer_infer.py deleted file mode 100755 index 3f36320a..00000000 --- a/lightllm/models/internlm_wquant/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -import torch.functional as F -import numpy as np -import triton - -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.internlm_wquant.layer_weights.transformer_layer_weight import ( - InternlmTransformerLayerWeightQuantized, -) -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama_wquant.layer_infer.transformer_layer_infer import LlamaTransformerLayerInferWquant - - -class InternlmTransformerLayerInferWquant(LlamaTransformerLayerInferWquant): - def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): - super().__init__(layer_num, tp_rank, world_size, network_config, mode) - return - - def _get_qkv( - self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeightQuantized - ): - q = self._wquant_matmul_for_qkv( - input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.q_weight_, - infer_state=infer_state, - bias=layer_weight.q_bias_, - ) - cache_kv = self._wquant_matmul_for_qkv( - input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.kv_weight_, - infer_state=infer_state, - bias=layer_weight.kv_bias_, - ).view(-1, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, - ) - return q, cache_kv - - def _get_o( - self, input, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeightQuantized - ) -> torch.Tensor: - o_tensor = self._wquant_matmul_for_o( - input, - quant_weight_params=layer_weight.o_weight_, - infer_state=infer_state, - bias=layer_weight.o_bias_ / self.world_size_, - ) - return o_tensor diff --git a/lightllm/models/internlm_xcomposer/layer_infer/transformer_layer_infer.py b/lightllm/models/internlm_xcomposer/layer_infer/transformer_layer_infer.py deleted file mode 100755 index 4da28843..00000000 --- a/lightllm/models/internlm_xcomposer/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -import torch.functional as F -import numpy as np - -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd -from lightllm.models.internlm_xcomposer.layer_weights.transformer_layer_weight import ( - InternlmComposerTransformerLayerWeight, -) -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.internlm_xcomposer.infer_struct import InternlmComposerInferStateInfo - - -class InternlmComposerTransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): - super().__init__(layer_num, tp_rank, world_size, network_config, mode) - return - - def _get_qkv( - self, - input, - cache_kv, - infer_state: InternlmComposerInferStateInfo, - layer_weight: InternlmComposerTransformerLayerWeight, - ) -> torch.Tensor: - im_mask = infer_state.im_mask - has_img = infer_state.has_img - q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_) - torch.addmm( - layer_weight.kv_bias_, - input.view(-1, self.embed_dim_), - layer_weight.kv_weight_, - beta=1.0, - alpha=1.0, - out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), - ) - if has_img: - input_part = input.view(-1, self.embed_dim_)[im_mask] - q[im_mask] += torch.mm(torch.mm(input_part, layer_weight.qkv_loraA_weight_), layer_weight.q_loraB_weight_) - cache_kv[im_mask] += torch.mm( - torch.mm(input_part, layer_weight.qkv_loraA_weight_), layer_weight.kv_loraB_weight_ - ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, - ) - return q, cache_kv - - def _get_o( - self, input, infer_state: InternlmComposerInferStateInfo, layer_weight: InternlmComposerTransformerLayerWeight - ) -> torch.Tensor: - im_mask = infer_state.im_mask - has_img = infer_state.has_img - o_tensor = torch.mm( - input.view(-1, self.tp_o_head_num_ * self.head_dim_), - layer_weight.o_weight_, - ) - if has_img: - input_part = input.view(-1, self.tp_o_head_num_ * self.head_dim_)[im_mask] - o_tensor[im_mask] += torch.mm( - torch.mm(input_part, layer_weight.wo_loraA_weight_), layer_weight.wo_loraB_weight_ - ) - return o_tensor - - def _ffn( - self, input, infer_state: InternlmComposerInferStateInfo, layer_weight: InternlmComposerTransformerLayerWeight - ) -> torch.Tensor: - im_mask = infer_state.im_mask - has_img = infer_state.has_img - up_gate_out = torch.mm(input.view(-1, self.embed_dim_), layer_weight.gate_up_proj) - if has_img: - gate_dim = up_gate_out.shape[1] // 2 - input_part = input.view(-1, self.embed_dim_)[im_mask] - up_gate_out[:, :gate_dim][im_mask] += torch.mm( - torch.mm(input_part, layer_weight.gate_loraA_weight_), layer_weight.gate_loraB_weight_ - ) - up_gate_out[:, gate_dim:][im_mask] += torch.mm( - torch.mm(input_part, layer_weight.up_loraA_weight_), layer_weight.up_loraB_weight_ - ) - ffn1_out = silu_and_mul_fwd(up_gate_out) - input = None - up_gate_out = None - ffn2_out = torch.mm(ffn1_out, layer_weight.down_proj) - if has_img: - ffn2_out[im_mask] += torch.mm( - torch.mm(ffn1_out[im_mask], layer_weight.down_loraA_weight_), layer_weight.down_loraB_weight_ - ) - ffn1_out = None - return ffn2_out diff --git a/lightllm/models/llama_awquant/layer_infer/transformer_layer_infer.py b/lightllm/models/llama_awquant/layer_infer/transformer_layer_infer.py deleted file mode 100755 index b4fed46d..00000000 --- a/lightllm/models/llama_awquant/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,566 +0,0 @@ -from typing import Tuple - -import numpy as np -import torch -from lightllm.distributed import tensor_model_parallel_all_reduce -import torch.functional as F -import triton -from functools import partial - -from lightllm.models.llama_awquant.layer_weights.transformer_layer_weight import ( - LlamaTransformerLayerActivationWeightQuantPpl, - LlamaTransformerLayerActivationWeightQuantTriton, -) -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.splitfuse_infer_struct import LlamaSplitFuseInferStateInfo -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.common.basemodel import TransformerLayerInferActivationWeightQuantTpl -from lightllm.common.basemodel.cuda_kernel.ppl_awquant import ( - matmul_i8_i32_ppl, - skiprmsnorm_ppl, - channel_token_dequant_i32_fp16_ppl, -) -from lightllm.common.basemodel.cuda_kernel.ppl_awquant import ( - dynamic_channelwise_quant_fp16_i8_ppl, - gatesilu_i32_i8_ppl, - gatesilu_i32_fp16_ppl, -) -from lightllm.common.basemodel.triton_kernel.quantize_gemm_int8 import matmul_quantize_int8 -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd -from lightllm.utils.infer_utils import mark_cost_time - - -class LlamaTransformerLayerInferActivationWeightQuantPpl(TransformerLayerInferActivationWeightQuantTpl): - """ """ - - def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): - super().__init__(layer_num, tp_rank, world_size, network_config, mode) - self.eps_ = network_config["rms_norm_eps"] - self.tp_q_head_num_ = network_config["num_attention_heads"] // self.world_size_ - self.tp_k_head_num_ = network_config["num_key_value_heads"] // self.world_size_ - self.tp_v_head_num_ = network_config["num_key_value_heads"] // self.world_size_ - self.tp_o_head_num_ = self.tp_q_head_num_ - self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"] - self.embed_dim_ = network_config["hidden_size"] - - self.inter_dim_ = network_config["intermediate_size"] - self._bind_func() - return - - def _bind_func(self): - self._bind_norm() - self._bind_matmul() - self._bind_silu() - LlamaTransformerLayerInfer._bind_attention(self) - return - - def _bind_norm(self): - if "ppl_w8a8" in self.mode or "ppl_w8a8_mixdown" in self.mode: - self._awquant_att_norm = partial( - LlamaTransformerLayerInferActivationWeightQuantPpl._awquant_att_norm_ppl_int8, self - ) - self._awquant_ffn_norm = partial( - LlamaTransformerLayerInferActivationWeightQuantPpl._awquant_ffn_norm_ppl_int8, self - ) - else: - raise Exception(f"error mode {self.mode}") - return - - def _bind_matmul(self): - if "ppl_w8a8" in self.mode or "ppl_w8a8_mixdown" in self.mode: - self._awquant_matmul_for_qkv = partial( - LlamaTransformerLayerInferActivationWeightQuantPpl._awquant_matmul_ppl_int8_quant_dequant, self - ) - self._awquant_matmul_for_o = partial( - LlamaTransformerLayerInferActivationWeightQuantPpl._awquant_matmul_ppl_int8_quant_dequant, self - ) - self._awquant_matmul_for_ffn_up = partial( - LlamaTransformerLayerInferActivationWeightQuantPpl._awquant_matmul_ppl_int8_quant, self - ) - self._awquant_matmul_for_ffn_down = partial( - LlamaTransformerLayerInferActivationWeightQuantPpl._awquant_matmul_ppl_int8_quant_dequant, self - ) - if self.tp_rank_ == 0 and self.layer_num_ == 0: - print("model use ppl_w8a8 kernel") - else: - raise Exception(f"error mode {self.mode}") - return - - def _bind_silu(self): - if "ppl_w8a8" in self.mode: - func = partial(LlamaTransformerLayerInferActivationWeightQuantPpl._awquant_silu_ppl_int8, self) - self._awquant_silu = func - else: - raise Exception(f"error mode {self.mode}") - return - - def _get_qkv( - self, - input, - cache_kv, - token_scale, - infer_state: LlamaInferStateInfo, - layer_weight: LlamaTransformerLayerActivationWeightQuantPpl, - ) -> torch.Tensor: - q = self._awquant_matmul_for_qkv( - input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.q_weight_, - is_prefill=infer_state.is_prefill, - token_scale=token_scale, - ) - - cache_k_ = self._awquant_matmul_for_qkv( - input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.k_weight_, - is_prefill=infer_state.is_prefill, - token_scale=token_scale, - ) - - cache_k_ = cache_k_.view(-1, self.tp_k_head_num_, self.head_dim_) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_k_, - infer_state.position_cos, - infer_state.position_sin, - ) - cache_v_ = self._awquant_matmul_for_qkv( - input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.v_weight_, - is_prefill=infer_state.is_prefill, - token_scale=token_scale, - ).view(-1, self.tp_v_head_num_, self.head_dim_) - - infer_state.kv_buffer[:, 0 : self.tp_k_head_num_, :] = cache_k_ - infer_state.kv_buffer[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] = cache_v_ - - return q, infer_state.kv_buffer - - def _get_o( - self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerActivationWeightQuantPpl - ) -> torch.Tensor: - o_tensor = torch.mm(input.view(-1, self.tp_o_head_num_ * self.head_dim_), layer_weight.o_weight_) - return o_tensor - - def _ffn( - self, - input, - token_scale, - infer_state: LlamaInferStateInfo, - layer_weight: LlamaTransformerLayerActivationWeightQuantPpl, - ) -> torch.Tensor: - gate_out = self._awquant_matmul_for_ffn_up( - input.view(-1, self.embed_dim_), - layer_weight.gate_proj, - is_prefill=infer_state.is_prefill, - ) - up_out = self._awquant_matmul_for_ffn_up( - input.view(-1, self.embed_dim_), - layer_weight.up_proj, - is_prefill=infer_state.is_prefill, - ) - input = None - _, gate_proj_scale = layer_weight.gate_proj - _, up_proj_scale = layer_weight.up_proj - ffn1_out, ffn1_out_scale = self._awquant_silu(gate_out, up_out, gate_proj_scale, up_proj_scale, token_scale) - gate_out, up_out = None, None - ffn2_out = self._awquant_matmul_for_ffn_down( - ffn1_out, layer_weight.down_proj, is_prefill=infer_state.is_prefill, token_scale=ffn1_out_scale - ) - ffn1_out = None - - return ffn2_out - - def _context_attention(self, input_embding, infer_state: LlamaInferStateInfo, layer_weight): - input1, token_scale, skip_out = self._awquant_att_norm(input_embding, infer_state, layer_weight) - cache_kv = self._pre_cache_kv(infer_state, layer_weight) - q, cache_kv = self._get_qkv(input1, cache_kv, token_scale, infer_state, layer_weight) - input1 = None - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) - q = None - o = self._get_o(o, infer_state, layer_weight) - if self.world_size_ > 1: - o = tensor_model_parallel_all_reduce(o) - input_embding.add_(o.view(-1, self.embed_dim_)) - return - - def _context_ffn(self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight): - input1, token_scale, skip_out = self._awquant_ffn_norm(input_embdings, infer_state, layer_weight) - ffn_out = self._ffn(input1, token_scale, infer_state, layer_weight) - input1 = None - if self.world_size_ > 1: - ffn_out = tensor_model_parallel_all_reduce(ffn_out) - input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) - return - - def _token_attention(self, input_embding, infer_state: LlamaInferStateInfo, layer_weight): - input1, token_scale, skip_out = self._awquant_att_norm(input_embding, infer_state, layer_weight) - cache_kv = self._pre_cache_kv(infer_state, layer_weight) - q, cache_kv = self._get_qkv(input1, cache_kv, token_scale, infer_state, layer_weight) - input1 = None - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._token_attention_kernel(q, infer_state, layer_weight) - q = None - o = self._get_o(o, infer_state, layer_weight) - if self.world_size_ > 1: - o = tensor_model_parallel_all_reduce(o) - input_embding.add_(o.view(-1, self.embed_dim_)) - return - - def _token_ffn(self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight): - input1, token_scale, skip_out = self._awquant_ffn_norm(input_embdings, infer_state, layer_weight) - ffn_out = self._ffn(input1, token_scale, infer_state, layer_weight) - input1 = None - if self.world_size_ > 1: - ffn_out = tensor_model_parallel_all_reduce(ffn_out) - input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) - return - - def _splitfuse_attention(self, input_embding, infer_state: LlamaSplitFuseInferStateInfo, layer_weight): - # 因为 LlamaSplitFuseInferStateInfo 对象并没有 is_prefill 成员,但是后续的矩阵乘法算子入口 - # 函数输入中需要使用到, 所以在开始的地方默认添加一个 is_prefill 成员,并设置为True. - infer_state.is_prefill = True - - input1, token_scale, skip_out = self._awquant_att_norm(input_embding, infer_state, layer_weight) - cache_kv = self._pre_cache_kv(infer_state, layer_weight) - q, cache_kv = self._get_qkv(input1, cache_kv, token_scale, infer_state, layer_weight) - input1 = None - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._splitfuse_attention_kernel(q, infer_state, layer_weight) - q = None - o = self._get_o(o, infer_state, layer_weight) - if self.world_size_ > 1: - o = tensor_model_parallel_all_reduce(o) - input_embding.add_(o.view(-1, self.embed_dim_)) - return - - def _splitfuse_ffn(self, input_embdings, infer_state: LlamaSplitFuseInferStateInfo, layer_weight): - input1, token_scale, skip_out = self._awquant_ffn_norm(input_embdings, infer_state, layer_weight) - ffn_out = self._ffn(input1, token_scale, infer_state, layer_weight) - input1 = None - if self.world_size_ > 1: - ffn_out = tensor_model_parallel_all_reduce(ffn_out) - input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) - return - - def _awquant_matmul_ppl_int8_quant_dequant( - self, input, quant_weight_params, is_prefill, token_scale=None, out=None, bias=None, has_act=False - ): - if input.dtype == torch.float16: - input, token_scale = dynamic_channelwise_quant_fp16_i8_ppl(input.transpose(0, 1)) - assert has_act is False - qweight, qscale = quant_weight_params - out = matmul_i8_i32_ppl(input, qweight) - out = channel_token_dequant_i32_fp16_ppl(out, token_scale, qscale) - if bias is not None: - out.add_(bias) - return out - - def _awquant_matmul_ppl_int8_quant( - self, input, quant_weight_params, is_prefill, out=None, bias=None, has_act=False - ): - assert has_act is False - qweight, qscale = quant_weight_params - out = matmul_i8_i32_ppl(input, qweight) - if bias is not None: - out.add_(bias) - return out - - def _awquant_att_norm_ppl_int8(self, input, infer_state: LlamaInferStateInfo, layer_weight): - if getattr(infer_state, "skip", None) is None: - infer_state.skip = torch.zeros_like(input) - return skiprmsnorm_ppl(input, layer_weight.att_norm_weight_, skip=infer_state.skip) - - def _awquant_ffn_norm_ppl_int8(self, input, infer_state: LlamaInferStateInfo, layer_weight): - return skiprmsnorm_ppl(input, layer_weight.ffn_norm_weight_, skip=infer_state.skip) - - def _awquant_silu_ppl_int8(self, x, y, x_scale, y_scale, token_scale): - return gatesilu_i32_i8_ppl(x, y, x_scale, y_scale, token_scale) - - -class LlamaTransformerLayerInferActivationWeightQuantPplMixdown(LlamaTransformerLayerInferActivationWeightQuantPpl): - """ """ - - def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): - super(LlamaTransformerLayerInferActivationWeightQuantPpl, self).__init__( - layer_num, tp_rank, world_size, network_config, mode - ) - self.eps_ = network_config["rms_norm_eps"] - self.tp_q_head_num_ = network_config["num_attention_heads"] // self.world_size_ - self.tp_k_head_num_ = network_config["num_key_value_heads"] // self.world_size_ - self.tp_v_head_num_ = network_config["num_key_value_heads"] // self.world_size_ - self.tp_o_head_num_ = self.tp_q_head_num_ - self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"] - self.embed_dim_ = network_config["hidden_size"] - self.inter_dim_ = network_config["intermediate_size"] - self._init_mixdown() - self._bind_func() - self._bind_ffn() - return - - def _init_mixdown(self): - self.mixdown = self.network_config_.get("mixdown", list(range(self.network_config_["num_hidden_layers"]))) - assert isinstance(self.mixdown, list), "mixdown must be all or a list." - - def _bind_silu(self): - if "ppl_w8a8_mixdown" in self.mode: - if self.layer_num_ in self.mixdown: - func = partial(LlamaTransformerLayerInferActivationWeightQuantPplMixdown._awquant_silu_ppl_fp16, self) - else: - func = partial(LlamaTransformerLayerInferActivationWeightQuantPpl._awquant_silu_ppl_int8, self) - self._awquant_silu = func - else: - raise Exception(f"error mode {self.mode}") - return - - def _bind_ffn(self): - if "ppl_w8a8_mixdown" in self.mode: - if self.layer_num_ in self.mixdown: - func = partial(LlamaTransformerLayerInferActivationWeightQuantPplMixdown._ffn_down_fp16, self) - else: - func = partial(LlamaTransformerLayerInferActivationWeightQuantPplMixdown._ffn_down_int8, self) - self._ffn = func - else: - raise Exception(f"error mode {self.mode}") - return - - def _ffn_down_int8( - self, - input, - token_scale, - infer_state: LlamaInferStateInfo, - layer_weight: LlamaTransformerLayerActivationWeightQuantPpl, - ) -> torch.Tensor: - gate_out = self._awquant_matmul_for_ffn_up( - input.view(-1, self.embed_dim_), - layer_weight.gate_proj, - is_prefill=infer_state.is_prefill, - ) - up_out = self._awquant_matmul_for_ffn_up( - input.view(-1, self.embed_dim_), - layer_weight.up_proj, - is_prefill=infer_state.is_prefill, - ) - input = None - _, gate_proj_scale = layer_weight.gate_proj - _, up_proj_scale = layer_weight.up_proj - ffn1_out, ffn1_out_scale = self._awquant_silu(gate_out, up_out, gate_proj_scale, up_proj_scale, token_scale) - gate_out, up_out = None, None - ffn2_out = self._awquant_matmul_for_ffn_down( - ffn1_out, layer_weight.down_proj, is_prefill=infer_state.is_prefill, token_scale=ffn1_out_scale - ) - ffn1_out = None - - return ffn2_out - - def _ffn_down_fp16( - self, - input, - token_scale, - infer_state: LlamaInferStateInfo, - layer_weight: LlamaTransformerLayerActivationWeightQuantPpl, - ) -> torch.Tensor: - gate_out = self._awquant_matmul_for_ffn_up( - input.view(-1, self.embed_dim_), - layer_weight.gate_proj, - is_prefill=infer_state.is_prefill, - ) - up_out = self._awquant_matmul_for_ffn_up( - input.view(-1, self.embed_dim_), - layer_weight.up_proj, - is_prefill=infer_state.is_prefill, - ) - input = None - _, gate_proj_scale = layer_weight.gate_proj - _, up_proj_scale = layer_weight.up_proj - ffn1_out = self._awquant_silu(gate_out, up_out, gate_proj_scale, up_proj_scale, token_scale) - gate_out, up_out = None, None - ffn2_out = torch.mm(ffn1_out, layer_weight.down_proj) - ffn1_out = None - - return ffn2_out - - def _awquant_silu_ppl_fp16(self, x, y, x_scale, y_scale, token_scale): - return gatesilu_i32_fp16_ppl(x, y, x_scale, y_scale, token_scale) - - -class LlamaTransformerLayerInferActivationWeightQuantTriton(TransformerLayerInferActivationWeightQuantTpl): - """ """ - - def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): - super().__init__(layer_num, tp_rank, world_size, network_config, mode) - self.eps_ = network_config["rms_norm_eps"] - self.tp_q_head_num_ = network_config["num_attention_heads"] // self.world_size_ - self.tp_k_head_num_ = network_config["num_key_value_heads"] // self.world_size_ - self.tp_v_head_num_ = network_config["num_key_value_heads"] // self.world_size_ - self.tp_o_head_num_ = self.tp_q_head_num_ - self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"] - self.embed_dim_ = network_config["hidden_size"] - - self.inter_dim_ = network_config["intermediate_size"] - self._bind_func() - return - - def _bind_func(self): - self._bind_norm() - self._bind_matmul() - self._bind_silu() - LlamaTransformerLayerInfer._bind_attention(self) - return - - def _bind_norm(self): - if "triton_w8a8" in self.mode: - self._awquant_att_norm = partial(LlamaTransformerLayerInfer._att_norm, self) - self._awquant_ffn_norm = partial(LlamaTransformerLayerInfer._ffn_norm, self) - else: - raise Exception(f"error mode {self.mode}") - return - - def _bind_matmul(self): - if "triton_w8a8" in self.mode: - func = partial(LlamaTransformerLayerInferActivationWeightQuantTriton._awquant_matmul_triton_w8a8, self) - self._awquant_matmul_for_qkv = func - self._awquant_matmul_for_o = func - self._awquant_matmul_for_ffn_up = func - self._awquant_matmul_for_ffn_down = func - else: - raise Exception(f"error mode {self.mode}") - return - - def _bind_silu(self): - if "triton_w8a8" in self.mode: - self._awquant_silu = silu_and_mul_fwd - else: - raise Exception(f"error mode {self.mode}") - return - - def _get_qkv( - self, - input, - cache_kv, - token_scale, - infer_state: LlamaInferStateInfo, - layer_weight: LlamaTransformerLayerActivationWeightQuantTriton, - ) -> torch.Tensor: - q = self._awquant_matmul_for_qkv( - input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.q_weight_, - is_prefill=infer_state.is_prefill, - token_scale=token_scale, - ) - - cache_k_ = self._awquant_matmul_for_qkv( - input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.k_weight_, - is_prefill=infer_state.is_prefill, - token_scale=token_scale, - ) - - cache_k_ = cache_k_.view(-1, self.tp_k_head_num_, self.head_dim_) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_k_, - infer_state.position_cos, - infer_state.position_sin, - ) - cache_v_ = self._awquant_matmul_for_qkv( - input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.v_weight_, - is_prefill=infer_state.is_prefill, - token_scale=token_scale, - ).view(-1, self.tp_v_head_num_, self.head_dim_) - - infer_state.kv_buffer[:, 0 : self.tp_k_head_num_, :] = cache_k_ - infer_state.kv_buffer[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :] = cache_v_ - - return q, infer_state.kv_buffer - - def _get_o( - self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerActivationWeightQuantTriton - ) -> torch.Tensor: - o_tensor = self._awquant_matmul_for_o( - input.view(-1, self.tp_o_head_num_ * self.head_dim_), - layer_weight.o_weight_, - is_prefill=infer_state.is_prefill, - ) - return o_tensor - - def _ffn( - self, - input, - token_scale, - infer_state: LlamaInferStateInfo, - layer_weight: LlamaTransformerLayerActivationWeightQuantTriton, - ) -> torch.Tensor: - up_gate_out = self._awquant_matmul_for_ffn_up( - input.view(-1, self.embed_dim_), layer_weight.gate_up_proj, is_prefill=infer_state.is_prefill - ) - ffn1_out = self._awquant_silu(up_gate_out) - input = None - up_gate_out = None - ffn2_out = self._awquant_matmul_for_ffn_down( - ffn1_out, layer_weight.down_proj, is_prefill=infer_state.is_prefill - ) - ffn1_out = None - return ffn2_out - - def _context_attention(self, input_embding, infer_state: LlamaInferStateInfo, layer_weight): - input1 = self._awquant_att_norm(input_embding, infer_state, layer_weight) - token_scale = None - cache_kv = self._pre_cache_kv(infer_state, layer_weight) - q, cache_kv = self._get_qkv(input1, cache_kv, token_scale, infer_state, layer_weight) - input1 = None - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) - q = None - o = self._get_o(o, infer_state, layer_weight) - if self.world_size_ > 1: - o = tensor_model_parallel_all_reduce(o) - input_embding.add_(o.view(-1, self.embed_dim_)) - return - - def _context_ffn(self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight): - input1 = self._awquant_ffn_norm(input_embdings, infer_state, layer_weight) - token_scale = None - ffn_out = self._ffn(input1, token_scale, infer_state, layer_weight) - input1 = None - if self.world_size_ > 1: - ffn_out = tensor_model_parallel_all_reduce(ffn_out) - input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) - return - - def _token_attention(self, input_embding, infer_state: LlamaInferStateInfo, layer_weight): - input1 = self._awquant_att_norm(input_embding, infer_state, layer_weight) - token_scale = None - cache_kv = self._pre_cache_kv(infer_state, layer_weight) - q, cache_kv = self._get_qkv(input1, cache_kv, token_scale, infer_state, layer_weight) - input1 = None - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._token_attention_kernel(q, infer_state, layer_weight) - q = None - o = self._get_o(o, infer_state, layer_weight) - if self.world_size_ > 1: - o = tensor_model_parallel_all_reduce(o) - input_embding.add_(o.view(-1, self.embed_dim_)) - return - - def _token_ffn(self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight): - input1 = self._awquant_ffn_norm(input_embdings, infer_state, layer_weight) - token_scale = None - ffn_out = self._ffn(input1, token_scale, infer_state, layer_weight) - input1 = None - if self.world_size_ > 1: - ffn_out = tensor_model_parallel_all_reduce(ffn_out) - input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) - return - - def _awquant_matmul_triton_w8a8( - self, input, quant_weight_params, is_prefill, token_scale=None, out=None, bias=None, has_act=False - ): - assert has_act is False - qweight, qscale = quant_weight_params - out = matmul_quantize_int8(input, qweight, qscale) - if bias is not None: - out.add_(bias) - return out diff --git a/lightllm/models/llama_quik/layer_infer/transformer_layer_infer.py b/lightllm/models/llama_quik/layer_infer/transformer_layer_infer.py deleted file mode 100755 index 4f855573..00000000 --- a/lightllm/models/llama_quik/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,135 +0,0 @@ -from typing import Tuple - -import numpy as np -import torch -from lightllm.distributed import tensor_model_parallel_all_reduce -import torch.functional as F -import triton -from functools import partial - -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.llama_quik.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeightQuik -from lightllm.utils.infer_utils import mark_cost_time - - -class LlamaTransformerLayerInferQuik(LlamaTransformerLayerInfer): - """ """ - - def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): - super().__init__(layer_num, tp_rank, world_size, network_config, mode) - self.eps_ = network_config["rms_norm_eps"] - self.tp_q_head_num_ = network_config["num_attention_heads"] // self.world_size_ - self.tp_k_head_num_ = network_config["num_key_value_heads"] // self.world_size_ - self.tp_v_head_num_ = network_config["num_key_value_heads"] // self.world_size_ - self.tp_o_head_num_ = self.tp_q_head_num_ - self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"] - self.embed_dim_ = network_config["hidden_size"] - - self.inter_dim_ = network_config["intermediate_size"] - self._bind_func() - return - - def _get_qkv( - self, - input: torch.Tensor, - cache_kv: torch.Tensor, - infer_state: LlamaInferStateInfo, - layer_weight: LlamaTransformerLayerWeightQuik, - ) -> torch.Tensor: - q = layer_weight.q_proj(input.view(-1, self.embed_dim_)) - if layer_weight.cat_kv_: - cache_kv = layer_weight.kv_proj(input.view(-1, self.embed_dim_)).view( - -1, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_ - ) - else: - cache_k = layer_weight.k_proj(input.view(-1, self.embed_dim_)).view(-1, self.tp_k_head_num_, self.head_dim_) - cache_v = layer_weight.v_proj(input.view(-1, self.embed_dim_)).view(-1, self.tp_v_head_num_, self.head_dim_) - cache_kv = torch.cat([cache_k, cache_v], dim=1) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, - ) - return q, cache_kv - - def _get_o( - self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeightQuik - ) -> torch.Tensor: - return layer_weight.o_proj(input.view(-1, self.embed_dim_)) - - def _ffn( - self, - input, - infer_state: LlamaInferStateInfo, - layer_weight: LlamaTransformerLayerWeightQuik, - ) -> torch.Tensor: - if not layer_weight.cat_gate_up_: - gate_out = layer_weight.gate_proj(input.view(-1, self.embed_dim_)) - up_out = layer_weight.up_proj(input.view(-1, self.embed_dim_)) - torch.nn.functional.silu(gate_out, inplace=True) - gate_out.mul_(up_out) - input = None - ffn2_out = layer_weight.down_proj(gate_out) - gate_out, up_out = None, None - else: - gate_up_out = layer_weight.gate_up_proj(input.view(-1, self.embed_dim_)).view( - -1, self.inter_dim_ * 2 // self.world_size_ - ) - # gate_out, up_out = torch.split(gate_up_out, split_size_or_sections=1, dim=1) - ffn1_out = silu_and_mul_fwd(gate_up_out) - input = None - gate_up_out = None - ffn2_out = layer_weight.down_proj(ffn1_out) - ffn1_out = None - - return ffn2_out - - def _context_attention(self, input_embding, infer_state: LlamaInferStateInfo, layer_weight): - input1 = self._att_norm(input_embding, infer_state, layer_weight) - cache_kv = self._pre_cache_kv(infer_state, layer_weight) - q, cache_kv = self._get_qkv(input1, cache_kv, infer_state, layer_weight) - input1 = None - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) - q = None - o = self._get_o(o, infer_state, layer_weight) - if self.world_size_ > 1: - o = tensor_model_parallel_all_reduce(o) - input_embding.add_(o.view(-1, self.embed_dim_)) - return - - def _context_ffn(self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight): - input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) - ffn_out = self._ffn(input1, infer_state, layer_weight) - input1 = None - if self.world_size_ > 1: - ffn_out = tensor_model_parallel_all_reduce(ffn_out) - input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) - return - - def _token_attention(self, input_embding, infer_state: LlamaInferStateInfo, layer_weight): - input1 = self._att_norm(input_embding, infer_state, layer_weight) - cache_kv = self._pre_cache_kv(infer_state, layer_weight) - q, cache_kv = self._get_qkv(input1, cache_kv, infer_state, layer_weight) - input1 = None - self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._token_attention_kernel(q, infer_state, layer_weight) - q = None - o = self._get_o(o, infer_state, layer_weight) - if self.world_size_ > 1: - o = tensor_model_parallel_all_reduce(o) - input_embding.add_(o.view(-1, self.embed_dim_)) - return - - def _token_ffn(self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight): - input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) - ffn_out = self._ffn(input1, infer_state, layer_weight) - input1 = None - if self.world_size_ > 1: - ffn_out = tensor_model_parallel_all_reduce(ffn_out) - input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) - return diff --git a/lightllm/models/llama_wquant/layer_infer/transformer_layer_infer.py b/lightllm/models/llama_wquant/layer_infer/transformer_layer_infer.py deleted file mode 100755 index 0a910e4a..00000000 --- a/lightllm/models/llama_wquant/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,201 +0,0 @@ -from typing import Tuple - -import numpy as np -import torch -import torch.functional as F -import triton -from functools import partial - -from lightllm.models.llama_wquant.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeightQuantized -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.common.basemodel import TransformerLayerInferWeightQuantTpl -from lightllm.common.basemodel.triton_kernel.quantize_gemm_int8 import matmul_quantize_int8 -from lightllm.common.basemodel.triton_kernel.dequantize_gemm_int8 import matmul_dequantize_int8 -from lightllm.common.basemodel.triton_kernel.dequantize_gemm_int4 import ( - matmul_dequantize_int4_s1, - matmul_dequantize_int4_s2, - matmul_dequantize_int4_gptq, -) -from lightllm.common.basemodel.cuda_kernel.lmdeploy_wquant import matmul_dequantize_int4_lmdeploy -from lightllm.common.basemodel.cuda_kernel.ppl_wquant import matmul_dequantize_int4_ppl -from lightllm.common.basemodel.cuda_kernel.fast_llm_wquant import matmul_dequantize_int6_fast_llm -from lightllm.utils.infer_utils import mark_cost_time -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class LlamaTransformerLayerInferWquant(TransformerLayerInferWeightQuantTpl): - """ """ - - def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): - super().__init__(layer_num, tp_rank, world_size, network_config, mode) - self.eps_ = network_config["rms_norm_eps"] - self.tp_q_head_num_ = network_config["num_attention_heads"] // self.world_size_ - self.tp_k_head_num_ = network_config["num_key_value_heads"] // self.world_size_ - self.tp_v_head_num_ = network_config["num_key_value_heads"] // self.world_size_ - self.tp_o_head_num_ = self.tp_q_head_num_ - self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"] - self.embed_dim_ = network_config["hidden_size"] - - self.inter_dim_ = network_config["intermediate_size"] - self._bind_func() - return - - def _bind_func(self): - self._bind_matmul() - LlamaTransformerLayerInfer._bind_norm(self) - LlamaTransformerLayerInfer._bind_attention(self) - return - - def _bind_matmul(self): - if "triton_w8a16" in self.mode: - func = partial(LlamaTransformerLayerInferWquant._wquant_matmul_triton_int8weight_only_quant, self) - self._wquant_matmul_for_qkv = func - self._wquant_matmul_for_o = func - self._wquant_matmul_for_ffn_up = func - self._wquant_matmul_for_ffn_down = func - if self.tp_rank_ == 0 and self.layer_num_ == 0: - logger.info("model use triton_w8a16 kernel") - elif "triton_w4a16" in self.mode: - func = partial(LlamaTransformerLayerInferWquant._wquant_matmul_triton_int4weight_only_quant, self) - self._wquant_matmul_for_qkv = func - self._wquant_matmul_for_o = func - self._wquant_matmul_for_ffn_up = func - self._wquant_matmul_for_ffn_down = func - if self.tp_rank_ == 0 and self.layer_num_ == 0: - logger.info("model use triton_w4a16 kernel") - elif "lmdeploy_w4a16" in self.mode: - func = partial(LlamaTransformerLayerInferWquant._wquant_matmul_lmdeploy_int4weight_only_quant, self) - self._wquant_matmul_for_qkv = func - self._wquant_matmul_for_o = func - self._wquant_matmul_for_ffn_up = func - self._wquant_matmul_for_ffn_down = func - if self.tp_rank_ == 0 and self.layer_num_ == 0: - logger.info("model use lmdeploy_w4a16 kernel") - elif "ppl_w4a16" in self.mode: - func = partial(LlamaTransformerLayerInferWquant._wquant_matmul_ppl_int4weight_only_quant, self) - self._wquant_matmul_for_qkv = func - self._wquant_matmul_for_o = func - self._wquant_matmul_for_ffn_up = func - self._wquant_matmul_for_ffn_down = func - if self.tp_rank_ == 0 and self.layer_num_ == 0: - logger.info("model use ppl_w4a16 kernel") - elif "flash_llm_w6a16" in self.mode: - func = partial(LlamaTransformerLayerInferWquant._wquant_matmul_fast_llm_int6weight_only_quant, self) - self._wquant_matmul_for_qkv = func - self._wquant_matmul_for_o = func - self._wquant_matmul_for_ffn_up = func - self._wquant_matmul_for_ffn_down = func - if self.tp_rank_ == 0 and self.layer_num_ == 0: - logger.info("model use flash_llm_w6a16 kernel") - else: - raise Exception(f"error mode {self.mode}") - return - - def _get_qkv( - self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeightQuantized - ): - q = self._wquant_matmul_for_qkv( - input.view(-1, self.embed_dim_), quant_weight_params=layer_weight.q_weight_, infer_state=infer_state - ) - cache_kv = self._wquant_matmul_for_qkv( - input.view(-1, self.embed_dim_), quant_weight_params=layer_weight.kv_weight_, infer_state=infer_state - ).view(-1, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, - ) - return q, cache_kv - - def _get_o( - self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeightQuantized - ) -> torch.Tensor: - o_tensor = self._wquant_matmul_for_o(input, quant_weight_params=layer_weight.o_weight_, infer_state=infer_state) - return o_tensor - - def _ffn( - self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeightQuantized - ) -> torch.Tensor: - gate_up_output = self._wquant_matmul_for_ffn_up( - input.view(-1, self.embed_dim_), quant_weight_params=layer_weight.gate_up_proj, infer_state=infer_state - ) - input = None - tp_inter_dim = self.inter_dim_ // self.world_size_ - gate_up_output = gate_up_output.view(-1, 2, tp_inter_dim) - torch.nn.functional.silu(gate_up_output[:, 0], inplace=True) - ffn1_out = gate_up_output[:, 0] * gate_up_output[:, 1] - gate_up_output = None - ffn2_out = self._wquant_matmul_for_ffn_down( - ffn1_out, quant_weight_params=layer_weight.down_proj, infer_state=infer_state - ) - ffn1_out = None - return ffn2_out - - def _wquant_matmul_triton_int8weight_only_quant( - self, input, quant_weight_params, infer_state: LlamaInferStateInfo, out=None, bias=None, has_act=False - ): - assert has_act is False - qweight, scale = quant_weight_params - out = matmul_dequantize_int8(input, qweight, scale, out=out) - if bias is None: - return out - else: - out.add_(bias) - return out - - def _wquant_matmul_triton_int4weight_only_quant( - self, input, quant_weight_params, infer_state: LlamaInferStateInfo, out=None, bias=None, has_act=False - ): - assert has_act is False - if infer_state.is_splitfuse is False and infer_state.is_prefill: - qweight, scale, zeros, int4_q_group_size = quant_weight_params - out = matmul_dequantize_int4_s1(input, qweight, scale, zeros, int4_q_group_size, out=out) - else: - qweight, scale, zeros, int4_q_group_size = quant_weight_params - out = matmul_dequantize_int4_gptq(input, qweight, scale, zeros, int4_q_group_size, output=out) - if bias is None: - return out - else: - out.add_(bias) - return out - - def _wquant_matmul_lmdeploy_int4weight_only_quant( - self, input, quant_weight_params, infer_state: LlamaInferStateInfo, out=None, bias=None, has_act=False - ): - assert has_act is False - qweight, scale_zeros, int4_q_group_size = quant_weight_params - out = matmul_dequantize_int4_lmdeploy(input, qweight, scale_zeros, int4_q_group_size) - if bias is None: - return out - else: - out.add_(bias) - return out - - def _wquant_matmul_ppl_int4weight_only_quant( - self, input, quant_weight_params, infer_state: LlamaInferStateInfo, out=None, bias=None, has_act=False - ): - assert has_act is False - qweight, qscale = quant_weight_params - out = matmul_dequantize_int4_ppl(input, qweight, qscale) - if bias is None: - return out - else: - out.add_(bias) - return out - - def _wquant_matmul_fast_llm_int6weight_only_quant( - self, input, quant_weight_params, infer_state: LlamaInferStateInfo, out=None, bias=None, has_act=False - ): - assert has_act is False - qweight, qscale = quant_weight_params - out = matmul_dequantize_int6_fast_llm(input, qweight, qscale) - if bias is None: - return out - else: - out.add_(bias) - return out diff --git a/lightllm/models/minicpm/layer_infer/transformer_layer_infer.py b/lightllm/models/minicpm/layer_infer/transformer_layer_infer.py deleted file mode 100755 index 9294ac2f..00000000 --- a/lightllm/models/minicpm/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch -import torch.functional as F -import numpy as np - -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.internlm.layer_weights.transformer_layer_weight import InternlmTransformerLayerWeight -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.llama.infer_struct import LlamaInferStateInfo - - -class InternlmTransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): - super().__init__(layer_num, tp_rank, world_size, network_config, mode) - return - - def _get_qkv( - self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeight - ) -> torch.Tensor: - q = torch.addmm( - layer_weight.q_bias_, input.view(-1, self.embed_dim_), layer_weight.q_weight_, beta=1.0, alpha=1.0 - ) - torch.addmm( - layer_weight.kv_bias_, - input.view(-1, self.embed_dim_), - layer_weight.kv_weight_, - beta=1.0, - alpha=1.0, - out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), - ) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, - ) - return q, cache_kv - - def _get_o( - self, input, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeight - ) -> torch.Tensor: - o_tensor = torch.addmm( - layer_weight.o_bias_, - input.view(-1, self.tp_o_head_num_ * self.head_dim_), - layer_weight.o_weight_, - beta=1.0 / self.world_size_, - ) - return o_tensor diff --git a/lightllm/models/qwen2/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen2/layer_infer/transformer_layer_infer.py deleted file mode 100644 index 06ccde42..00000000 --- a/lightllm/models/qwen2/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,43 +0,0 @@ -import torch -import torch.functional as F -import numpy as np -from typing import Tuple -import triton -from lightllm.models.llama.infer_struct import LlamaInferStateInfo - -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer, rotary_emb_fwd - -from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight - - -class Qwen2TransformerLayerInfer(LlamaTransformerLayerInfer): - def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): - super().__init__(layer_num, tp_rank, world_size, network_config, mode) - return - - def _get_qkv( - self, - input: torch.Tensor, - cache_kv: torch.Tensor, - infer_state: LlamaInferStateInfo, - layer_weight: Qwen2TransformerLayerWeight, - ) -> torch.Tensor: - input = input.view(-1, self.embed_dim_) - dtype = input.dtype - q = self.alloc_tensor((input.shape[0], layer_weight.q_weight_.shape[1]), dtype=dtype) - torch.addmm(layer_weight.q_bias_, input, layer_weight.q_weight_, beta=1.0, alpha=1.0, out=q) - torch.addmm( - layer_weight.kv_bias_, - input, - layer_weight.kv_weight_, - beta=1.0, - alpha=1.0, - out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), - ) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, - ) - return q, cache_kv diff --git a/lightllm/models/qwen2_wquant/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen2_wquant/layer_infer/transformer_layer_infer.py deleted file mode 100644 index e9a803d6..00000000 --- a/lightllm/models/qwen2_wquant/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,44 +0,0 @@ -import torch -import torch.functional as F -import numpy as np -from functools import partial - -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.llama_wquant.layer_infer.transformer_layer_infer import LlamaTransformerLayerInferWquant -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.qwen2_wquant.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeightQuantized -from lightllm.models.mistral.triton_kernel.context_flashattention_nopad import context_attention_fwd -from lightllm.models.mistral.triton_kernel.token_attention_nopad_att1 import token_att_fwd - - -class Qwen2TransformerLayerInferWQuant(LlamaTransformerLayerInferWquant): - """ """ - - def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): - super().__init__(layer_num, tp_rank, world_size, network_config, mode) - self.inter_dim_ = network_config["intermediate_size"] - return - - def _get_qkv( - self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: Qwen2TransformerLayerWeightQuantized - ): - q = self._wquant_matmul_for_qkv( - input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.q_weight_, - infer_state=infer_state, - bias=layer_weight.q_bias_, - ) - cache_kv = self._wquant_matmul_for_qkv( - input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.kv_weight_, - infer_state=infer_state, - bias=layer_weight.kv_bias_, - ).view(-1, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, - ) - return q, cache_kv diff --git a/lightllm/models/qwen_wquant/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen_wquant/layer_infer/transformer_layer_infer.py deleted file mode 100755 index 612708dc..00000000 --- a/lightllm/models/qwen_wquant/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -import torch.functional as F -import numpy as np - -from lightllm.models.llama_wquant.layer_infer.transformer_layer_infer import LlamaTransformerLayerInferWquant -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.qwen_wquant.layer_weights.transformer_layer_weight import QwenTransformerLayerWeightQuantized -from lightllm.models.qwen.infer_struct import QwenInferStateInfo - - -class QwenTransformerLayerInferWQuant(LlamaTransformerLayerInferWquant): - """ """ - - def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): - super().__init__(layer_num, tp_rank, world_size, network_config, mode) - self.inter_dim_ = network_config["intermediate_size"] // 2 # qwen 的 inter_dim 要 // 2 - return - - def _get_qkv( - self, input, cache_kv, infer_state: QwenInferStateInfo, layer_weight: QwenTransformerLayerWeightQuantized - ): - q = self._wquant_matmul_for_qkv( - input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.q_weight_, - infer_state=infer_state, - bias=layer_weight.q_bias_, - ) - cache_kv = self._wquant_matmul_for_qkv( - input.view(-1, self.embed_dim_), - quant_weight_params=layer_weight.kv_weight_, - infer_state=infer_state, - bias=layer_weight.kv_bias_, - ).view(-1, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_) - rotary_emb_fwd( - q.view(-1, self.tp_q_head_num_, self.head_dim_), - cache_kv[:, 0 : self.tp_k_head_num_, :], - infer_state.position_cos, - infer_state.position_sin, - ) - if infer_state.logn_values is not None: - q.mul_(infer_state.logn_values.view(-1, 1)) - return q, cache_kv diff --git a/lightllm/models/starcoder_wquant/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder_wquant/layer_infer/transformer_layer_infer.py deleted file mode 100755 index 852799b4..00000000 --- a/lightllm/models/starcoder_wquant/layer_infer/transformer_layer_infer.py +++ /dev/null @@ -1,91 +0,0 @@ -from typing import Tuple - -import torch -import torch.functional as F -import numpy as np -from functools import partial - -from lightllm.common.basemodel.triton_kernel.quantize_gemm_int8 import matmul_quantize_int8 -from lightllm.common.basemodel.triton_kernel.dequantize_gemm_int8 import matmul_dequantize_int8 -from lightllm.common.basemodel.triton_kernel.dequantize_gemm_int4 import ( - matmul_dequantize_int4_s1, - matmul_dequantize_int4_s2, - matmul_dequantize_int4_gptq, -) -from lightllm.models.starcoder_wquant.layer_weights.transformer_layer_weight import ( - StarcoderTransformerLayerWeightQuantized, -) -from lightllm.utils.infer_utils import mark_cost_time -from lightllm.models.starcoder.infer_struct import StarcoderInferStateInfo -from lightllm.common.basemodel import TransformerLayerInferWeightQuantTpl -from lightllm.models.bloom.layer_infer.transformer_layer_infer import BloomTransformerLayerInfer -from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.llama_wquant.layer_infer.transformer_layer_infer import LlamaTransformerLayerInferWquant - - -class StarcoderTransformerLayerInferWQuant(TransformerLayerInferWeightQuantTpl): - """ """ - - def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): - super().__init__(layer_num, tp_rank, world_size, network_config, mode) - self.eps_ = network_config["layer_norm_epsilon"] - self.tp_q_head_num_ = network_config["num_attention_heads"] // self.world_size_ - self.tp_k_head_num_ = 1 - self.tp_v_head_num_ = 1 - self.tp_o_head_num_ = self.tp_q_head_num_ - self.head_dim_ = network_config["n_embed"] // network_config["num_attention_heads"] - self.embed_dim_ = network_config["n_embed"] - self._bind_func() - return - - def _bind_func(self): - self._att_norm = partial(BloomTransformerLayerInfer._att_norm, self) - self._ffn_norm = partial(BloomTransformerLayerInfer._ffn_norm, self) - - LlamaTransformerLayerInferWquant._bind_matmul(self) - LlamaTransformerLayerInfer._bind_attention(self) - return - - def _get_qkv( - self, - input, - cache_kv, - infer_state: StarcoderInferStateInfo, - layer_weight: StarcoderTransformerLayerWeightQuantized, - ) -> torch.Tensor: - q = self._wquant_matmul_for_qkv( - input.view(-1, self.embed_dim_), layer_weight.q_weight_, infer_state=infer_state, bias=layer_weight.q_bias_ - ) - cache_kv = self._wquant_matmul_for_qkv( - input.view(-1, self.embed_dim_), - layer_weight.kv_weight_, - infer_state=infer_state, - bias=layer_weight.kv_bias_, - ).view(-1, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_) - return q, cache_kv - - def _get_o( - self, input, infer_state: StarcoderInferStateInfo, layer_weight: StarcoderTransformerLayerWeightQuantized - ) -> torch.Tensor: - o_output = self._wquant_matmul_for_o( - input, layer_weight.o_weight_, infer_state=infer_state, bias=layer_weight.o_bias_ - ) - return o_output - - def _ffn( - self, input, infer_state: StarcoderInferStateInfo, layer_weight: StarcoderTransformerLayerWeightQuantized - ) -> torch.Tensor: - ffn1_out = self._wquant_matmul_for_ffn_up( - input.view(-1, self.embed_dim_), - layer_weight.ffn_1_weight_, - infer_state=infer_state, - bias=layer_weight.ffn_1_bias_, - ) - input = None - gelu_out = torch.nn.functional.gelu(ffn1_out, approximate="tanh") - ffn1_out = None - ffn2_out = self._wquant_matmul_for_ffn_down( - gelu_out, layer_weight.ffn_2_weight_, infer_state=infer_state, bias=layer_weight.ffn_2_bias_ - ) - gelu_out = None - return ffn2_out