From ab1286842df4b9cf2bd4971fa4fbc4c27b8e7f21 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Fri, 27 Dec 2024 01:23:29 -0500 Subject: [PATCH] [BugFix] Fix quantization for all other methods (#11547) --- vllm/model_executor/layers/fused_moe/layer.py | 19 ++++++++++++---- .../layers/quantization/awq_marlin.py | 10 ++++++--- .../compressed_tensors_moe.py | 22 +++++++++++++------ .../layers/quantization/experts_int8.py | 10 ++++++--- .../model_executor/layers/quantization/fp8.py | 3 +-- .../layers/quantization/gptq_marlin.py | 10 ++++++--- 6 files changed, 52 insertions(+), 22 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 01ffac4550f28..b108cbd52c218 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -41,9 +41,20 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, raise NotImplementedError @abstractmethod - def apply(self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, top_k: int, renormalize: bool, - use_grouped_topk: bool) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: raise NotImplementedError @@ -79,7 +90,7 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, - use_grouped_topk: bool, + use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 4d1a837d11585..c28fd0c6737e0 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -440,11 +440,13 @@ def apply( x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool = True, + renormalize: bool, use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -454,7 +456,9 @@ def apply( renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function) + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) return torch.ops.vllm.fused_marlin_moe( x, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index dad04017d3212..5fd6b017f444b 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -203,13 +203,14 @@ def apply( x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool = True, + renormalize: bool, use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( @@ -220,7 +221,9 @@ def apply( renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function) + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) return fused_experts(x, layer.w13_weight, @@ -476,12 +479,15 @@ def apply( x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool = True, + renormalize: bool, use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -490,7 +496,9 @@ def apply( renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function) + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) return torch.ops.vllm.fused_marlin_moe( x, diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 97297970d9317..209f12c6dfec9 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -99,11 +99,13 @@ def apply( x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool = True, + renormalize: bool, use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -115,7 +117,9 @@ def apply( renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function) + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) return fused_experts(x, layer.w13_weight, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 4362468c1db69..7f779ac8d3b3e 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -601,14 +601,13 @@ def apply( router_logits: torch.Tensor, top_k: int, renormalize: bool, - use_grouped_topk: bool, + use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index a3e58bf1b2a4c..a006d729cc627 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -532,11 +532,13 @@ def apply( x: torch.Tensor, router_logits: torch.Tensor, top_k: int, - renormalize: bool = True, + renormalize: bool, use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: # The input must currently be float16 orig_dtype = x.dtype @@ -550,7 +552,9 @@ def apply( renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, - custom_routing_function=None) + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) return torch.ops.vllm.fused_marlin_moe( x,