Skip to content

Commit

Permalink
Add weight padding for moe (opendatahub-io#119)
Browse files Browse the repository at this point in the history
* add weight padding for moe

* enable padding by default

* fix linter

* fix linter

* fix linter

* using envs.py

* fix linter
  • Loading branch information
charlifu authored Aug 2, 2024
1 parent 42b1b9a commit 5fac73f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
VERBOSE: bool = False
VLLM_SYNC_SERVER_ACCUM_REQUESTS: int = 1
VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS: int = 1
VLLM_MOE_PADDING: bool = True

# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
Expand Down Expand Up @@ -229,6 +230,10 @@
# Poll for new requests every this many steps
"VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS":
lambda: int(os.getenv("VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS", "1")),

# Pad the weight for moe kernel or not
"VLLM_MOE_PADDING":
lambda: bool(int(os.getenv("VLLM_MOE_PADDING", "1"))),
}

# end-env-vars-definition
Expand Down
9 changes: 6 additions & 3 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

import vllm._moe_C as moe_kernels
from vllm import _custom_ops as ops
from vllm import envs
from vllm.logger import init_logger

logger = init_logger(__name__)
padding_size = 128 if envs.VLLM_MOE_PADDING else 0


@triton.jit
Expand Down Expand Up @@ -262,7 +264,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2],
B.shape[2] - padding_size,
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
Expand Down Expand Up @@ -365,7 +367,8 @@ def fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None):
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert hidden_states.shape[
1] == w1.shape[2] - padding_size, "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
Expand All @@ -381,7 +384,7 @@ def fused_experts(hidden_states: torch.Tensor,
config = override_config
else:
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2],
configs = get_moe_configs(E, w2.shape[2] - padding_size,
"float8" if use_fp8 else None)

if configs:
Expand Down
9 changes: 9 additions & 0 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
from typing import Iterable, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from transformers import MixtralConfig

from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -181,6 +183,13 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
def process_weights_after_loading(self):
# Fp8 is the only case where we need to process after loading.
if not self.use_fp8:
if envs.VLLM_MOE_PADDING:
self.w13_weight = nn.Parameter(F.pad(self.w13_weight.data,
(0, 128), "constant", 0),
requires_grad=False)
self.w2_weight = nn.Parameter(F.pad(self.w2_weight.data,
(0, 128), "constant", 0),
requires_grad=False)
return

# If checkpoint is fp16, quantize here.
Expand Down

0 comments on commit 5fac73f

Please sign in to comment.