Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arctic hao 0504 #11

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
You can download the dataset by running:
```bash
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
```
```
26 changes: 26 additions & 0 deletions examples/offline_inference_arctic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
arctic_model_path = "/checkpoint/arctic"
llm = LLM(model=arctic_model_path,
quantization="deepspeedfp",
tensor_parallel_size=8)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.

outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
60 changes: 60 additions & 0 deletions examples/save_state_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import argparse
import dataclasses
import os
import shutil
from pathlib import Path

from vllm import LLM, EngineArgs

"""
Example usage

python save_state_dict.py \
--model /path/to/load \
--quantization deepspeedfp \
--tensor-parallel-size 8 \
--output /path/to/save

Then, the model can be loaded with

llm = LLM(
model="/path/to/save",
load_format="state_dict",
quantization="deepspeedfp",
tensor_parallel_size=8,
)
"""

parser = argparse.ArgumentParser()
EngineArgs.add_cli_args(parser)
parser.add_argument("--output",
"-o",
required=True,
type=str,
help="path to output checkpoint")


def main(args):
engine_args = EngineArgs.from_cli_args(args)
model_path = engine_args.model
if not Path(model_path).is_dir():
raise ValueError("model path must be a local directory")
# Create LLM instance from arguments
llm = LLM(**dataclasses.asdict(engine_args))
# Prepare output directory
Path(args.output).mkdir(exist_ok=True)
# Dump worker states to output directory
model_executor = llm.llm_engine.model_executor
model_executor._run_workers("save_model",
path=args.output,
max_size=5 * 1024**3)
# Copy metadata files to output directory
for file in os.listdir(model_path):
if not any(
file.endswith(ext) for ext in (".bin", ".pt", ".safetensors")):
shutil.copy(f"{model_path}/{file}", args.output)


if __name__ == "__main__":
args = parser.parse_args()
main(args)
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ class LoadFormat(str, enum.Enum):
NPCACHE = "npcache"
DUMMY = "dummy"
TENSORIZER = "tensorizer"
STATE_DICT = "state_dict"


@dataclass
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_moe, get_config_file_name)
fused_experts, fused_moe, fused_topk, get_config_file_name)

__all__ = [
"fused_moe",
"fused_topk",
"fused_experts",
"get_config_file_name",
]
137 changes: 90 additions & 47 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,60 +308,16 @@ def get_moe_configs(E: int, N: int,
return None


def fused_moe(
def fused_topk(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.

Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.

Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]

M, _ = hidden_states.shape
E, N, _ = w1.shape

if is_hip():
# The MoE kernels are not yet supported on ROCm.
Expand Down Expand Up @@ -393,6 +349,33 @@ def fused_moe(
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids


def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None):
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]

M, _ = hidden_states.shape
E, N, _ = w1.shape

if override_config:
config = override_config
Expand Down Expand Up @@ -477,3 +460,63 @@ def fused_moe(
out=hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)


def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.

Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.

Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"

topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
return fused_experts(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace=inplace,
override_config=override_config,
use_fp8=use_fp8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
Expand All @@ -19,6 +21,7 @@
"squeezellm": SqueezeLLMConfig,
"gptq_marlin": GPTQMarlinConfig,
"marlin": MarlinConfig,
"deepspeedfp": DeepSpeedFPConfig
}


Expand Down
Loading
Loading