Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Deepseekv3 broke quantization for all other methods #11547

Merged
merged 2 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,20 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
raise NotImplementedError

@abstractmethod
def apply(self, layer: torch.nn.Module, x: torch.Tensor,
router_logits: torch.Tensor, top_k: int, renormalize: bool,
use_grouped_topk: bool) -> torch.Tensor:
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
raise NotImplementedError


Expand Down Expand Up @@ -79,7 +90,7 @@ def apply(
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
Expand Down
10 changes: 7 additions & 3 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,13 @@ def apply(
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
renormalize: bool,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
Expand All @@ -454,7 +456,9 @@ def apply(
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)

return torch.ops.vllm.fused_marlin_moe(
x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,14 @@ def apply(
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
renormalize: bool,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:

from vllm.model_executor.layers.fused_moe import fused_experts

topk_weights, topk_ids = FusedMoE.select_experts(
Expand All @@ -220,7 +221,9 @@ def apply(
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)

return fused_experts(x,
layer.w13_weight,
Expand Down Expand Up @@ -476,12 +479,15 @@ def apply(
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
renormalize: bool,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:

topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
Expand All @@ -490,7 +496,9 @@ def apply(
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)

return torch.ops.vllm.fused_marlin_moe(
x,
Expand Down
10 changes: 7 additions & 3 deletions vllm/model_executor/layers/quantization/experts_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,13 @@ def apply(
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
renormalize: bool,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts

Expand All @@ -115,7 +117,9 @@ def apply(
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)

return fused_experts(x,
layer.w13_weight,
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,14 +601,13 @@ def apply(
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:

from vllm.model_executor.layers.fused_moe import fused_experts

topk_weights, topk_ids = FusedMoE.select_experts(
Expand Down
10 changes: 7 additions & 3 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,11 +532,13 @@ def apply(
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
renormalize: bool,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# The input must currently be float16
orig_dtype = x.dtype
Expand All @@ -550,7 +552,9 @@ def apply(
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=None)
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)

return torch.ops.vllm.fused_marlin_moe(
x,
Expand Down
Loading