diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 379a67c4c8cf8..ea8b3d46f1b3f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -70,7 +70,7 @@ steps: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - + - label: Core Test # 10min mirror_hardwares: [amd] fast_check: true @@ -90,8 +90,11 @@ steps: commands: - pip install -e ./plugins/vllm_add_dummy_model - pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@a4987bba6e9e9b3f22bd3a6c1ecf0abd04fd5622#egg=lm_eval[api] - - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py + - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process + - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process + - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process + - pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process - pytest -v -s entrypoints/openai - pytest -v -s entrypoints/test_chat_utils.py - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests @@ -207,6 +210,21 @@ steps: command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py parallelism: 4 +- label: "PyTorch Fullgraph Smoke Test" + fast_check: true + source_file_dependencies: + - vllm/ + - tests/compile + commands: + - pytest -v -s compile/test_full_graph_smoke.py + +- label: "PyTorch Fullgraph Test" + source_file_dependencies: + - vllm/ + - tests/compile + commands: + - pytest -v -s compile/test_full_graph.py + - label: Kernels Test %N # 30min each mirror_hardwares: [amd] source_file_dependencies: @@ -352,7 +370,7 @@ steps: - tests/distributed/ - vllm/compilation commands: - - pytest -v -s ./compile/test_full_graph.py + - pytest -v -s ./compile/test_full_graph_multi_gpu.py - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus diff --git a/.gitignore b/.gitignore index bc7236ea18698..abeaf0a82e303 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ -# vllm commit id, generated by setup.py -vllm/commit_id.py +# version file generated by setuptools-scm +/vllm/_version.py # vllm-flash-attn built from source vllm/vllm_flash_attn/ diff --git a/CMakeLists.txt b/CMakeLists.txt index dcf8bd3f58031..9144219249c88 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -236,6 +236,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/fp8/fp8_marlin.cu" "csrc/custom_all_reduce.cu" + "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") @@ -333,6 +334,11 @@ set(VLLM_MOE_EXT_SRC if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC + "csrc/moe/marlin_kernels/marlin_moe_kernel.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" "csrc/moe/marlin_moe_ops.cu") endif() diff --git a/Dockerfile b/Dockerfile index 30e27620574a0..6bb4bd032c39c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -79,15 +79,13 @@ ENV MAX_JOBS=${max_jobs} ARG nvcc_threads=8 ENV NVCC_THREADS=$nvcc_threads -ARG buildkite_commit -ENV BUILDKITE_COMMIT=${buildkite_commit} - ARG USE_SCCACHE ARG SCCACHE_BUCKET_NAME=vllm-build-sccache ARG SCCACHE_REGION_NAME=us-west-2 ARG SCCACHE_S3_NO_CREDENTIALS=0 # if USE_SCCACHE is set, use sccache to speed up compilation RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=.git,target=.git \ if [ "$USE_SCCACHE" = "1" ]; then \ echo "Installing sccache..." \ && curl -L -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz \ @@ -107,6 +105,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/ccache \ --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=.git,target=.git \ if [ "$USE_SCCACHE" != "1" ]; then \ python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \ fi @@ -203,7 +202,7 @@ FROM vllm-base AS vllm-openai # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ - pip install accelerate hf_transfer 'modelscope!=1.15.0' + pip install accelerate hf_transfer 'modelscope!=1.15.0' bitsandbytes>=0.44.0 timm==0.9.10 ENV VLLM_USAGE_SOURCE production-docker-image diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 4d7289366296b..a9d97a3e0bde4 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -62,8 +62,10 @@ ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512} RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/ccache \ + --mount=type=bind,source=.git,target=.git \ VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \ - pip install dist/*.whl + pip install dist/*.whl && \ + rm -rf dist WORKDIR /workspace/ diff --git a/Dockerfile.neuron b/Dockerfile.neuron index 647ed99a41e70..adae6db87ba87 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -6,9 +6,12 @@ FROM $BASE_IMAGE RUN echo "Base image is $BASE_IMAGE" # Install some basic utilities -RUN apt-get update \ - && apt-get install python3 python3-pip -y \ - && apt-get install -y ffmpeg libsm6 libxext6 libgl1 +RUN apt-get update && \ + apt-get install -y \ + git \ + python3 \ + python3-pip \ + ffmpeg libsm6 libxext6 libgl1 ### Mount Point ### # When launching the container, mount the code directory to /app @@ -22,17 +25,17 @@ RUN python3 -m pip install sentencepiece transformers==4.36.2 -U RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U RUN python3 -m pip install --pre neuronx-cc==2.15.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U -COPY ./vllm /app/vllm/vllm -COPY ./setup.py /app/vllm/setup.py -COPY ./requirements-common.txt /app/vllm/requirements-common.txt -COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt +COPY . /app/vllm RUN cd /app/vllm \ - && python3 -m pip install -U -r requirements-neuron.txt + && python3 -m pip install -U \ + cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + -r requirements-neuron.txt ENV VLLM_TARGET_DEVICE neuron -RUN cd /app/vllm \ - && pip install -e . \ +RUN --mount=type=bind,source=.git,target=.git \ + cd /app/vllm \ + && pip install --no-build-isolation -v -e . \ && cd .. CMD ["/bin/bash"] diff --git a/Dockerfile.openvino b/Dockerfile.openvino index 96b9593a2bfa8..95714a3d17188 100644 --- a/Dockerfile.openvino +++ b/Dockerfile.openvino @@ -4,8 +4,9 @@ FROM ubuntu:22.04 AS dev RUN apt-get update -y && \ - apt-get install -y python3-pip git && \ - apt-get install -y ffmpeg libsm6 libxext6 libgl1 + apt-get install -y \ + git python3-pip \ + ffmpeg libsm6 libxext6 libgl1 WORKDIR /workspace # copy requirements diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le index 3313162bf28e1..1f374b01b9bc0 100644 --- a/Dockerfile.ppc64le +++ b/Dockerfile.ppc64le @@ -16,9 +16,15 @@ COPY ./ /workspace/vllm WORKDIR /workspace/vllm # These packages will be in rocketce eventually -RUN pip install -v cmake xformers torch==2.3.1 uvloop==0.20.0 -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing - -RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \ + cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + torch==2.3.1 \ + -r requirements-cpu.txt \ + xformers uvloop==0.20.0 + +RUN --mount=type=bind,source=.git,target=.git \ + VLLM_TARGET_DEVICE=cpu python3 setup.py install WORKDIR /workspace/ diff --git a/Dockerfile.tpu b/Dockerfile.tpu index 04cd4d79f4045..d8f1a42c45177 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -5,16 +5,25 @@ FROM $BASE_IMAGE WORKDIR /workspace # Install some basic utilities -RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 libgl1 +RUN apt-get update && apt-get install -y \ + git \ + ffmpeg libsm6 libxext6 libgl1 # Install the TPU and Pallas dependencies. -RUN python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html -RUN python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html # Build vLLM. COPY . /workspace/vllm ENV VLLM_TARGET_DEVICE="tpu" -RUN cd /workspace/vllm && python3 -m pip install -r requirements-tpu.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=.git,target=.git \ + cd /workspace/vllm && \ + python3 -m pip install \ + cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + -r requirements-tpu.txt RUN cd /workspace/vllm && python3 setup.py develop CMD ["/bin/bash"] diff --git a/Dockerfile.xpu b/Dockerfile.xpu index 8f61e4c55260e..8471edd16e4bb 100644 --- a/Dockerfile.xpu +++ b/Dockerfile.xpu @@ -7,15 +7,20 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \ chmod 644 /usr/share/keyrings/intel-graphics.gpg -RUN apt-get update -y \ -&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip ffmpeg libsm6 libxext6 libgl1 +RUN apt-get update -y && \ + apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip ffmpeg libsm6 libxext6 libgl1 COPY ./ /workspace/vllm WORKDIR /workspace/vllm -RUN pip install -v -r requirements-xpu.txt --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -v --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ \ + cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + -r requirements-xpu.txt -RUN VLLM_TARGET_DEVICE=xpu python3 setup.py install +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=.git,target=.git \ + VLLM_TARGET_DEVICE=xpu python3 setup.py install CMD ["/bin/bash"] diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 0dfba9b862b53..244222f8d0966 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -11,7 +11,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs -from vllm.inputs import PromptType +from vllm.inputs import PromptInputs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -62,7 +62,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_prompts: List[PromptType] = [{ + dummy_inputs: List[PromptInputs] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] @@ -75,13 +75,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(dummy_prompts, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(dummy_prompts, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py new file mode 100644 index 0000000000000..0ba29fabca59b --- /dev/null +++ b/benchmarks/benchmark_prioritization.py @@ -0,0 +1,295 @@ +"""Benchmark offline prioritization.""" +import argparse +import json +import random +import time +from typing import List, Optional, Tuple + +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + + +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int], +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data["conversations"][0]["value"], + data["conversations"][1]["value"]) for data in dataset] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + + #Select a equi-probable random priority + priority = 0 if random.random() < 0.5 else 1 + + filtered_dataset.append((prompt, prompt_len, output_len, priority)) + + return filtered_dataset + + +def run_vllm( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: str, + quantization: Optional[str], + tensor_parallel_size: int, + seed: int, + n: int, + use_beam_search: bool, + trust_remote_code: bool, + dtype: str, + max_model_len: Optional[int], + enforce_eager: bool, + kv_cache_dtype: str, + quantization_param_path: Optional[str], + device: str, + enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, + gpu_memory_utilization: float = 0.9, + download_dir: Optional[str] = None, +) -> float: + from vllm import LLM, SamplingParams + llm = LLM( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + disable_log_stats=False, + ) + + # Add the requests to the engine. + prompts = [] + sampling_params = [] + priority = [] + for prompt, _, output_len, _priority in requests: + prompts.append(prompt) + priority.append(_priority) + sampling_params.append( + SamplingParams( + n=n, + temperature=0.0 if use_beam_search else 1.0, + top_p=1.0, + use_beam_search=use_beam_search, + ignore_eos=True, + max_tokens=output_len, + )) + + start = time.perf_counter() + llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True) + end = time.perf_counter() + return end - start + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + if args.dataset is None: + # Synthesize a prompt with the given input length. + prompt = "hi" * (args.input_len - 1) + requests = [(prompt, args.input_len, args.output_len) + for _ in range(args.num_prompts)] + else: + requests = sample_requests(args.dataset, args.num_prompts, tokenizer, + args.output_len) + + if args.backend == "vllm": + elapsed_time = run_vllm( + requests, args.model, args.tokenizer, args.quantization, + args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, + args.trust_remote_code, args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, args.enable_chunked_prefill, + args.max_num_batched_tokens, args.gpu_memory_utilization, + args.download_dir) + else: + raise ValueError(f"Unknown backend: {args.backend}") + total_num_tokens = sum(prompt_len + output_len + for _, prompt_len, output_len, priority in requests) + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} tokens/s") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark the throughput.") + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii"], + default="vllm") + parser.add_argument("--dataset", + type=str, + default=None, + help="Path to the dataset.") + parser.add_argument("--input-len", + type=int, + default=None, + help="Input prompt length for each request") + parser.add_argument("--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--model", type=str, default="facebook/opt-125m") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument('--quantization', + '-q', + choices=[*QUANTIZATION_METHODS, None], + default=None) + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--num-prompts", + type=int, + default=200, + help="Number of prompts to process.") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument('--trust-remote-code', + action='store_true', + help='trust remote code from huggingface') + parser.add_argument( + '--max-model-len', + type=int, + default=None, + help='Maximum length of a sequence (including prompt and output). ' + 'If None, will be derived from the model.') + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], + help='data type for model weights and activations. ' + 'The "auto" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models.') + parser.add_argument('--gpu-memory-utilization', + type=float, + default=0.9, + help='the fraction of GPU memory to be used for ' + 'the model executor, which can range from 0 to 1.' + 'If unspecified, will use the default value of 0.9.') + parser.add_argument("--enforce-eager", + action="store_true", + help="enforce eager execution") + parser.add_argument( + '--kv-cache-dtype', + type=str, + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + default="auto", + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + parser.add_argument( + '--quantization-param-path', + type=str, + default=None, + help='Path to the JSON file containing the KV cache scaling factors. ' + 'This should generally be supplied, when KV cache dtype is FP8. ' + 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' + 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' + 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' + 'instead supported for common inference criteria.') + parser.add_argument( + "--device", + type=str, + default="cuda", + choices=["cuda", "cpu"], + help='device type for vLLM execution, supporting CUDA and CPU.') + parser.add_argument( + "--enable-prefix-caching", + action='store_true', + help="enable automatic prefix caching for vLLM backend.") + parser.add_argument("--enable-chunked-prefill", + action='store_true', + help="enable chunked prefill for vLLM backend.") + parser.add_argument('--max-num-batched-tokens', + type=int, + default=None, + help='maximum number of batched tokens per ' + 'iteration') + parser.add_argument('--download-dir', + type=str, + default=None, + help='directory to download and load the weights, ' + 'default to the default cache dir of huggingface') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + + args = parser.parse_args() + if args.tokenizer is None: + args.tokenizer = args.model + if args.dataset is None: + assert args.input_len is not None + assert args.output_len is not None + else: + assert args.input_len is None + + main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index e1a5d4ee28ea1..68b401d5bbbb7 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -90,6 +90,7 @@ def run_vllm( download_dir: Optional[str] = None, load_format: str = EngineArgs.load_format, disable_async_output_proc: bool = False, + use_new_beam_search_impl: bool = False, ) -> float: from vllm import LLM, SamplingParams llm = LLM( @@ -132,9 +133,23 @@ def run_vllm( max_tokens=output_len, )) - start = time.perf_counter() - llm.generate(prompts, sampling_params, use_tqdm=True) - end = time.perf_counter() + if not use_new_beam_search_impl: + start = time.perf_counter() + llm.generate(prompts, sampling_params, use_tqdm=True) + end = time.perf_counter() + else: + assert use_beam_search + prompts = [prompt for prompt, _, _ in requests] + # output_len should be the same for all requests. + output_len = requests[0][2] + for prompt, input_len, _output_len in requests: + assert _output_len == output_len + start = time.perf_counter() + llm.beam_search(prompts, + beam_width=n, + max_tokens=output_len, + ignore_eos=True) + end = time.perf_counter() return end - start @@ -336,7 +351,7 @@ def main(args: argparse.Namespace): run_args.append(args.disable_frontend_multiprocessing) elapsed_time = uvloop.run(run_vllm_async(*run_args)) else: - elapsed_time = run_vllm(*run_args) + elapsed_time = run_vllm(*run_args, args.use_new_beam_search_impl) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -396,6 +411,7 @@ def main(args: argparse.Namespace): default=1, help="Number of generated sequences per prompt.") parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--use-new-beam-search-impl", action="store_true") parser.add_argument("--num-prompts", type=int, default=1000, diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index ca45cba6f8165..b70c4b94c97a1 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -4,8 +4,10 @@ import math import pickle as pkl import time -from typing import Callable, Iterable, List, Tuple +from itertools import product +from typing import Callable, Iterable, List, Optional, Tuple +import pandas as pd import torch import torch.utils.benchmark as TBenchmark from torch.utils.benchmark import Measurement as TMeasurement @@ -84,6 +86,10 @@ def loop_over_weights( fn(a, w_ref, w_q, w_s) +_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None +_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None + + def bench(atype: torch.dtype, wtype: ScalarType, group_size: int, @@ -94,6 +100,8 @@ def bench(atype: torch.dtype, sub_label: str, benchmark_marlinv1: bool = True, sweep_schedules: bool = True) -> Iterable[TMeasurement]: + global _SWEEP_SCHEDULES_RESULTS + a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k) sub_label += f", L={len(weights)}" @@ -163,6 +171,11 @@ def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor: best_schedule = None schedules = ops.machete_supported_schedules(wtype) for schedule in reversed(schedules): + schedule_M = int(schedule.split("_")[0].split("x")[1]) + + # Prune known bad schedules + if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4: + continue def run(a, _, w_q, w_s, schedule=schedule): ops.machete_gemm(a, @@ -175,6 +188,20 @@ def run(a, _, w_q, w_s, schedule=schedule): res = bench_fn(label, sub_label, "machete_best", lambda: loop_over_weights(a, weights_machete, run)) + results_row = { + "M": m, + "K": k, + "N": n, + "group_size": group_size, + "schedule": schedule, + "median": res.median, + } + if _SWEEP_SCHEDULES_RESULTS is None: + _SWEEP_SCHEDULES_RESULTS = pd.DataFrame( + columns=results_row.keys()) + _SWEEP_SCHEDULES_RESULTS.\ + loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row + print(f" {res.median:5.5} ", schedule) if not best or res.median < best.median: best = res @@ -235,18 +262,22 @@ def run_square_bench(args): dim_sizes = list( range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, args.sweep_schedules, MKNs) make_output(data, MKNs, f"square_bench-{args.dtype}") def run_range_bench(args): - dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) - n = len(dim_sizes) - Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes - Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes - Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes - MKNs = list(zip(Ms, Ks, Ns)) + m_start, k_start, n_start = [int(x) for x in args.dim_start.split(",")] + m_end, k_end, n_end = [int(x) for x in args.dim_end.split(",")] + m_increment, k_increment, n_increment = \ + [int(x) for x in args.dim_increment.split(",")] + Ms = list(range(m_start, m_end + 1, m_increment)) + Ks = list(range(k_start, k_end + 1, k_increment)) + Ns = list(range(n_start, n_end + 1, n_increment)) + MKNs = list(product(Ms, Ks, Ns)) + data = run(args.dtype, args.sweep_schedules, MKNs) make_output(data, MKNs, f"range_bench-{args.dtype}") @@ -333,6 +364,9 @@ def to_torch_dtype(dt): action="store_true", help="Run a sweep over all supported schedules", ) + parser.add_argument("--sweep-csv-out", + help="CSV to store sweep results", + default="sch_sweep_results.csv") subparsers = parser.add_subparsers(dest="cmd", required=True) square_parser = subparsers.add_parser("square_bench") @@ -342,12 +376,21 @@ def to_torch_dtype(dt): square_parser.set_defaults(func=run_square_bench) range_parser = subparsers.add_parser("range_bench") - range_parser.add_argument("--dim-start", type=int, required=True) - range_parser.add_argument("--dim-end", type=int, required=True) - range_parser.add_argument("--dim-increment", type=int, required=True) - range_parser.add_argument("--m-constant", type=int, default=None) - range_parser.add_argument("--n-constant", type=int, default=None) - range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.add_argument( + "--dim-start", + type=str, + required=True, + help="Start value for M,K,N as common separated list") + range_parser.add_argument( + "--dim-end", + type=str, + required=True, + help="End value (inclusive) for M,K,N as common separated list") + range_parser.add_argument( + "--dim-increment", + type=str, + required=True, + help="Increment value for M,K,N as common separated list") range_parser.set_defaults(func=run_range_bench) model_parser = subparsers.add_parser("model_bench") @@ -369,4 +412,9 @@ def to_torch_dtype(dt): model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() + + _SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out args.func(args) + + if _SWEEP_SCHEDULES_RESULTS is not None: + _SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV) diff --git a/benchmarks/kernels/requirements.txt b/benchmarks/kernels/requirements.txt new file mode 100644 index 0000000000000..1411a4a0b5ab8 --- /dev/null +++ b/benchmarks/kernels/requirements.txt @@ -0,0 +1 @@ +pandas \ No newline at end of file diff --git a/csrc/cutlass_extensions/torch_utils.hpp b/csrc/cutlass_extensions/torch_utils.hpp index 1618a340ce10e..2c78572521eec 100644 --- a/csrc/cutlass_extensions/torch_utils.hpp +++ b/csrc/cutlass_extensions/torch_utils.hpp @@ -68,7 +68,13 @@ static inline auto make_cute_layout(torch::Tensor const& tensor, name, ".stride(", idx, ") to be ", StrideEle::value); return StrideEle{}; } else { - return tensor.stride(idx); + if (tensor.size(idx) == 1) { + // use 0 stride for dim with size 1, this is easier for + // cute/cutlass to optimize (helps the TMA code flatten dims) + return StrideEle{0}; + } else { + return tensor.stride(idx); + } } } else { // Extra strides are assumed to be 0 or 1 diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index df968dda92adc..d7829f5d583d4 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -586,7 +586,7 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { selective_scan_fwd_cuda(params, stream); }); - std::vector result = {out, x.value()}; + std::vector result = {out}; if (has_z) { result.push_back(out_z); } return result; } diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel.h b/csrc/moe/marlin_kernels/marlin_moe_kernel.h new file mode 100644 index 0000000000000..0bd3017226c94 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel.h @@ -0,0 +1,1425 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include + +#include "core/scalar_type.hpp" + +namespace marlin_moe { + +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, + FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline FragB dequant(int q); + +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +template <> +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 +// Reference: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Given 2 floats multiply by 2 scales (halves) +__device__ inline void scale_float(float* c, FragS& s) { + __half* s_ptr = reinterpret_cast<__half*>(&s); + c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +} + +// Same as above, but for act_order (each K is multiplied individually) +__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, + FragS& frag_s_3, FragS& frag_s_4, int i) { + __half2 s_val_1_2; + s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; + + __half2 s_val_3_4; + s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__device__ inline void MarlinMoESingle( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block // current m block to start kernel computation from +) { + static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + constexpr int pack_factor = 32 / w_type.size_bits(); + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + sorted_ids += 16 * thread_m_blocks; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + constexpr int sorted_sh_stride = threads; + constexpr int sorted_gl_stride = threads; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + int shs_size; + if constexpr (has_act_order) + shs_size = sh_max_num_groups * s_sh_stride + threads; + else + shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_s = sh_g_idx + (stages * g_idx_stage); + int* sh_sorted = (int*)(sh_s + shs_size); + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_sh_wr_delta * i + a_sh_wr; + int row = a_idx / a_gl_rd_delta_o; + if (row >= prob_m) { + a_sh_wr_pred[i] = false; + } else { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; + int row = a_idx / a_gl_stride; + int sorted_row = + replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; + int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; + if (sorted_row < tot_m * (replicate_input ? 1 : topk) && + new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], + a_sh_wr_pred[i]); + } + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // TODO we are currently hitting illegal memory accesses when fetching + // sorted_ids to shared data: fix this + auto fetch_sorted_ids_to_shared = [&]() { + const int mpt = ceildiv(prob_m, threads); + for (int i = 0; i < mpt; i++) { + if ((i * sorted_gl_stride) + threadIdx.x < prob_m) { + sh_sorted[(i * sorted_sh_stride) + threadIdx.x] = + sorted_ids[(i * sorted_gl_stride) + threadIdx.x]; + } + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant_0, b_quant_1; + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k % 2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + FragB frag_b0 = dequant(b_quant_0); + FragB frag_b1 = dequant(b_quant_1); + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int sorted_row = sorted_ids[c_idx / c_gl_stride]; + int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], + sorted_row < tot_m * topk && + (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk))); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + __half2float(reinterpret_cast<__half*>(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int row = sorted_ids[c_idx / c_gl_stride]; + if (row < tot_m * topk) { + int new_idx = row * c_gl_stride + c_idx % c_gl_stride; + C[new_idx] = c; + } + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4) { + res = __hmul2(res, s[0]); + } + + ((half2*)sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + int row = sorted_ids[c_gl_wr / c_gl_stride]; + if (row < tot_m * topk) { + int off = row * c_gl_stride + c_gl_wr % c_gl_stride; + if (!apply_weights) { + C[off] = sh[c_sh_rd]; + } else { + __half* ctrg = reinterpret_cast<__half*>(&C[off]); + __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); + for (int j = 0; j < 8; ++j) { + ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); + } + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + // TODO re-enable after fixing this function + // fetch_sorted_ids_to_shared(); + // __syncthreads(); + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } else { + // For 4-bit per-column scales, we only fetch them here in the + // final step before write-out + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } + start_pipes(); + } + } + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par, // maximum parallelism + int cfg_max_m_blocks // upper bound on m blocks +) { + int m_block_ctr = current_m_block; + + const int* sorted_ids_expert = + sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; + int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; + if (tot_its == 0) { + return; + } + int tot_m_blocks = ceildiv(tot_its, 16); + int pad = 16 * tot_m_blocks - tot_its; + + if (m_block_ctr >= tot_m_blocks) { + return; + } + + int max_block = tot_m_blocks - m_block_ctr; + prob_m = tot_its - 16 * m_block_ctr; + + int par = 1; + if (max_block > cfg_max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * cfg_max_m_blocks) * par; + m_block_ctr += cfg_max_m_blocks * (par - 1); + max_block = cfg_max_m_blocks; + } + + if (max_block == 1) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 2) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 3) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } +} + +#else + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par, // maximum parallelism + int cfg_max_m_blocks // upper bound on m blocks + +) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +#endif + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +const int USER_THREADS = + 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory +// const int SHARED_MEM = +// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ + GROUP_BLOCKS, NUM_THREADS) \ + else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + cfg_max_m_blocks); \ + } + +#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu new file mode 100644 index 0000000000000..cbafd9ffe7474 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu @@ -0,0 +1,29 @@ +#include "marlin_moe_kernel_ku4b8.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4b8( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, + int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, + int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, + bool replicate_input, bool apply_weights, int m_block, int max_par, + int cfg_max_m_blocks) { + if (false) { + } + GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) + else { + return false; + } + return true; +} + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h new file mode 100644 index 0000000000000..9eacb42c115f0 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h @@ -0,0 +1,20 @@ +#pragma once + +#include "marlin_moe_kernel.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4b8( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, + int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, + int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, + bool replicate_input, bool apply_weights, int m_block, int max_par, + int cfg_max_m_blocks); + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu new file mode 100644 index 0000000000000..c46712474f715 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu @@ -0,0 +1,29 @@ +#include "marlin_moe_kernel_ku8b128.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku8b128( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, + int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, + int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, + bool replicate_input, bool apply_weights, int m_block, int max_par, + int cfg_max_m_blocks) { + if (false) { + } + GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) + else { + return false; + } + return true; +} + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h new file mode 100644 index 0000000000000..7cd9acafb3b80 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h @@ -0,0 +1,18 @@ +#pragma once + +#include "marlin_moe_kernel.h" + +namespace marlin_moe { + +bool call_marlin_moe_kernel_ku8b128( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, + int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, + int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, + bool replicate_input, bool apply_weights, int m_block, int max_par, + int cfg_max_m_blocks); + +} diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 293a6fad72c2f..dfe0437414013 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -26,6 +26,8 @@ #include #include "core/scalar_type.hpp" +#include "marlin_kernels/marlin_moe_kernel_ku4b8.h" +#include "marlin_kernels/marlin_moe_kernel_ku8b128.h" template inline std::string str(T x) { @@ -34,230 +36,8 @@ inline std::string str(T x) { namespace marlin_moe { -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -using I4 = Vec; - -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -template -__device__ inline FragB dequant(int q); - -// Efficiently dequantize 4bit values packed in an int32 value into a full -// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, -// with some small changes: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -template <> -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 -// Reference: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -template <> -__device__ inline FragB dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -// Given 2 floats multiply by 2 scales (halves) -__device__ inline void scale_float(float* c, FragS& s) { - __half* s_ptr = reinterpret_cast<__half*>(&s); - c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); -} - -// Same as above, but for act_order (each K is multiplied individually) -__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, - FragS& frag_s_3, FragS& frag_s_4, int i) { - __half2 s_val_1_2; - s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; - - __half2 s_val_3_4; - s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; - - frag_b[0] = __hmul2(frag_b[0], s_val_1_2); - frag_b[1] = __hmul2(frag_b[1], s_val_3_4); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - // For a given "a" of size [M,K] performs a permutation of the K columns based // on the given "perm" indices. __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, @@ -335,1106 +115,6 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, __syncthreads(); } -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__device__ inline void MarlinMoESingle( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block // current m block to start kernel computation from -) { - static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); - constexpr int pack_factor = 32 / w_type.size_bits(); - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - - if constexpr (!has_act_order && group_blocks != -1) { - if (group_blocks >= thread_k_blocks) { - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts - // in the middle of group. - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - } - } - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - sorted_ids += 16 * thread_m_blocks; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = - !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks - : 1; - constexpr int s_sh_stage = s_tb_groups * s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - constexpr int sorted_sh_stride = threads; - constexpr int sorted_gl_stride = threads; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; - int slice_k_start = tb_k * slice_row; - int slice_k_finish = slice_k_start + tb_k * slice_iters; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd; - if constexpr (!has_act_order) { - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - } - int s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - int s_sh_rd; - if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - int shs_size; - if constexpr (has_act_order) - shs_size = sh_max_num_groups * s_sh_stride + threads; - else - shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_s = sh_g_idx + (stages * g_idx_stage); - int* sh_sorted = (int*)(sh_s + shs_size); - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_sh_wr_delta * i + a_sh_wr; - int row = a_idx / a_gl_rd_delta_o; - if (row >= prob_m) { - a_sh_wr_pred[i] = false; - } else { - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - } - } - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; - int row = a_idx / a_gl_stride; - int sorted_row = - replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; - int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; - if (sorted_row < tot_m * (replicate_input ? 1 : topk) && - new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { - cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], - a_sh_wr_pred[i]); - } - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - B_ptr[i] += b_gl_rd_delta_o; - } - - if constexpr (has_act_order) { - // Fetch g_idx thread-block portion - int full_pipe = a_off; - int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; - if (cur_k < prob_k && cur_k < slice_k_finish) { - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - - int4 const* cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); - - if (threadIdx.x < g_idx_stage) { - cp_async4_pred(&sh_g_idx_stage[threadIdx.x], - &cur_g_idx_stage_ptr[threadIdx.x]); - } - } - } else { - if constexpr (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // TODO we are currently hitting illegal memory accesses when fetching - // sorted_ids to shared data: fix this - auto fetch_sorted_ids_to_shared = [&]() { - const int mpt = ceildiv(prob_m, threads); - for (int i = 0; i < mpt; i++) { - if ((i * sorted_gl_stride) + threadIdx.x < prob_m) { - sh_sorted[(i * sorted_sh_stride) + threadIdx.x] = - sorted_ids[(i * sorted_gl_stride) + threadIdx.x]; - } - } - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - if constexpr (!has_act_order) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - int group_id_1 = sh_g_idx_int_ptr[0]; - int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - - is_same_group[pipe] = group_id_1 == group_id_2; - same_group_id[pipe] = group_id_1; - }; - - auto fetch_scales_to_registers = [&](int k, int full_pipe) { - int pipe = full_pipe % stages; - - if constexpr (!has_act_order) { - // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; - - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } - } - - return; - } - - // Act-order case - - // Determine K of the "current" thread-block - int cur_k = slice_k_start + tb_k * full_pipe; - if (cur_k >= prob_k || cur_k >= slice_k_finish) { - return; - } - - // Reset (to current thread-block) since we read g_idx portion from the - // shared memory - cur_k = 0; - - // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); - - // Determine "position" inside the thread-block (based on warp and - // thread-id) - int warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; - - int th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - - int s_col_shift = - /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + - (th_id / 4) * act_s_col_stride; - - if (is_same_group[pipe]) { - if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + - s_col_shift]; - } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); - } - - for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); - } - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread - - #pragma unroll - for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; - - int group_id = sh_g_idx_int_ptr[actual_k]; - int rel_group_id = group_id - sh_first_group_id; - - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - sh_s[rel_group_id * s_sh_stride + s_col_shift]; - } - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { - b_quant_0 = frag_b_quant[k % 2][0][j]; - b_quant_1 = b_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - } - - FragB frag_b0 = dequant(b_quant_0); - FragB frag_b1 = dequant(b_quant_1); - - // Apply scale to frag_b0 - if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); - } else { - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - } - - // Apply scale to frag_b1 - if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); - - } else { - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int sorted_row = sorted_ids[c_idx / c_gl_stride]; - int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], - sorted_row < tot_m * topk && - (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk))); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half*>(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int row = sorted_ids[c_idx / c_gl_stride]; - if (row < tot_m * topk) { - int new_idx = row * c_gl_stride + c_idx % c_gl_stride; - C[new_idx] = c; - } - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - - // For per-column quantization we finally apply the scale here (only for - // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4) { - res = __hmul2(res, s[0]); - } - - ((half2*)sh)[idx] = res; - }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - int row = sorted_ids[c_gl_wr / c_gl_stride]; - if (row < tot_m * topk) { - int off = row * c_gl_stride + c_gl_wr % c_gl_stride; - if (!apply_weights) { - C[off] = sh[c_sh_rd]; - } else { - __half* ctrg = reinterpret_cast<__half*>(&C[off]); - __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); - for (int j = 0; j < 8; ++j) { - ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); - } - } - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - // TODO re-enable after fixing this function - // fetch_sorted_ids_to_shared(); - // __syncthreads(); - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - if (has_act_order && i == 0) { - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); - } - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - fetch_scales_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - if constexpr (has_act_order) { - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); - } - } - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } else { - // For 4-bit per-column scales, we only fetch them here in the - // final step before write-out - if (last) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } - } - } - - thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - } - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - if constexpr (has_act_order) { - slice_k_start = tb_k * slice_row; - slice_k_finish = slice_k_start + tb_k * slice_iters; - slice_k_start_shared_fetch = slice_k_start; - slice_n_offset = act_s_col_tb_stride * slice_col; - - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } - start_pipes(); - } - } - } -} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par, // maximum parallelism - int cfg_max_m_blocks // upper bound on m blocks -) { - int m_block_ctr = current_m_block; - - const int* sorted_ids_expert = - sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; - int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; - if (tot_its == 0) { - return; - } - int tot_m_blocks = ceildiv(tot_its, 16); - int pad = 16 * tot_m_blocks - tot_its; - - if (m_block_ctr >= tot_m_blocks) { - return; - } - - int max_block = tot_m_blocks - m_block_ctr; - prob_m = tot_its - 16 * m_block_ctr; - - int par = 1; - if (max_block > cfg_max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); - if (par > max_par) par = max_par; - prob_m = (16 * cfg_max_m_blocks) * par; - m_block_ctr += cfg_max_m_blocks * (par - 1); - max_block = cfg_max_m_blocks; - } - - if (max_block == 1) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 2) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 3) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } -} - #else __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, @@ -1454,81 +134,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, return; } -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par, // maximum parallelism - int cfg_max_m_blocks // upper bound on m blocks - -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - #endif -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory -// const int SHARED_MEM = -// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ - <<>>( \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - exec_cfg.max_m_blocks); \ - } - typedef struct { int thread_k; int thread_n; @@ -1703,25 +310,27 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) - -void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, - const void* sorted_ids, const void* topk_weights, - const void* topk_ids, const void* s, const void* g_idx, - const void* perm, void* a_tmp, void* expert_offsets, - int prob_m, int prob_n, int prob_k, void* workspace, - vllm::ScalarType const& q_type, bool has_act_order, - bool is_k_full, int num_groups, int group_size, - int num_experts, int topk, int moe_block_size, int dev, - cudaStream_t stream, int thread_k, int thread_n, - int sms, int max_par, bool replicate_input, - bool apply_weights) { +#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ + else if (KERNEL_FUNCTION(q_type, thread_n_blocks, thread_k_blocks, \ + has_act_order, group_blocks, num_threads, blocks, \ + max_shared_mem, stream, A_ptr, B_ptr, C_ptr, \ + sorted_ids_ptr, topk_weights_ptr, s_ptr, g_idx_ptr, \ + expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, \ + locks, replicate_input, apply_weights, m_block, \ + max_par, exec_cfg.max_m_blocks)) { \ + } + +void marlin_mm_moe(const void* A, const void* B, void* C, + const void* sorted_ids, const void* topk_weights, + const void* topk_ids, const void* s, const void* g_idx, + const void* perm, void* a_tmp, void* expert_offsets, + int prob_m, int prob_n, int prob_k, void* workspace, + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, int num_groups, int group_size, + int num_experts, int topk, int moe_block_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, int sms, + int max_par, bool replicate_input, bool apply_weights) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1845,26 +454,16 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, int tot_m_blocks = ceildiv(tot_m, 16); for (int m_block = 0; m_block < tot_m_blocks; m_block += 4 * exec_cfg.max_m_blocks) { - // make it max possible value - int thread_m_blocks = exec_cfg.max_m_blocks; - if (false) { } - CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) - CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) - CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) - CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) - CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) - CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) - CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) - CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + ", has_act_order = " + str(has_act_order) + ", num_groups = " + str(num_groups) + ", group_size = " + str(group_size) + - ", thread_m_blocks = " + str(thread_m_blocks) + ", thread_n_blocks = " + str(thread_n_blocks) + ", thread_k_blocks = " + str(thread_k_blocks)); } @@ -1943,7 +542,7 @@ torch::Tensor marlin_gemm_moe( } } - marlin_moe::marlin_mm_moe_f16i4( + marlin_moe::marlin_mm_moe( a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), diff --git a/csrc/ops.h b/csrc/ops.h index 93621d2c78d4e..4f1525799a24e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -124,6 +124,8 @@ torch::Tensor prepack_B(torch::Tensor const& B, }; // namespace machete +torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); + torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, diff --git a/csrc/permute_cols.cu b/csrc/permute_cols.cu new file mode 100644 index 0000000000000..f51fa73298cc1 --- /dev/null +++ b/csrc/permute_cols.cu @@ -0,0 +1,88 @@ +#include + +#include +#include + +#include + +static constexpr int default_threads = 256; +static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +// Currently only supports 16bit types (since we permute half types) +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = std::max(finish_row - start_row, 0); + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int offset = row * row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); + half* out_half = reinterpret_cast(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +// More efficient version of A[..., perm] +// taken from gptq_marlin.cu +torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + auto dev = A.get_device(); + auto stream = at::cuda::getCurrentCUDAStream(dev); + + TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16, + "Currently only 16bit types are supported"); + TORCH_CHECK(A.is_contiguous(), "A must be contiguous"); + TORCH_CHECK(A.size(-1) % 8 == 0, + "A columns must be a multiple of 8 (128bits)"); + auto A_2d = A.view({-1, A.size(-1)}); + + torch::Tensor D = torch::empty_like(A); + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + int block_rows = div_ceil(A_2d.size(0), sms); + permute_cols_kernel<<>>( + reinterpret_cast(A_2d.const_data_ptr()), + perm.const_data_ptr(), reinterpret_cast(D.mutable_data_ptr()), + A_2d.size(0), A_2d.size(1), block_rows); + return D; +} \ No newline at end of file diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 09a98a5dd1fd6..8ed81ea727aa3 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -157,7 +157,7 @@ TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative -@dataclass +@dataclass(frozen=True) class ScheduleConfig: tile_shape_mn: Tuple[int, int] cluster_shape_mnk: Tuple[int, int, int] @@ -328,56 +328,137 @@ def generate(): # about how this works SCRIPT_DIR = os.path.dirname(__file__) - schedules = [ - ScheduleConfig( - tile_shape_mn=tile_shape_mn, - cluster_shape_mnk=cluster_shape_mnk, - kernel_schedule=kernel_schedule, - epilogue_schedule=epilogue_schedule, - tile_scheduler=tile_scheduler, - ) for tile_shape_mn, cluster_shape_mnk in ( - ((128, 16), (1, 1, 1)), - ((128, 32), (1, 1, 1)), - ((128, 64), (1, 1, 1)), - ((128, 128), (1, 1, 1)), - ) for kernel_schedule in (TmaMI, ) for epilogue_schedule in (TmaCoop, ) - for tile_scheduler in (TileSchedulerType.StreamK, ) - ] + schedule_common_params = dict( + kernel_schedule=TmaMI, + epilogue_schedule=TmaCoop, + tile_scheduler=TileSchedulerType.StreamK, + ) # For now we use the same heuristic for all types + # Heuristic is currently tuned for H100s default_heuristic = [ - ("M > 64", - ScheduleConfig( - tile_shape_mn=(128, 128), - cluster_shape_mnk=(1, 1, 1), - kernel_schedule=TmaMI, - epilogue_schedule=TmaCoop, - tile_scheduler=TileSchedulerType.StreamK, - )), - ("M > 32", - ScheduleConfig( - tile_shape_mn=(128, 64), - cluster_shape_mnk=(1, 1, 1), - kernel_schedule=TmaMI, - epilogue_schedule=TmaCoop, - tile_scheduler=TileSchedulerType.StreamK, - )), - ("M > 16", - ScheduleConfig( - tile_shape_mn=(128, 32), - cluster_shape_mnk=(1, 1, 1), - kernel_schedule=TmaMI, - epilogue_schedule=TmaCoop, - tile_scheduler=TileSchedulerType.StreamK, - )), - (None, - ScheduleConfig(tile_shape_mn=(128, 16), - cluster_shape_mnk=(1, 1, 1), - kernel_schedule=TmaMI, - epilogue_schedule=TmaCoop, - tile_scheduler=TileSchedulerType.StreamK)) + #### M = 257+ + ( + "M > 256 && K <= 16384 && N <= 4096", + ScheduleConfig( + tile_shape_mn=(128, 128), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 256", + ScheduleConfig( + tile_shape_mn=(128, 256), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 129-256 + ( + "M > 128 && K <= 4096 && N <= 4096", + ScheduleConfig( + tile_shape_mn=(128, 64), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 128 && K <= 8192 && N <= 8192", + ScheduleConfig( + tile_shape_mn=(128, 128), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 128", + ScheduleConfig( + tile_shape_mn=(128, 256), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 65-128 + ( + "M > 64 && K <= 4069 && N <= 4069", + ScheduleConfig( + tile_shape_mn=(128, 32), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 64 && K <= 4069 && N <= 8192", + ScheduleConfig( + tile_shape_mn=(128, 64), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 64 && K >= 8192 && N >= 12288", + ScheduleConfig( + tile_shape_mn=(256, 128), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 64", + ScheduleConfig( + tile_shape_mn=(128, 128), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 33-64 + ( + "M > 32 && K <= 6144 && N <= 6144", + ScheduleConfig( + tile_shape_mn=(128, 16), + cluster_shape_mnk=(1, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 32 && K >= 16384 && N >= 12288", + ScheduleConfig( + tile_shape_mn=(256, 64), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 32", + ScheduleConfig( + tile_shape_mn=(128, 64), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 17-32 + ( + "M > 16 && K <= 12288 && N <= 8192", + ScheduleConfig( + tile_shape_mn=(128, 32), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 16", + ScheduleConfig( + tile_shape_mn=(256, 32), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 1-16 + ( + "N >= 26624", + ScheduleConfig( + tile_shape_mn=(256, 16), + cluster_shape_mnk=(1, 1, 1), + **schedule_common_params # type: ignore + )), + ( + None, + ScheduleConfig( + tile_shape_mn=(128, 16), + cluster_shape_mnk=(1, 1, 1), + **schedule_common_params # type: ignore + )), ] + schedules = list(set([x[1] for x in default_heuristic])) + impl_configs = [] GPTQ_kernel_type_configs = list( diff --git a/csrc/quantization/machete/machete_mm_kernel.cuh b/csrc/quantization/machete/machete_mm_kernel.cuh index 046e6e5a53652..4d41b8d291484 100644 --- a/csrc/quantization/machete/machete_mm_kernel.cuh +++ b/csrc/quantization/machete/machete_mm_kernel.cuh @@ -152,7 +152,8 @@ struct MacheteKernelTemplate { int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A); - int const group_size = maybe_group_size.value_or(K); + int const group_size = + maybe_group_size == -1 ? K : maybe_group_size.value_or(K); int const scale_k = (K + group_size - 1) / group_size; TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K); diff --git a/csrc/quantization/machete/machete_mm_launcher.cuh b/csrc/quantization/machete/machete_mm_launcher.cuh index e2604d4bed3e2..60a4ed60535b7 100644 --- a/csrc/quantization/machete/machete_mm_launcher.cuh +++ b/csrc/quantization/machete/machete_mm_launcher.cuh @@ -71,7 +71,7 @@ torch::Tensor run_impl(PyTorchArguments args) { auto arguments = MacheteKernel::create_arguments( stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr, layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0), - args.group_size.value_or(K)); + args.group_size); TORCH_CHECK(MacheteKernel::can_implement(arguments), "Machete kernel cannot be run with these arguments"); diff --git a/csrc/quantization/machete/machete_prepack_launcher.cuh b/csrc/quantization/machete/machete_prepack_launcher.cuh index 686dd68bd52bb..df78312997fb0 100644 --- a/csrc/quantization/machete/machete_prepack_launcher.cuh +++ b/csrc/quantization/machete/machete_prepack_launcher.cuh @@ -53,7 +53,7 @@ torch::Tensor prepack_impl(torch::Tensor const B) { // clang-format on // Allocate output - torch::Tensor D = torch::empty_like(B); + torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous); prepack_B(stream, B_ptr, layout_Bt, static_cast(D.mutable_data_ptr())); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index a48fc0df74ff4..f876041a40976 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -212,6 +212,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "-> Tensor"); ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B); + ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); + ops.impl("permute_cols", torch::kCUDA, &permute_cols); + // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " @@ -292,7 +295,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor! A, Tensor! B, Tensor! C," "Tensor? D_, Tensor? z_, Tensor? delta_bias_," "bool delta_softplus," - "Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]"); + "Tensor? index_, Tensor!? x) -> Tensor[]"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); ops.def( @@ -309,7 +312,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? bias_," "Tensor? seq_idx_," "Tensor? initial_states_," - "Tensor? final_states_out_," + "Tensor!? final_states_out_," "bool silu_activation) -> Tensor"); ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); #endif diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index e112b43aade5e..241b2ccd0991e 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -8,7 +8,7 @@ Multi-Modality vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models ` -via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`. +via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`. Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities by following :ref:`this guide `. diff --git a/docs/source/dev/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst index 0d47281db485e..9adf82d43f3e0 100644 --- a/docs/source/dev/offline_inference/llm_inputs.rst +++ b/docs/source/dev/offline_inference/llm_inputs.rst @@ -1,7 +1,7 @@ LLM Inputs ========== -.. autodata:: vllm.inputs.PromptType +.. autodata:: vllm.inputs.PromptInputs .. autoclass:: vllm.inputs.TextPrompt :show-inheritance: diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index 7083a78a64584..f70ee81afd927 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -3,15 +3,17 @@ Installation with ROCm ====================== -vLLM supports AMD GPUs with ROCm 6.1. +vLLM supports AMD GPUs with ROCm 6.2. Requirements ------------ * OS: Linux -* Python: 3.8 -- 3.11 +* Python: 3.9 -- 3.12 * GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100) -* ROCm 6.1 +* ROCm 6.2 + +Note: PyTorch 2.5+/ROCm6.2 dropped the support for python 3.8. Installation options: @@ -40,8 +42,18 @@ You can build and install vLLM from source. First, build a docker image from `Dockerfile.rocm `_ and launch a docker container from the image. +It is important that the user kicks off the docker build using buildkit. Either the user put DOCKER_BUILDKIT=1 as environment variable when calling docker build command, or the user needs to setup buildkit in the docker daemon configuration /etc/docker/daemon.json as follows and restart the daemon: + +.. code-block:: console + + { + "features": { + "buildkit": true + } + } + -`Dockerfile.rocm `_ uses ROCm 6.1 by default, but also supports ROCm 5.7 and 6.0 in older vLLM branches. +`Dockerfile.rocm `_ uses ROCm 6.2 by default, but also supports ROCm 5.7, 6.0 and 6.1 in older vLLM branches. It provides flexibility to customize the build of docker image using the following arguments: * `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. @@ -53,13 +65,13 @@ It provides flexibility to customize the build of docker image using the followi Their values can be passed in when running ``docker build`` with ``--build-arg`` options. -To build vllm on ROCm 6.1 for MI200 and MI300 series, you can use the default: +To build vllm on ROCm 6.2 for MI200 and MI300 series, you can use the default: .. code-block:: console $ DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm . -To build vllm on ROCm 6.1 for Radeon RX7900 series (gfx1100), you should specify ``BUILD_FA`` as below: +To build vllm on ROCm 6.2 for Radeon RX7900 series (gfx1100), you should specify ``BUILD_FA`` as below: .. code-block:: console @@ -93,9 +105,8 @@ Option 2: Build from source - `ROCm `_ - `PyTorch `_ -- `hipBLAS `_ -For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging`, `rocm/pytorch-nightly`. +For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0`, `rocm/pytorch-nightly`. Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch `Getting Started `_ @@ -104,26 +115,45 @@ Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTor Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from `ROCm/triton `_ + .. code-block:: console + + $ python3 -m pip install ninja cmake wheel pybind11 + $ pip uninstall -y triton + $ git clone https://github.com/OpenAI/triton.git + $ cd triton + $ git checkout e192dba + $ cd python + $ pip3 install . + $ cd ../.. + +.. note:: + - If you see HTTP issue related to downloading packages during building triton, please try again as the HTTP error is intermittent. + + 2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm `_ + Install ROCm's flash attention (v2.5.9.post1) following the instructions from `ROCm/flash-attention `_ Alternatively, wheels intended for vLLM use can be accessed under the releases. -.. note:: - - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) - -3. Build vLLM. +For example, for ROCm 6.2, suppose your gfx arch is `gfx90a`. +Note to get your gfx architecture, run `rocminfo |grep gfx`. -.. code-block:: console + .. code-block:: console - $ cd vllm - $ pip install -U -r requirements-rocm.txt - $ python setup.py develop # This may take 5-10 minutes. Currently, `pip install .` does not work for ROCm installation + $ git clone https://github.com/ROCm/flash-attention.git + $ cd flash-attention + $ git checkout 3cea2fb + $ git submodule update --init + $ GPU_ARCHS="gfx90a" python3 setup.py install + $ cd .. +.. note:: + - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) -.. tip:: +3. Build vLLM. - For example, vLLM v0.5.3 on ROCM 6.1 can be built with the following steps: + For example, vLLM on ROCM 6.2 can be built with the following steps: .. code-block:: console @@ -131,7 +161,7 @@ Alternatively, wheels intended for vLLM use can be accessed under the releases. $ # Install PyTorch $ pip uninstall torch -y - $ pip install --no-cache-dir --pre torch==2.5.0.dev20240726 --index-url https://download.pytorch.org/whl/nightly/rocm6.1 + $ pip install --no-cache-dir --pre torch==2.6.0.dev20240918 --index-url https://download.pytorch.org/whl/nightly/rocm6.2 $ # Build & install AMD SMI $ pip install /opt/rocm/share/amd_smi @@ -141,15 +171,14 @@ Alternatively, wheels intended for vLLM use can be accessed under the releases. $ pip install "numpy<2" $ pip install -r requirements-rocm.txt - $ # Apply the patch to ROCM 6.1 (requires root permission) - $ wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P /opt/rocm/lib - $ rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* - $ # Build vLLM for MI210/MI250/MI300. $ export PYTORCH_ROCM_ARCH="gfx90a;gfx942" $ python3 setup.py develop + This may take 5-10 minutes. Currently, :code:`pip install .` does not work for ROCm installation. + + .. tip:: - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index 816e0a29ef28b..c8947beb34942 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -56,7 +56,7 @@ Build from source .. code-block:: console $ pip install --upgrade pip - $ pip install wheel packaging ninja "setuptools>=49.4.0" numpy + $ pip install cmake>=3.26 wheel packaging ninja "setuptools-scm>=8" numpy $ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu - Third, build and install oneDNN library from source: diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index d86d0860f7f29..c41903f84910d 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -242,18 +242,23 @@ Multimodal Language Models * - :code:`LlavaNextVideoForConditionalGeneration` - LLaVA-NeXT-Video - Video - - :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. (see note) + - :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. - * - :code:`LlavaOnevisionForConditionalGeneration` - LLaVA-Onevision - Image\ :sup:`+` / Video - - :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. (see note) + - :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. - * - :code:`MiniCPMV` - MiniCPM-V - Image\ :sup:`+` - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - + * - :code:`MllamaForConditionalGeneration` + - Llama 3.2 + - Image + - :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc. + - * - :code:`PaliGemmaForConditionalGeneration` - PaliGemma - Image\ :sup:`E` @@ -275,7 +280,7 @@ Multimodal Language Models - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - * - :code:`Qwen2VLForConditionalGeneration` - - Qwen2-VL (see note) + - Qwen2-VL - Image\ :sup:`+` / Video\ :sup:`+` - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. - @@ -292,15 +297,6 @@ Multimodal Language Models For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 -.. note:: - For :code:`LLaVA-NeXT-Video`, :code:`LLaVA-Onevision` and :code:`Qwen2-VL`, the latest release of :code:`huggingface/transformers` doesn't work yet, so we need to use a developer version (:code:`21fac7abba2a37fae86106f87fcf9974fd1e3830`) for now. - This can be installed by running the following command: - - .. code-block:: bash - - pip install git+https://github.com/huggingface/transformers.git@21fac7abba2a37fae86106f87fcf9974fd1e3830 - ----- If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. Otherwise, please refer to :ref:`Adding a New Model ` and :ref:`Enabling Multimodal Inputs ` diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index ca5b125369c85..08db891665044 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model. -To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`: +To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`: * ``prompt``: The prompt should follow the format that is documented on HuggingFace. * ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. diff --git a/docs/source/quantization/bnb.rst b/docs/source/quantization/bnb.rst index aefb54a8acb65..682938cc63d48 100644 --- a/docs/source/quantization/bnb.rst +++ b/docs/source/quantization/bnb.rst @@ -11,7 +11,7 @@ Below are the steps to utilize BitsAndBytes with vLLM. .. code-block:: console - $ pip install bitsandbytes>=0.42.0 + $ pip install bitsandbytes>=0.44.0 vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint. diff --git a/examples/lora_with_quantization_inference.py b/examples/lora_with_quantization_inference.py index 3b2347c1115e1..0c454ea50f665 100644 --- a/examples/lora_with_quantization_inference.py +++ b/examples/lora_with_quantization_inference.py @@ -79,23 +79,17 @@ def initialize_engine(model: str, quantization: str, # It quantizes the model when loading, with some config info from the # LoRA adapter repo. So need to set the parameter of load_format and # qlora_adapter_name_or_path as below. - engine_args = EngineArgs( - model=model, - quantization=quantization, - qlora_adapter_name_or_path=lora_repo, - load_format="bitsandbytes", - enable_lora=True, - max_lora_rank=64, - # set it only in GPUs of limited memory - enforce_eager=True) + engine_args = EngineArgs(model=model, + quantization=quantization, + qlora_adapter_name_or_path=lora_repo, + load_format="bitsandbytes", + enable_lora=True, + max_lora_rank=64) else: - engine_args = EngineArgs( - model=model, - quantization=quantization, - enable_lora=True, - max_loras=4, - # set it only in GPUs of limited memory - enforce_eager=True) + engine_args = EngineArgs(model=model, + quantization=quantization, + enable_lora=True, + max_loras=4) return LLMEngine.from_engine_args(engine_args) diff --git a/examples/offline_inference_chat.py b/examples/offline_inference_chat.py index c2020724c72fe..8814f4d7bef0d 100644 --- a/examples/offline_inference_chat.py +++ b/examples/offline_inference_chat.py @@ -39,6 +39,33 @@ def print_outputs(outputs): use_tqdm=False) print_outputs(outputs) +# You can run batch inference with llm.chat API +conversation = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "Hello" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "Write an essay about the importance of higher education.", + }, +] +conversations = [conversation for _ in range(10)] + +# We turn on tqdm progress bar to verify it's indeed running batch inference +outputs = llm.chat(messages=conversations, + sampling_params=sampling_params, + use_tqdm=True) +print_outputs(outputs) + # A chat template can be optionally supplied. # If not, the model will use its default chat template. diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index c1129316a6e30..6d34621a8a9bc 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -83,10 +83,24 @@ def run_phi3v(question, modality): # In this example, we override max_num_seqs to 5 while # keeping the original context length of 128k. + + # num_crops is an override kwarg to the multimodal image processor; + # For some models, e.g., Phi-3.5-vision-instruct, it is recommended + # to use 16 for single frame scenarios, and 4 for multi-frame. + # + # Generally speaking, a larger value for num_crops results in more + # tokens per image instance, because it may scale the image more in + # the image preprocessing. Some references in the model docs and the + # formula for image tokens after the preprocessing + # transform can be found below. + # + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194 llm = LLM( model="microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True, max_num_seqs=5, + mm_processor_kwargs={"num_crops": 16}, ) stop_token_ids = None return llm, prompt, stop_token_ids @@ -228,6 +242,29 @@ def run_qwen2_vl(question, modality): return llm, prompt, stop_token_ids +# LLama +def run_mllama(question, modality): + assert modality == "image" + + model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" + + # Note: The default setting of max_num_seqs (256) and + # max_model_len (131072) for this model may cause OOM. + # You may lower either to run this example on lower-end GPUs. + + # The configuration below has been confirmed to launch on a + # single H100 GPU. + llm = LLM( + model=model_name, + max_num_seqs=16, + enforce_eager=True, + ) + + prompt = f"<|image|><|begin_of_text|>{question}" + stop_token_ids = None + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -242,6 +279,7 @@ def run_qwen2_vl(question, modality): "internvl_chat": run_internvl, "qwen_vl": run_qwen_vl, "qwen2_vl": run_qwen2_vl, + "mllama": run_mllama, } diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 92ab4f42baa80..8c5f1a7b7af08 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -67,11 +67,24 @@ def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData: def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: + # num_crops is an override kwarg to the multimodal image processor; + # For some models, e.g., Phi-3.5-vision-instruct, it is recommended + # to use 16 for single frame scenarios, and 4 for multi-frame. + # + # Generally speaking, a larger value for num_crops results in more + # tokens per image instance, because it may scale the image more in + # the image preprocessing. Some references in the model docs and the + # formula for image tokens after the preprocessing + # transform can be found below. + # + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194 llm = LLM( model="microsoft/Phi-3.5-vision-instruct", trust_remote_code=True, max_model_len=4096, limit_mm_per_prompt={"image": len(image_urls)}, + mm_processor_kwargs={"num_crops": 4}, ) placeholders = "\n".join(f"<|image_{i}|>" for i, _ in enumerate(image_urls, start=1)) diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py index 1ba702ef019e4..71ae03e4d148b 100644 --- a/examples/openai_vision_api_client.py +++ b/examples/openai_vision_api_client.py @@ -38,7 +38,7 @@ "content": [ { "type": "text", - "text": "What’s in this image?" + "text": "What's in this image?" }, { "type": "image_url", @@ -75,7 +75,7 @@ def encode_image_base64_from_url(image_url: str) -> str: "content": [ { "type": "text", - "text": "What’s in this image?" + "text": "What's in this image?" }, { "type": "image_url", diff --git a/pyproject.toml b/pyproject.toml index 8782a60eac514..be049e6d76616 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,8 @@ requires = [ "cmake>=3.26", "ninja", "packaging", - "setuptools >= 49.4.0", + "setuptools>=61", + "setuptools-scm>=8.0", "torch == 2.4.0", "wheel", "jinja2", @@ -19,6 +20,10 @@ exclude = [ "examples/fp8/quantizer/quantize.py" ] +[tool.ruff.lint.per-file-ignores] +"vllm/version.py" = ["F401"] +"vllm/_version.py" = ["ALL"] + [tool.ruff.lint] select = [ # pycodestyle diff --git a/requirements-build.txt b/requirements-build.txt index 3f08f5d67b6da..6144a56da8c47 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -2,7 +2,8 @@ cmake>=3.26 ninja packaging -setuptools>=49.4.0 +setuptools>=61 +setuptools-scm>=8 torch==2.4.0 wheel jinja2 diff --git a/requirements-common.txt b/requirements-common.txt index c113ff3630425..2fc89c026901b 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -4,7 +4,7 @@ numpy < 2.0.0 requests tqdm py-cpuinfo -transformers >= 4.43.2 # Required for Chameleon and Llama 3.1 hotfox. +transformers >= 4.45.0 # Required for Llama 3.2. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. fastapi < 0.113.0; python_version < '3.9' diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 1cbcfe6a0840f..9e3c4a86cd81d 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -9,3 +9,4 @@ ray >= 2.10.0 peft pytest-asyncio tensorizer>=2.9.0 +setuptools-scm>=8 \ No newline at end of file diff --git a/requirements-test.txt b/requirements-test.txt index 10d463de27be5..9c6fadb88865a 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -30,5 +30,5 @@ datamodel_code_generator # required for minicpm3 test aiohttp # quantization -bitsandbytes==0.42.0 +bitsandbytes>=0.44.0 buildkite-test-collector==0.1.8 diff --git a/setup.py b/setup.py index 60e31af0a8d39..8ef759f5245fc 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,6 @@ import re import subprocess import sys -import warnings from pathlib import Path from shutil import which from typing import Dict, List @@ -14,6 +13,7 @@ from packaging.version import Version, parse from setuptools import Extension, find_packages, setup from setuptools.command.build_ext import build_ext +from setuptools_scm import get_version from torch.utils.cpp_extension import CUDA_HOME @@ -28,34 +28,6 @@ def load_module_from_path(module_name, path): ROOT_DIR = os.path.dirname(__file__) logger = logging.getLogger(__name__) - -def embed_commit_hash(): - try: - if "BUILDKITE_COMMIT" in os.environ: - # ci build - commit_id = os.environ["BUILDKITE_COMMIT"] - else: - commit_id = subprocess.check_output(["git", "rev-parse", "HEAD"], - encoding="utf-8").strip() - - commit_contents = f'__commit__ = "{commit_id}"\n' - - version_file = os.path.join(ROOT_DIR, "vllm", "commit_id.py") - with open(version_file, "w", encoding="utf-8") as f: - f.write(commit_contents) - - except subprocess.CalledProcessError as e: - warnings.warn(f"Failed to get commit hash:\n{e}", - RuntimeWarning, - stacklevel=2) - except Exception as e: - warnings.warn(f"Failed to embed commit hash:\n{e}", - RuntimeWarning, - stacklevel=2) - - -embed_commit_hash() - # cannot import envs directly because it depends on vllm, # which is not installed yet envs = load_module_from_path('envs', os.path.join(ROOT_DIR, 'vllm', 'envs.py')) @@ -381,52 +353,43 @@ def get_path(*filepath) -> str: return os.path.join(ROOT_DIR, *filepath) -def find_version(filepath: str) -> str: - """Extract version information from the given filepath. - - Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py - """ - with open(filepath) as fp: - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - fp.read(), re.M) - if version_match: - return version_match.group(1) - raise RuntimeError("Unable to find version string.") - - def get_vllm_version() -> str: - version = find_version(get_path("vllm", "version.py")) + version = get_version( + write_to="vllm/_version.py", # TODO: move this to pyproject.toml + ) + + sep = "+" if "+" not in version else "." # dev versions might contain + if _no_device(): if envs.VLLM_TARGET_DEVICE == "empty": - version += "+empty" + version += f"{sep}empty" elif _is_cuda(): cuda_version = str(get_nvcc_cuda_version()) if cuda_version != MAIN_CUDA_VERSION: cuda_version_str = cuda_version.replace(".", "")[:3] # skip this for source tarball, required for pypi if "sdist" not in sys.argv: - version += f"+cu{cuda_version_str}" + version += f"{sep}cu{cuda_version_str}" elif _is_hip(): # Get the HIP version hipcc_version = get_hipcc_rocm_version() if hipcc_version != MAIN_CUDA_VERSION: rocm_version_str = hipcc_version.replace(".", "")[:3] - version += f"+rocm{rocm_version_str}" + version += f"{sep}rocm{rocm_version_str}" elif _is_neuron(): # Get the Neuron version neuron_version = str(get_neuronxcc_version()) if neuron_version != MAIN_CUDA_VERSION: neuron_version_str = neuron_version.replace(".", "")[:3] - version += f"+neuron{neuron_version_str}" + version += f"{sep}neuron{neuron_version_str}" elif _is_openvino(): - version += "+openvino" + version += f"{sep}openvino" elif _is_tpu(): - version += "+tpu" + version += f"{sep}tpu" elif _is_cpu(): - version += "+cpu" + version += f"{sep}cpu" elif _is_xpu(): - version += "+xpu" + version += f"{sep}xpu" else: raise RuntimeError("Unknown runtime environment") diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 2e309aaa58d48..5dd65ad7236f9 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -1,42 +1,13 @@ -import os - import pytest -from vllm.utils import cuda_device_count_stateless - -from ..utils import fork_new_process_for_each_test - - -@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) -@pytest.mark.parametrize("tp_size", [1, 2]) -@fork_new_process_for_each_test -def test_full_graph(model, tp_size): - - # Skip the test if there are not enough CUDA devices. - if cuda_device_count_stateless() < tp_size: - pytest.skip("Not enough CUDA devices for the test.") - - # make sure these models can be captured in full graph mode - if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: - os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" +from vllm.compilation.backends import vllm_backend - from vllm import LLM, SamplingParams - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - sampling_params = SamplingParams(temperature=0) - llm = LLM(model=model, - enforce_eager=True, - tensor_parallel_size=tp_size, - disable_custom_all_reduce=True) +from .utils import TEST_MODELS, check_full_graph_support - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") +@pytest.mark.parametrize("model_info", TEST_MODELS) +@pytest.mark.parametrize("backend", ["eager", vllm_backend]) +def test_full_graph(model_info, backend): + model = model_info[0] + model_kwargs = model_info[1] + check_full_graph_support(model, model_kwargs, backend, tp_size=1) diff --git a/tests/compile/test_full_graph_multi_gpu.py b/tests/compile/test_full_graph_multi_gpu.py new file mode 100644 index 0000000000000..e9883d5254e72 --- /dev/null +++ b/tests/compile/test_full_graph_multi_gpu.py @@ -0,0 +1,22 @@ +import pytest + +from vllm.compilation.backends import vllm_backend +from vllm.utils import cuda_device_count_stateless + +from ..utils import fork_new_process_for_each_test +from .utils import TEST_MODELS_SMOKE, check_full_graph_support + + +@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("backend", ["eager", vllm_backend]) +@fork_new_process_for_each_test +def test_full_graph_multi_gpu(model_info, tp_size, backend): + model = model_info[0] + model_kwargs = model_info[1] + + # Skip the test if there are not enough CUDA devices. + if cuda_device_count_stateless() < tp_size: + pytest.skip("Not enough CUDA devices for the test.") + + check_full_graph_support(model, model_kwargs, backend, tp_size=tp_size) diff --git a/tests/compile/test_full_graph_smoke.py b/tests/compile/test_full_graph_smoke.py new file mode 100644 index 0000000000000..0c5a95b4ead4c --- /dev/null +++ b/tests/compile/test_full_graph_smoke.py @@ -0,0 +1,13 @@ +import pytest + +from vllm.compilation.backends import vllm_backend + +from .utils import TEST_MODELS_SMOKE, check_full_graph_support + + +@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE) +@pytest.mark.parametrize("backend", ["eager", vllm_backend]) +def test_full_graph(model_info, backend): + model = model_info[0] + model_kwargs = model_info[1] + check_full_graph_support(model, model_kwargs, backend, tp_size=1) diff --git a/tests/compile/utils.py b/tests/compile/utils.py new file mode 100644 index 0000000000000..2d06a0946d911 --- /dev/null +++ b/tests/compile/utils.py @@ -0,0 +1,104 @@ +import os + +import torch + +from tests.quantization.utils import is_quant_method_supported +from vllm import LLM, SamplingParams +from vllm.plugins import set_torch_compile_backend +from vllm.utils import is_hip + +TEST_MODELS_SMOKE = [ + ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", { + "quantization": "compressed-tensors" + }), + ("meta-llama/Meta-Llama-3-8B", {}), +] + +TEST_MODELS = [ + ("facebook/opt-125m", {}), + ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", { + "dtype": torch.float16, + "quantization": "compressed-tensors" + }), + ("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", { + "dtype": torch.float16, + "quantization": "fp8" + }), + ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", { + "quantization": "compressed-tensors" + }), + ("meta-llama/Meta-Llama-3-8B", {}), +] + +# TODO: enable in pytorch 2.5 +if False and is_quant_method_supported("aqlm"): # noqa: SIM223 + TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", { + "quantization": "aqlm" + })) + +# TODO: enable in pytorch 2.5 +if False and is_quant_method_supported("gguf"): # noqa: SIM223 + TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", { + "quantization": "gguf" + })) + +if is_quant_method_supported("gptq"): + TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", { + "quantization": "gptq" + })) + +if is_quant_method_supported("gptq_marlin"): + TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", { + "quantization": "gptq_marlin" + })) + +if is_quant_method_supported("gptq_marlin_24"): + TEST_MODELS.append(("alexm-nm/tinyllama-24-marlin24-4bit-g128", { + "quantization": "gptq_marlin_24" + })) + +if is_quant_method_supported("marlin"): + TEST_MODELS.append(("robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin", { + "quantization": "marlin" + })) + +if not is_hip() and is_quant_method_supported("awq"): + TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", { + "quantization": "AWQ" + })) + + +def check_full_graph_support(model, model_kwargs, backend, tp_size=1): + # make sure these models can be captured in full graph mode + if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ: + os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1" + os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" + + # Inductor doesn't support fp8/gptq_marlin_24 yet. + quantization = model_kwargs.get("quantization") + if (quantization == "fp8" or quantization == "gptq_marlin" + or quantization == "gptq_marlin_24") and backend != "eager": + return + + set_torch_compile_backend(backend) + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0) + llm = LLM(model=model, + enforce_eager=True, + tensor_parallel_size=tp_size, + disable_custom_all_reduce=True, + **model_kwargs) + + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/conftest.py b/tests/conftest.py index c2616bcf7091c..354862e3579ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -169,6 +169,12 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): cleanup() +@pytest.fixture(autouse=True) +def dynamo_reset(): + yield + torch._dynamo.reset() + + @pytest.fixture def example_prompts() -> List[str]: prompts = [] @@ -675,8 +681,6 @@ def generate_w_logprobs( videos: Optional[PromptVideoInput] = None, ) -> Union[List[TokensTextLogprobs], List[TokensTextLogprobsPromptLogprobs]]: - assert sampling_params.logprobs is not None - if images is not None: assert len(prompts) == len(images) @@ -754,7 +758,7 @@ def generate_greedy_logprobs( temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs, - prompt_logprobs=(num_prompt_logprobs), + prompt_logprobs=num_prompt_logprobs, stop_token_ids=stop_token_ids) return self.generate_w_logprobs(prompts, @@ -798,6 +802,20 @@ def generate_beam_search( outputs = self.generate(prompts, beam_search_params) return outputs + def generate_beam_search_new( + self, + prompts: Union[List[str], List[List[int]]], + beam_width: int, + max_tokens: int, + ) -> List[Tuple[List[List[int]], List[str]]]: + outputs = self.model.beam_search(prompts, beam_width, max_tokens) + returned_outputs = [] + for output in outputs: + token_ids = [x.tokens for x in output.sequences] + texts = [x.text for x in output.sequences] + returned_outputs.append((token_ids, texts)) + return returned_outputs + def encode(self, prompts: List[str]) -> List[List[float]]: req_outputs = self.model.encode(prompts) outputs = [] diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 2f6ea632a5d9b..9dddd751c7858 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -27,16 +27,19 @@ def schedule_and_update_computed_tokens(scheduler): return metas, out -def test_simple(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_simple(use_v2_block_manager: bool): """Verify basic scheduling works.""" block_size = 4 num_seq_group = 4 max_model_len = 16 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - num_seq_group, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + num_seq_group, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -45,7 +48,9 @@ def test_simple(): # Add seq groups to scheduler. for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=block_size, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -69,30 +74,36 @@ def test_simple(): assert len(seq_group_meta) == num_seq_group -def test_chunk(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_chunk(use_v2_block_manager: bool): """Verify prefills are chunked properly.""" block_size = 4 max_seqs = 60 max_model_len = 80 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 32 + cache_config.num_gpu_blocks = 32 scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(2): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) # Verify the second request is chunked. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + print() assert set(get_sequence_groups(out)) == set(running) assert seq_group_meta[0].token_chunk_size == 60 # Verify it is chunked. @@ -113,24 +124,29 @@ def test_chunk(): assert out.num_batched_tokens == 57 -def test_complex(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_complex(use_v2_block_manager: bool): block_size = 4 max_seqs = 60 max_model_len = 80 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 64 + cache_config.num_gpu_blocks = 64 scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(2): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -151,7 +167,9 @@ def test_complex(): # Add 2 more requests. for i in range(2, 4): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -176,16 +194,19 @@ def test_complex(): assert running[2].is_prefill() -def test_maximal_decoding(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_maximal_decoding(use_v2_block_manager: bool): """Verify decoding requests are prioritized.""" block_size = 4 max_seqs = 2 max_model_len = 8 max_num_batched_tokens = 2 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -194,7 +215,9 @@ def test_maximal_decoding(): # Add seq groups to scheduler. for i in range(2): - _, seq_group = create_dummy_prompt(str(i), prompt_length=2) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=2, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -211,7 +234,9 @@ def test_maximal_decoding(): append_new_token(running[0], 1) # Create one more seq_group. - _, seq_group = create_dummy_prompt("3", prompt_length=2) + _, seq_group = create_dummy_prompt("3", + prompt_length=2, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -263,23 +288,28 @@ def test_maximal_decoding(): assert out.num_batched_tokens == 2 -def test_prompt_limit(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prompt_limit(use_v2_block_manager: bool): """Verify max_num_batched_tokens < max_model_len is possible.""" block_size = 4 max_seqs = 32 max_model_len = 64 max_num_batched_tokens = 32 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] - _, seq_group = create_dummy_prompt("1", prompt_length=48) + _, seq_group = create_dummy_prompt("1", + prompt_length=48, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -293,7 +323,8 @@ def test_prompt_limit(): assert out.num_batched_tokens == 32 -def test_prompt_limit_exceed(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prompt_limit_exceed(use_v2_block_manager: bool): block_size = 4 max_seqs = 64 max_model_len = 32 @@ -303,12 +334,13 @@ def test_prompt_limit_exceed(): max_model_len, enable_chunked_prefill=True) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] - - _, seq_group = create_dummy_prompt("2", prompt_length=48) + _, seq_group = create_dummy_prompt("2", + prompt_length=48, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -317,22 +349,28 @@ def test_prompt_limit_exceed(): assert out.ignored_seq_groups[0] == seq_group -def test_swap(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_swap(use_v2_block_manager: bool): """Verify swapping works with chunked prefill requests""" block_size = 4 max_seqs = 30 max_model_len = 200 max_num_batched_tokens = 30 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt("1", + prompt_length=60, + best_of=2, + block_size=block_size) scheduler.add_seq_group(seq_group) _, out = schedule_and_update_computed_tokens(scheduler) # The request is chunked. @@ -369,21 +407,27 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert out.blocks_to_swap_out == [] -def test_running_prefill_prioritized_over_swap(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_running_prefill_prioritized_over_swap(use_v2_block_manager: bool): block_size = 4 max_seqs = 30 max_model_len = 200 max_num_batched_tokens = 30 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 32 + cache_config.num_gpu_blocks = 32 scheduler = Scheduler(scheduler_config, cache_config, None) - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt("1", + prompt_length=60, + best_of=2, + block_size=block_size) scheduler.add_seq_group(seq_group) _, out = schedule_and_update_computed_tokens(scheduler) # The request is chunked. @@ -413,7 +457,9 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): scheduler.block_manager.can_swap_in = MagicMock() scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER - _, seq_group2 = create_dummy_prompt("2", prompt_length=60) + _, seq_group2 = create_dummy_prompt("2", + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group2) _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 1 @@ -455,22 +501,27 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert out.blocks_to_swap_out == [] -def test_chunked_prefill_preempt(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_chunked_prefill_preempt(use_v2_block_manager: bool): """Verify preempt works with chunked prefill requests""" block_size = 4 max_seqs = 30 max_model_len = 200 max_num_batched_tokens = 30 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) - _, seq_group = create_dummy_prompt("1", prompt_length=60) + _, seq_group = create_dummy_prompt("1", + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) _, out = schedule_and_update_computed_tokens(scheduler) # The request is chunked. @@ -517,22 +568,27 @@ def cannot_append_second_group2(seq_group, num_lookahead_slots): assert out.num_batched_tokens == max_num_batched_tokens -def test_chunked_prefill_max_seqs(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_chunked_prefill_max_seqs(use_v2_block_manager: bool): block_size = 4 max_seqs = 2 max_model_len = 80 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 128 + cache_config.num_gpu_blocks = 128 scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] - _, seq_group = create_dummy_prompt("1", prompt_length=65) + _, seq_group = create_dummy_prompt("1", + prompt_length=65, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) # The first prefill is chunked. @@ -542,7 +598,9 @@ def test_chunked_prefill_max_seqs(): # Add new requests. for i in range(4): - _, seq_group = create_dummy_prompt(str(i), prompt_length=65) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=65, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -564,16 +622,19 @@ def test_chunked_prefill_max_seqs(): assert not running[1].is_prefill() -def test_perfix_caching(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_perfix_caching(use_v2_block_manager: bool): """Verify allocating full blocks when prefix caching is enabled.""" block_size = 4 max_seqs = 10 max_model_len = 80 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 11168d2423b0e..88c6c3bb28e43 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -3,7 +3,8 @@ from typing import List, Set, Tuple from unittest.mock import MagicMock -import pytest # noqa +import pytest +from torch import Use # noqa from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus @@ -16,9 +17,11 @@ schedule_and_update_computed_tokens) -def test_scheduler_add_seq_group(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_add_seq_group(use_v2_block_manager: bool): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 1) + scheduler_config = SchedulerConfig( + 100, 64, 1, use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -27,14 +30,18 @@ def test_scheduler_add_seq_group(): # Add seq group to scheduler. num_seq_group = 4 for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), block_size) + _, seq_group = create_dummy_prompt(str(i), + block_size, + block_size=block_size) scheduler.add_seq_group(seq_group) assert scheduler.get_num_unfinished_seq_groups() == i + 1 -def test_scheduler_abort_seq_group(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_abort_seq_group(use_v2_block_manager: bool): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 1) + scheduler_config = SchedulerConfig( + 100, 64, 1, use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -54,11 +61,16 @@ def test_scheduler_abort_seq_group(): assert scheduler.get_num_unfinished_seq_groups() == 0 -def test_scheduler_schedule_simple(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_schedule_simple(use_v2_block_manager: bool): block_size = 4 num_seq_group = 4 max_model_len = 16 - scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len) + scheduler_config = SchedulerConfig( + 64, + num_seq_group, + max_model_len, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -67,7 +79,9 @@ def test_scheduler_schedule_simple(): # Add seq groups to scheduler. for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=block_size, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -91,20 +105,24 @@ def test_scheduler_schedule_simple(): append_new_token(out, 1) -def test_scheduler_prefill_prioritized(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_prefill_prioritized(use_v2_block_manager: bool): """Verify running batched tokens are not applied to prefill requests.""" block_size = 4 max_model_len = 30 max_batched_num_tokens = 30 - scheduler_config = SchedulerConfig(max_batched_num_tokens, 2, - max_model_len) + scheduler_config = SchedulerConfig( + max_batched_num_tokens, + 2, + max_model_len, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 2 - cache_config.num_gpu_blocks = 2 + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) # Add seq groups to scheduler. - _, seq_group_a = create_dummy_prompt("1", 1) + _, seq_group_a = create_dummy_prompt("1", 1, block_size=block_size) scheduler.add_seq_group(seq_group_a) # Schedule seq groups prompts. @@ -112,7 +130,7 @@ def test_scheduler_prefill_prioritized(): assert get_sequence_groups(out) == [seq_group_a] # Add a new prefill request B. - _, seq_group_b = create_dummy_prompt("2", 30) + _, seq_group_b = create_dummy_prompt("2", 30, block_size=block_size) scheduler.add_seq_group(seq_group_b) # Verify prefill requests are prioritized. Since max_batched_num_tokens @@ -121,18 +139,24 @@ def test_scheduler_prefill_prioritized(): assert get_sequence_groups(out) == [seq_group_b] -def test_scheduler_schedule_preempt_abort(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_schedule_preempt_abort(use_v2_block_manager: bool): block_size = 4 max_model_len = 16 - scheduler_config = SchedulerConfig(64, 2, max_model_len) + scheduler_config = SchedulerConfig( + 64, 2, max_model_len, use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 2 cache_config.num_gpu_blocks = 2 scheduler = Scheduler(scheduler_config, cache_config, None) # Add seq groups to scheduler. - seq_a, seq_group_a = create_dummy_prompt("1", block_size) - seq_b, seq_group_b = create_dummy_prompt("2", block_size) + seq_a, seq_group_a = create_dummy_prompt("1", + block_size, + block_size=block_size) + seq_b, seq_group_b = create_dummy_prompt("2", + block_size, + block_size=block_size) scheduler.add_seq_group(seq_group_a) scheduler.add_seq_group(seq_group_b) @@ -170,12 +194,17 @@ def test_scheduler_schedule_preempt_abort(): assert scheduler.get_num_unfinished_seq_groups() == 1 -def test_scheduler_max_seqs(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_max_seqs(use_v2_block_manager: bool): block_size = 4 num_seq_group = 4 max_seq_group = 2 max_model_len = 16 - scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len) + scheduler_config = SchedulerConfig( + 64, + max_seq_group, + max_model_len, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -184,7 +213,9 @@ def test_scheduler_max_seqs(): all_seq_groups: List[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=block_size, + block_size=block_size) all_seq_groups.append(seq_group) # Append 1 seq group @@ -211,9 +242,15 @@ def test_scheduler_max_seqs(): assert set(get_sequence_groups(out)) == set([all_seq_groups[1]]) -def test_scheduler_delay_factor(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_delay_factor(use_v2_block_manager: bool): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 16, delay_factor=0.5) + scheduler_config = SchedulerConfig( + 100, + 64, + 16, + delay_factor=0.5, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -221,7 +258,8 @@ def test_scheduler_delay_factor(): # schedule first prompt seq_group_meta, seq_group = create_dummy_prompt("0", - prompt_length=block_size) + prompt_length=block_size, + block_size=block_size) scheduler.add_seq_group(seq_group) seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert out.num_prefill_groups > 0 @@ -231,7 +269,8 @@ def test_scheduler_delay_factor(): # wait for a second before scheduling next prompt time.sleep(1) seq_group_meta, seq_group = create_dummy_prompt("1", - prompt_length=block_size) + prompt_length=block_size, + block_size=block_size) scheduler.add_seq_group(seq_group) # second prompt should *not* be scheduled @@ -248,11 +287,20 @@ def test_scheduler_delay_factor(): append_new_token(out, 1) -def test_swapped_out_prioritized(): - scheduler = initialize_scheduler(max_num_seqs=6) +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_swapped_out_prioritized(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(max_num_seqs=6, + block_size=block_size, + use_v2_block_manager=use_v2_block_manager, + num_cpu_blocks=64, + num_gpu_blocks=64) # best_of=2 * 3 == 6 sequences. for i in range(3): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + best_of=2, + block_size=block_size) scheduler.add_seq_group(seq_group) seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) # prefill scheduled now. @@ -276,7 +324,10 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): append_new_token(out, 1) # Add 1 more task. Swap should be prioritized over prefill. - _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + best_of=2, + block_size=block_size) scheduler.add_seq_group(seq_group) seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) append_new_token(out, 1) @@ -287,17 +338,26 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert out.blocks_to_swap_out == [] -def initialize_scheduler(*, - max_num_seqs=1000, - max_token_budget=1000, - max_model_len=1000, - lora_config=None): - block_size = 4 - scheduler_config = SchedulerConfig(max_token_budget, max_num_seqs, - max_model_len) +def initialize_scheduler( + *, + max_num_seqs=1000, + max_token_budget=1000, + max_model_len=1000, + lora_config=None, + use_v2_block_manager=False, + block_size=4, + num_cpu_blocks=8, + num_gpu_blocks=8, +): + block_size = block_size + scheduler_config = SchedulerConfig( + max_token_budget, + max_num_seqs, + max_model_len, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = num_cpu_blocks + cache_config.num_gpu_blocks = num_gpu_blocks scheduler = Scheduler(scheduler_config, cache_config, lora_config) return scheduler @@ -319,12 +379,18 @@ def add_token_budget(budget: SchedulingBudget, budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs) -def test_prefill_schedule_max_prompt_len(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prefill_schedule_max_prompt_len(use_v2_block_manager: bool): """ Test prompt longer than max_prompt_len is aborted. """ - scheduler = initialize_scheduler(max_model_len=30) - _, seq_group = create_dummy_prompt("0", prompt_length=60) + block_size = 4 + scheduler = initialize_scheduler(max_model_len=30, + use_v2_block_manager=use_v2_block_manager, + block_size=block_size) + _, seq_group = create_dummy_prompt("0", + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) budget = create_token_budget() output = scheduler._schedule_prefills(budget, None) @@ -336,14 +402,21 @@ def test_prefill_schedule_max_prompt_len(): assert len(remaining_waiting) == 0 -def test_prefill_schedule_token_budget(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prefill_schedule_token_budget(use_v2_block_manager: bool): """ Test token budget respected. """ - scheduler = initialize_scheduler() + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=64, + num_gpu_blocks=64) budget = create_token_budget(token_budget=0) for i in range(2): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) # 0 token budget == nothing is scheduled. @@ -366,10 +439,15 @@ def test_prefill_schedule_token_budget(): assert len(remaining_waiting) == 1 # Test when current_batched_tokens respected. - scheduler = initialize_scheduler() + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=16, + num_gpu_blocks=16) budget = create_token_budget(token_budget=60) add_token_budget(budget, 30, 0) - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) # Cannot schedule a prompt that doesn't fit the budget. scheduler.add_seq_group(seq_group) output = scheduler._schedule_prefills(budget, None) @@ -389,14 +467,21 @@ def test_prefill_schedule_token_budget(): assert len(remaining_waiting) == 0 -def test_prefill_schedule_max_seqs(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prefill_schedule_max_seqs(use_v2_block_manager: bool): """ Test max seq respected. """ - scheduler = initialize_scheduler() + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=64, + num_gpu_blocks=64) budget = create_token_budget(max_num_seqs=2) for i in range(3): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) output = scheduler._schedule_prefills(budget, None) remaining_waiting = scheduler.waiting @@ -410,7 +495,9 @@ def test_prefill_schedule_max_seqs(): scheduler.waiting = deque() budget = create_token_budget(max_num_seqs=2) add_token_budget(budget, 0, 2) - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) output = scheduler._schedule_prefills(budget, None) remaining_waiting = scheduler.waiting @@ -421,17 +508,24 @@ def test_prefill_schedule_max_seqs(): assert len(remaining_waiting) == 1 -def test_prefill_schedule_max_lora(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prefill_schedule_max_lora(use_v2_block_manager: bool): """ Test max lora is respected and prioritized. """ + block_size = 4 lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) - scheduler = initialize_scheduler(lora_config=lora_config) + scheduler = initialize_scheduler(lora_config=lora_config, + use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=64, + num_gpu_blocks=64) budget = create_token_budget(token_budget=120) curr_loras: Set[int] = set() for i in range(2): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, + block_size=block_size, lora_request=LoRARequest( lora_name=str(i), lora_int_id=i + 1, @@ -443,7 +537,9 @@ def test_prefill_schedule_max_lora(): # If a request is not scheduled because it hits max lora, it is # prioritized. Verify that. for i in range(2, 4): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) # Schedule 2 requests (0 and 2) output = scheduler._schedule_prefills(budget, curr_loras) @@ -467,14 +563,21 @@ def test_prefill_schedule_max_lora(): assert budget.num_batched_tokens == 60 -def test_prefill_schedule_no_block_manager_capacity(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prefill_schedule_no_block_manager_capacity(use_v2_block_manager): """ Test sequence cannot be scheduled due to block manager has no capacity. """ - scheduler = initialize_scheduler() + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_gpu_blocks=128, + num_cpu_blocks=128) budget = create_token_budget() for i in range(3): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) scheduler.block_manager.can_allocate = MagicMock() scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER @@ -489,7 +592,9 @@ def test_prefill_schedule_no_block_manager_capacity(): scheduler = initialize_scheduler() budget = create_token_budget() for i in range(3): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) scheduler.block_manager.can_allocate = MagicMock() scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER @@ -502,14 +607,21 @@ def test_prefill_schedule_no_block_manager_capacity(): assert len(remaining_waiting) == 0 -def test_decode_schedule_preempted(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_decode_schedule_preempted(use_v2_block_manager: bool): """ Test decodes cannot be scheduled and preempted. """ - scheduler = initialize_scheduler() + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=64, + num_gpu_blocks=64) curr_loras = None for i in range(3): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._add_seq_group_to_running(seq_group) @@ -541,15 +653,23 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert output.blocks_to_copy == [] -def test_decode_swap_beam_search(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_decode_swap_beam_search(use_v2_block_manager: bool): """ Test best_of > 1 swap out blocks """ - scheduler = initialize_scheduler() + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_gpu_blocks=64, + num_cpu_blocks=64) curr_loras = None budget = create_token_budget() for i in range(3): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + best_of=2, + block_size=block_size) scheduler._allocate_and_set_running(seq_group) scheduler._add_seq_group_to_running(seq_group) append_new_token_seq_group(60, seq_group, 1) @@ -589,12 +709,20 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert output.blocks_to_copy == [] -def test_schedule_decode_blocks_to_copy_update(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_decode_blocks_to_copy_update(use_v2_block_manager: bool): """ Verify blocks_to_copy is updated. """ - scheduler = initialize_scheduler() - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=4, + num_cpu_blocks=16, + num_gpu_blocks=16) + _, seq_group = create_dummy_prompt("1", + prompt_length=60, + best_of=2, + block_size=block_size) curr_loras = None scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) @@ -619,13 +747,19 @@ def test_schedule_decode_blocks_to_copy_update(): assert output.blocks_to_copy == [(2, 3)] -def test_schedule_swapped_simple(): - scheduler = initialize_scheduler() +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_swapped_simple(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size) curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt("1", + prompt_length=4, + best_of=2, + block_size=block_size) scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) + append_new_token_seq_group(4, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) scheduler._add_seq_group_to_swapped(seq_group) @@ -644,12 +778,17 @@ def test_schedule_swapped_simple(): assert blocks_to_swap_out == blocks_to_swap_in_reverse -def test_schedule_swapped_max_token_budget(): - scheduler = initialize_scheduler() +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_swapped_max_token_budget(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=32, + num_gpu_blocks=32) curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] - for _ in range(2): - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) @@ -676,12 +815,19 @@ def test_schedule_swapped_max_token_budget(): assert len(output.prefill_seq_groups) == 0 -def test_schedule_swapped_max_seqs(): - scheduler = initialize_scheduler() +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_swapped_max_seqs(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=64, + num_gpu_blocks=64) curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] for i in range(4): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=4) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) @@ -706,14 +852,21 @@ def test_schedule_swapped_max_seqs(): assert len(output.prefill_seq_groups) == 0 -def test_schedule_swapped_max_loras(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_swapped_max_loras(use_v2_block_manager: bool): + block_size = 4 lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) - scheduler = initialize_scheduler(lora_config=lora_config) + scheduler = initialize_scheduler(lora_config=lora_config, + use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=32, + num_gpu_blocks=32) curr_loras: Set[int] = set() blocks_to_swap_out: List[Tuple[int, int]] = [] for i in range(2): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, + block_size=block_size, lora_request=LoRARequest( lora_name=str(i), lora_int_id=i + 1, @@ -734,12 +887,20 @@ def test_schedule_swapped_max_loras(): assert len(curr_loras) == 1 -def test_schedule_swapped_cannot_swap_in(): - scheduler = initialize_scheduler() +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_swapped_cannot_swap_in(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=32, + num_gpu_blocks=32) curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] - for _ in range(2): - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + best_of=2, + block_size=block_size) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) @@ -759,12 +920,20 @@ def test_schedule_swapped_cannot_swap_in(): assert len(output.prefill_seq_groups) == 0 -def test_infeasible_swap(): - scheduler = initialize_scheduler() +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_infeasible_swap(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=32, + num_gpu_blocks=32) curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] - for _ in range(2): - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + best_of=2, + block_size=block_size) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) @@ -785,10 +954,18 @@ def test_infeasible_swap(): assert len(output.prefill_seq_groups) == 0 -def test_schedule_swapped_blocks_to_copy(): - scheduler = initialize_scheduler() +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_swapped_blocks_to_copy(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=32, + num_gpu_blocks=32) curr_loras = None - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt("1", + prompt_length=60, + best_of=2, + block_size=block_size) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) blocks_to_swap_out: List[Tuple[int, int]] = [] diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index ef34bebbb0f8c..cd989225e2483 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -162,6 +162,41 @@ def test_chat(): assert len(outputs) == 1 +def test_multi_chat(): + + llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") + + prompt1 = "Explain the concept of entropy." + prompt2 = "Explain what among us is." + + conversation1 = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt1 + }, + ] + + conversation2 = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt2 + }, + ] + + messages = [conversation1, conversation2] + + outputs = llm.chat(messages) + assert len(outputs) == 2 + + @pytest.mark.parametrize("image_urls", [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) def test_chat_multi_image(image_urls: List[str]): diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index 2ad8460023c25..63beaaba29a80 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -19,7 +19,11 @@ RTOL = 0.03 EXPECTED_VALUE = 0.58 DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"] -MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]] +MORE_ARGS_LIST = [ + ["--enable-chunked-prefill"], # Chunked + ["--num-scheduler-steps", "8"], # MS + ["--num-scheduler-steps", "8", "--multi-step-stream-outputs"] # MS+Stream +] @pytest.mark.parametrize("more_args", MORE_ARGS_LIST) diff --git a/tests/kernels/test_aqlm.py b/tests/kernels/test_aqlm.py new file mode 100644 index 0000000000000..860fb66b17354 --- /dev/null +++ b/tests/kernels/test_aqlm.py @@ -0,0 +1,37 @@ +import torch + +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops # noqa: F401 + + +def test_aqlm_dequant_opcheck(): + codes = torch.randint(-32768, + 32767, (22016, 512, 1), + device='cuda', + dtype=torch.int16) + codebooks = torch.rand((2, 65536, 1, 8), + device='cuda', + dtype=torch.float16) + codebook_partition_sizes = [11008, 11008] + + opcheck(torch.ops._C.aqlm_dequant, + (codes, codebooks, codebook_partition_sizes)) + + +def test_aqlm_gemm_opcheck(): + input = torch.rand((4, 4096), device='cuda', dtype=torch.float16) + codes = torch.randint(-32768, + 32767, (12288, 512, 1), + device='cuda', + dtype=torch.int16) + codebooks = torch.rand((3, 65536, 1, 8), + device='cuda', + dtype=torch.float16) + scales = torch.rand((12288, 1, 1, 1), device='cuda', dtype=torch.float16) + codebook_partition_sizes = [4096, 4096, 4096] + bias = None + + opcheck(torch.ops._C.aqlm_gemm, + (input, codes, codebooks, scales, codebook_partition_sizes, None)) + opcheck(torch.ops._C.aqlm_gemm, + (input, codes, codebooks, scales, codebook_partition_sizes, bias)) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 63d4d0d9ada81..57981e7572c94 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -205,7 +205,8 @@ def test_paged_attention( (output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0])) + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) elif version in ("v2", "rocm"): if is_hip(): @@ -249,7 +250,8 @@ def test_paged_attention( key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0])) + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) else: ops.paged_attention_rocm( @@ -277,7 +279,8 @@ def test_paged_attention( key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale), - cond=(head_size == HEAD_SIZES[0])) + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) else: raise AssertionError(f"Unknown version: {version}") diff --git a/tests/kernels/test_awq.py b/tests/kernels/test_awq.py new file mode 100644 index 0000000000000..e421aca48af2c --- /dev/null +++ b/tests/kernels/test_awq.py @@ -0,0 +1,38 @@ +import os + +import torch + +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops # noqa: F401 + + +def test_awq_dequantize_opcheck(): + os.environ["VLLM_USE_TRITON_AWQ"] = "0" + qweight = torch.randint(-2000000000, + 2000000000, (8192, 256), + device='cuda', + dtype=torch.int32) + scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16) + zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32) + split_k_iters = 0 + thx = 0 + thy = 0 + opcheck(torch.ops._C.awq_dequantize, + (qweight, scales, zeros, split_k_iters, thx, thy)) + + +def test_awq_gemm_opcheck(): + os.environ["VLLM_USE_TRITON_AWQ"] = "0" + input = torch.rand((2, 8192), device='cuda', dtype=torch.float16) + qweight = torch.randint(-2000000000, + 2000000000, (8192, 256), + device='cuda', + dtype=torch.int32) + scales = torch.randint(-2000000000, + 2000000000, (64, 256), + device='cuda', + dtype=torch.int32) + qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16) + split_k_iters = 8 + opcheck(torch.ops._C.awq_gemm, + (input, qweight, qzeros, scales, split_k_iters)) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 043c4923bd660..744e445fe6673 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -5,6 +5,8 @@ import torch.nn.functional as F from einops import rearrange +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops # noqa: F401 from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.utils import seed_everything @@ -84,6 +86,64 @@ def causal_conv1d_update_ref(x: torch.Tensor, return (out if activation is None else F.silu(out)).to(dtype=dtype_in) +def causal_conv1d_opcheck_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + seq_idx: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out=None, + activation: Optional[str] = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert (initial_states is + None), "initial_states must be None if seq_idx is not None" + assert (not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and (initial_states.stride(2) != 1 + and initial_states.stride(1) != 1): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert (final_states_out.stride(2) == 1 + or final_states_out.stride(1) == 1) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty(batch, + width - 1, + dim, + device=x.device, + dtype=x.dtype).transpose(1, 2) + else: + final_states_out = None + + opcheck(torch.ops._C.causal_conv1d_fwd, + (x, weight, bias, seq_idx, initial_states, final_states_out, + activation in ["silu", "swish"])) + + @pytest.mark.parametrize("return_final_states", [False, True]) @pytest.mark.parametrize("has_initial_states", [False, True]) @pytest.mark.parametrize("channel_last", [False, True]) @@ -149,6 +209,14 @@ def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, initial_states=initial_states_ref, return_final_states=return_final_states, activation=activation) + + causal_conv1d_opcheck_fn(x_ref, + weight_ref, + bias_ref, + initial_states=initial_states_ref, + return_final_states=return_final_states, + activation=activation) + if return_final_states: assert final_states is not None and final_states_ref is not None assert torch.allclose(final_states, @@ -205,6 +273,10 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + opcheck( + torch.ops._C.causal_conv1d_update, + (x, conv_state, weight, bias, activation in ["silu", "swish"], None)) + @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @@ -258,7 +330,5 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, bias, activation=activation) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/tests/kernels/test_cutlass.py b/tests/kernels/test_cutlass.py index cc4ca2e91e76f..993e67e827ea0 100644 --- a/tests/kernels/test_cutlass.py +++ b/tests/kernels/test_cutlass.py @@ -15,6 +15,9 @@ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +capability = current_platform.get_device_capability() +capability = capability[0] * 10 + capability[1] + def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) @@ -74,6 +77,9 @@ def cutlass_fp8_gemm_helper(m: int, torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2) + opcheck(torch.ops._C.cutlass_scaled_mm, + (out, a, b, scale_a, scale_b, bias)) + def cutlass_int8_gemm_helper(m: int, n: int, @@ -425,3 +431,7 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool): baseline = torch.mm(scale_a * a.to(dtype=torch.float32), scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0) + + +def test_cutlass_support_opcheck(): + opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, )) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 8e960d098c408..71f61c19dd951 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -4,6 +4,7 @@ import torch import vllm.attention.backends.flash_attn # noqa: F401 +from tests.kernels.utils import opcheck from vllm.utils import seed_everything NUM_HEADS = [(4, 4), (8, 2), (16, 2)] @@ -127,19 +128,19 @@ def test_flash_attn_with_paged_kv( else: test_utils = ["test_faketensor"] - torch.library.opcheck(torch.ops.vllm.flash_attn_with_kvcache, - args=tuple(), - kwargs=dict( - decode_query=query.unsqueeze(1), - key_cache=key_cache, - value_cache=value_cache, - softmax_scale=scale, - causal=True, - block_table=block_tables, - cache_seqlens=kv_lens_tensor, - softcap=soft_cap if soft_cap is not None else 0, - ), - test_utils=test_utils) + opcheck(torch.ops.vllm.flash_attn_with_kvcache, + args=tuple(), + kwargs=dict( + decode_query=query.unsqueeze(1), + key_cache=key_cache, + value_cache=value_cache, + softmax_scale=scale, + causal=True, + block_table=block_tables, + cache_seqlens=kv_lens_tensor, + softcap=soft_cap if soft_cap is not None else 0, + ), + test_utils=test_utils) ref_output = ref_paged_attn( query=query, @@ -232,23 +233,23 @@ def test_varlen_with_paged_kv( else: test_utils = ["test_faketensor"] - torch.library.opcheck(torch.ops.vllm.flash_attn_varlen_func, - args=tuple(), - kwargs=dict( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=cu_query_lens, - cu_seqlens_k=cu_kv_lens, - max_seqlen_q=max_query_len, - max_seqlen_k=max_kv_len, - softmax_scale=scale, - causal=True, - window_size=window_size, - block_table=block_tables, - softcap=soft_cap if soft_cap is not None else 0, - ), - test_utils=test_utils) + opcheck(torch.ops.vllm.flash_attn_varlen_func, + args=tuple(), + kwargs=dict( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=cu_query_lens, + cu_seqlens_k=cu_kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=block_tables, + softcap=soft_cap if soft_cap is not None else 0, + ), + test_utils=test_utils) ref_output = ref_paged_attn( query=query, diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index 49f5ce53aab54..c18f5f468dc5a 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -5,6 +5,7 @@ from tests.kernels.quant_utils import (FP8_DTYPE, ref_dynamic_per_tensor_fp8_quant, ref_dynamic_per_token_quant) +from tests.kernels.utils import opcheck from vllm.utils import seed_everything DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -16,6 +17,26 @@ SEEDS = [0] +def opcheck_fp8_quant(output, + input, + scale=None, + scale_ub=None, + use_per_token_if_dynamic=False): + if scale is not None: + opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale)) + elif use_per_token_if_dynamic: + scale = torch.empty((input.shape[0], 1), + device=input.device, + dtype=torch.float32) + opcheck(torch.ops._C.dynamic_per_token_scaled_fp8_quant, + (output, input, scale, scale_ub)) + else: + scale = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + opcheck(torch.ops._C.dynamic_scaled_fp8_quant, (output, input, scale)) + + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @@ -41,6 +62,12 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, torch.testing.assert_close(ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)) + opcheck_fp8_quant(ops_out, + x, + None, + scale_ub, + use_per_token_if_dynamic=True) + @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @@ -60,6 +87,8 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, torch.testing.assert_close(ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)) + opcheck_fp8_quant(ops_out, x) + # Regression test for a case with large activations where an int32 index cannot # represent the number of elements. diff --git a/tests/kernels/test_ggml.py b/tests/kernels/test_ggml.py new file mode 100644 index 0000000000000..dddb285bf26ec --- /dev/null +++ b/tests/kernels/test_ggml.py @@ -0,0 +1,22 @@ +import gguf +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops # noqa: F401 + + +@pytest.mark.parametrize("quant_type", [12]) +def test_ggml_opcheck(quant_type): + block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type] + shape = [256, 1152] + qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8) + m = qweight.shape[0] + n = qweight.shape[1] // type_size * block_size + opcheck(torch.ops._C.ggml_dequantize, (qweight, quant_type, m, n)) + + x = torch.rand((m, 512), device='cuda', dtype=torch.float16) + opcheck(torch.ops._C.ggml_mul_mat_a8, + (qweight, x, quant_type, qweight.shape[0])) + opcheck(torch.ops._C.ggml_mul_mat_vec_a8, + (qweight, x, quant_type, qweight.shape[0])) diff --git a/tests/kernels/test_gptq.py b/tests/kernels/test_gptq.py new file mode 100644 index 0000000000000..c1ca6f1f5191b --- /dev/null +++ b/tests/kernels/test_gptq.py @@ -0,0 +1,29 @@ +import torch + +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops # noqa: F401 + + +def test_gptq_shuffle_opcheck(): + weight = torch.randint(-2000000, + 2000000, (1792, 4096), + device='cuda', + dtype=torch.int32) + perm = torch.empty((0, ), device='cuda', dtype=torch.int32) + bit = 4 + opcheck(torch.ops._C.gptq_shuffle, (weight, perm, bit)) + + +def test_gptq_gemm_opcheck(): + a = torch.rand((240, 4096), device='cuda', dtype=torch.float16) + weight = torch.randint(-2000000, + 2000000, (512, 6144), + device='cuda', + dtype=torch.int32) + zeros = torch.zeros((32, 768), device='cuda', dtype=torch.int32) + scales = torch.rand((32, 6144), device='cuda', dtype=torch.float16) + idx = torch.empty((0, ), device='cuda', dtype=torch.int32) + use_exllama = True + bit = 4 + opcheck(torch.ops._C.gptq_gemm, + (a, weight, zeros, scales, idx, use_exllama, bit)) diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py index 0a90882223077..0dfa79e9af8ec 100644 --- a/tests/kernels/test_machete_gemm.py +++ b/tests/kernels/test_machete_gemm.py @@ -31,6 +31,8 @@ (257, 4224, 4160), (257, 4096, 4096), (64, 4096, 4096), + (1024, 4096, 8192), + (1024, 8192, 4096), ] ACT_TYPES = [torch.float16, torch.bfloat16] @@ -139,6 +141,7 @@ def test_machete_all_schedules(shape, atype: torch.dtype, output_ref = torch.matmul(a, w_ref) for schedule in ops.machete_supported_schedules(wtype): + print(f"Testing schedule {schedule}") output = ops.machete_gemm( a, b_q=w_q_machete, diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index 366475222a68e..5a6149562e886 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -3,6 +3,8 @@ import torch.nn.functional as F from einops import rearrange, repeat +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops # noqa: F401 from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) from vllm.utils import seed_everything @@ -161,6 +163,59 @@ def selective_scan_ref(u, return out if not return_last_state else (out, last_state) +def selective_scan_opcheck_fn(u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, + position_indices=None, + prev_state=None): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). + """ + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = B.unsqueeze(1) + if C.dim() == 3: + C = C.unsqueeze(1) + n_chunks = int((u.shape[-1] + 2048 - 1) / 2048) + x = torch.zeros(( + u.shape[0], + u.shape[1], + n_chunks, + int(A.shape[1] * 2), + ), + device=u.device, + dtype=torch.float32, + requires_grad=False) + x[:, :, 0, 0::2] = 1 + if prev_state is not None: + x[:, :, 0, 1::2].copy_(prev_state) + + # Disable test_autograd_registration for now as it seems to trigger + # a bogus error. + opcheck(torch.ops._C.selective_scan_fwd, + (u, delta, A, B, C, D, z, delta_bias, delta_softplus, + position_indices, x), + test_utils=["test_schema", "test_faketensor"]) + + @pytest.mark.parametrize('wtype', [torch.float32]) @pytest.mark.parametrize('itype', [torch.float32]) @pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) @@ -274,6 +329,17 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, assert state is not None and state_ref is not None assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + selective_scan_opcheck_fn(u, + delta, + A, + B, + C, + D, + z=z, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + return_last_state=return_last_state) + @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index 721d3a6a819ac..a9bb72156c39e 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -501,3 +501,18 @@ def test_marlin_qqq_gemm( max_diff = compute_max_diff(output, output_ref) assert max_diff < 0.04 + + +def test_marlin_gemm_opcheck(): + size_m = 2048 + size_n = 4096 + size_k = 4096 + a = torch.rand((size_m, size_n), device='cuda', dtype=torch.float16) + w = torch.randint(-5, 5, (256, 8192), device='cuda', dtype=torch.int32) + s = torch.full((32, size_k), 0.125, device='cuda', dtype=torch.float16) + wk = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL).scratch + x = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k) + y = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k) + torch.testing.assert_close(x, y) + opcheck(torch.ops._C.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k)) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 88605fab92ed3..1dd23bd5a54db 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -12,11 +12,14 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import vllm.envs as envs +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe, single_marlin_moe) -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE @@ -270,6 +273,35 @@ def test_fused_marlin_moe( assert compute_max_diff(marlin_output, triton_output) < 4e-2 + if ops.supports_moe_ops: + token_expert_indicies = torch.empty(m, + topk, + dtype=torch.int32, + device=a.device) + + opcheck(torch.ops._moe_C.topk_softmax, ( + topk_weights, + topk_ids, + token_expert_indicies, + score.float(), + )) + + block_size_m = 4 + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, + e) + + max_workspace_size = ((m + 255) // 256) * (max(2 * n, k) // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + opcheck(torch.ops._moe_C.marlin_gemm_moe, + (a, qweight1, sorted_token_ids, topk_weights, topk_ids, + scales1, g_idx1, sort_indices1, workspace, quant_type, m, + 2 * n, k, True, e, topk, block_size_m, True, False)) + @pytest.mark.skip("This test is here for the sake of debugging, " "don't run it in automated tests.") @@ -342,3 +374,29 @@ def test_single_marlin_moe_multiply( torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) assert compute_max_diff(marlin_output, torch_output) < 1e-2 + + +def test_moe_align_block_size_opcheck(): + num_experts = 4 + block_size = 4 + topk_ids = torch.randint(0, + num_experts, (3, 4), + dtype=torch.int32, + device='cuda') + + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids = torch.empty((max_num_tokens_padded, ), + dtype=torch.int32, + device=topk_ids.device) + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids = torch.empty((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) + num_tokens_post_pad = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) + + opcheck(torch.ops._C.moe_align_block_size, + (topk_ids, num_experts, block_size, sorted_ids, expert_ids, + num_tokens_post_pad)) diff --git a/tests/kernels/test_permute_cols.py b/tests/kernels/test_permute_cols.py new file mode 100644 index 0000000000000..14ad7a22cf7cf --- /dev/null +++ b/tests/kernels/test_permute_cols.py @@ -0,0 +1,15 @@ +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm._custom_ops import permute_cols + + +@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)]) +@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16]) +def test_permute_cols(shape, dtype): + x = torch.randn(shape, dtype=dtype).cuda() + perm = torch.randperm(x.shape[1]).to(torch.int).cuda() + opcheck(torch.ops._C.permute_cols, (x, perm)) + y = permute_cols(x, perm) + torch.testing.assert_close(y, x[:, perm]) \ No newline at end of file diff --git a/tests/kernels/test_rotary_embedding.py b/tests/kernels/test_rotary_embedding.py new file mode 100644 index 0000000000000..da879406b3936 --- /dev/null +++ b/tests/kernels/test_rotary_embedding.py @@ -0,0 +1,62 @@ +""" +Tests for miscellaneous utilities +""" + +from typing import Optional + +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding + + +def rotary_embedding_opcheck(rot, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None): + cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype) + + # ops.rotary_embedding()/batched_rotary_embedding() + # are in-place operations that update the query and key tensors. + if offsets is not None: + opcheck(torch.ops._C.batched_rotary_embedding, + (positions, query, key, rot.head_size, cos_sin_cache, + rot.is_neox_style, rot.rotary_dim, offsets)) + else: + opcheck(torch.ops._C.rotary_embedding, + (positions, query, key, rot.head_size, cos_sin_cache, + rot.is_neox_style)) + + +@pytest.mark.parametrize("device", ["cuda"]) +@pytest.mark.parametrize("max_position", [11, 4096, 32768]) +@pytest.mark.parametrize("is_neox_style", [True, False]) +@pytest.mark.parametrize("rotary_dim", [32]) +@pytest.mark.parametrize("head_size", [32, 108]) +@pytest.mark.parametrize("seq_len", [11, 1024]) +def test_rotary_embedding_opcheck(dist_init, device, max_position, + is_neox_style, rotary_dim, head_size, + seq_len): + batch_size = 1 + base = 0 + num_heads = 7 + rot = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style, torch.float32) + + positions = torch.randint(0, + max_position, (batch_size, seq_len), + device=device) + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=torch.float32, + device=device) + key = torch.randn_like(query) + + rotary_embedding_opcheck(rot, positions, query, key) + offsets = torch.zeros(batch_size * seq_len, + device=device, + dtype=torch.long) + rotary_embedding_opcheck(rot, positions, query, key, offsets) diff --git a/tests/kernels/test_utils.py b/tests/kernels/test_utils.py new file mode 100644 index 0000000000000..7e5126a76f88b --- /dev/null +++ b/tests/kernels/test_utils.py @@ -0,0 +1,24 @@ +""" +Tests for miscellaneous utilities +""" + +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm.platforms import current_platform + + +def test_convert_fp8_opcheck(): + data = torch.randn((256, 256), dtype=torch.float32, device="cuda") + result = torch.empty_like(data, dtype=torch.float8_e4m3fn) + opcheck(torch.ops._C_cache_ops.convert_fp8, (result, data, 1.0, "fp8")) + + +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="Only supported for CUDA") +def test_cuda_utils_opcheck(): + opcheck(torch.ops._C_cuda_utils.get_device_attribute, (0, 0)) + opcheck( + torch.ops._C_cuda_utils. + get_max_shared_memory_per_block_device_attribute, (0, )) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 5746932c30a45..08004efe9e2f8 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -2,12 +2,14 @@ import itertools import random +import unittest from numbers import Number from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union) import pytest import torch +from torch._prims_common import TensorLikeType from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, @@ -946,6 +948,34 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters, output_under_test.view_as(ideal_output)) +# Copied/modified from torch._refs.__init__.py +def fp8_allclose( + a: TensorLikeType, + b: TensorLikeType, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, +) -> bool: + """ + Reference implementation of torch.allclose + """ + torch._refs._check_close_args(name="torch.allclose", + a=a, + b=b, + rtol=rtol, + atol=atol) + + return bool( + torch.all( + torch.isclose(a.double(), + b.double(), + rtol=rtol, + atol=atol, + equal_nan=equal_nan)).item()) + + +# A special version of op check that has a restricted default set of test_utils +# and a patched version of allclose that supports fp8 types. def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, torch._library.custom_ops.CustomOpDef], args: Tuple[Any, ...], @@ -954,9 +984,10 @@ def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, raise_exception: bool = True, cond: bool = True) -> Dict[str, str]: - return torch.library.opcheck( - op, - args, - kwargs, - test_utils=test_utils, - raise_exception=raise_exception) if cond else {} + with unittest.mock.patch('torch.allclose', new=fp8_allclose): + return torch.library.opcheck( + op, + args, + kwargs, + test_utils=test_utils, + raise_exception=raise_exception) if cond else {} diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 314d6215cbd9c..41c37a4813c68 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -169,6 +169,7 @@ def test_punica_sgmv( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -183,6 +184,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, scaling, ) else: @@ -195,6 +197,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, add_inputs=True, ) ref_torch_groupgemm( @@ -347,6 +350,7 @@ def test_punica_expand_nslices( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -364,6 +368,7 @@ def test_punica_expand_nslices( lora_indices_tensor, batches, max_seq_length, + token_nums, slice_offset, hidden_size, add_inputs=True, diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 28a395af19e6d..185da6399a06a 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -84,6 +84,7 @@ def test_punica_sgmv( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -98,6 +99,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, scaling, ) else: @@ -110,6 +112,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, add_inputs=True, ) ref_torch_groupgemm( @@ -262,6 +265,7 @@ def test_punica_expand_nslices( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -279,6 +283,7 @@ def test_punica_expand_nslices( lora_indices_tensor, batches, max_seq_length, + token_nums, slice_offset, hidden_size, add_inputs=True, diff --git a/tests/models/decoder_only/vision_language/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py index e248151c40a60..eba0a1a1bce42 100644 --- a/tests/models/decoder_only/vision_language/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/test_phi3v.py @@ -1,16 +1,21 @@ import os import re -from typing import List, Optional, Tuple, Type +from typing import Callable, List, Optional, Tuple, Type import pytest -from transformers import AutoTokenizer +import torch +from transformers import AutoImageProcessor, AutoTokenizer +from vllm.inputs import InputContext, LLMInputs +from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID +from vllm.multimodal import MultiModalRegistry from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs from vllm.utils import is_cpu, is_hip -from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner -from ...utils import check_logprobs_close +from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) +from ...utils import build_model_context, check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -71,7 +76,7 @@ def run_test( All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects + For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. @@ -230,3 +235,174 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, mm_limit=2, tensor_parallel_size=1, ) + + +### Fast tests for correctness in processor_kwarg override handling + + +# Wrap lazy imports to avoid initializing CUDA during test collection +@pytest.fixture() +def input_processor_for_phi3v(): + from vllm.model_executor.models.phi3v import input_processor_for_phi3v + return input_processor_for_phi3v + + +@pytest.fixture() +def dummy_data_for_phi3v(): + from vllm.model_executor.models.phi3v import dummy_data_for_phi3v + return dummy_data_for_phi3v + + +@pytest.fixture() +def get_max_phi3v_image_tokens(): + from vllm.model_executor.models.phi3v import get_max_phi3v_image_tokens + return get_max_phi3v_image_tokens + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops", [4, 16, None]) +def test_input_mapper_override(model: str, image_assets: _ImageAssets, + num_crops: Optional[int]): + """Ensure that the [default] input mapper handles num_crops properly.""" + # We pass the processor kwargs here since for this model, we fall back to + # the default mapper; this will fall back to the HF mapper and forward + # mm_processor_kwargs to it. + mm_processor_kwargs = { + "num_crops": num_crops + } if num_crops is not None else {} + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=mm_processor_kwargs, + ) + + hf_processor = AutoImageProcessor.from_pretrained(model, + trust_remote_code=True, + **mm_processor_kwargs) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + + image = image_assets[0].pil_image + hf_result = hf_processor.preprocess( + image, + return_tensors="pt", + ) + + vllm_result = mm_registry.map_input( + ctx.model_config, + {"image": image}, + ) + + assert torch.all(hf_result["image_sizes"] == vllm_result["image_sizes"]) + assert torch.all( + hf_result["num_img_tokens"] == vllm_result["num_img_tokens"]) + + # For pixel values, the second axis should be the num_crops + 1 + # for the rescaled original image. The default value in VLLM falls + # back to the HF config, which is why we compare to the processor num_crops + assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"]) + assert vllm_result["pixel_values"].shape[1] == hf_processor.num_crops + 1 + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops,expected_max_tokens", [ + (4, 781), + (16, 2653), +]) +def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str, + num_crops: int, expected_max_tokens: int): + """Ensure get_max_phi3v_image_tokens handles num_crops properly.""" + # NOTE: mm_processor_kwargs on the context in this test is unused, since + # this is testing the mapper directly. In practice, the processor kwargs + # are wrapped in a closure when calling the max tokens func. We explicitly + # do NOT use the mm_processor_kwargs in the model context here to ensure + # that the max image tokens implementation is referencing a mix of the + # kwargs to the function and the original mm_processor_kwargs in case + # values are somehow updated and end up in a bad state. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=None, + ) + + actual_max_tokens = get_max_phi3v_image_tokens( + InputContext(ctx.model_config), + num_crops=num_crops, + ) + + assert expected_max_tokens == actual_max_tokens + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops,toks_per_img,num_imgs", [ + (4, 781, 1), + (4, 781, 2), + (16, 2653, 1), + (16, 2653, 2), +]) +def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str, + num_crops: int, toks_per_img: int, num_imgs: int): + """Ensure dummy_data_for_phi3v handles num_crops properly.""" + # Same as the previous test - don't initialize mm_processor_kwargs + # in this test and assume that the kwargs will be correctly expanded by + # the partial when calling the dummy data func. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=None, + ) + + sequence_data, _, = dummy_data_for_phi3v( + ctx=ctx, + seq_len=8192, # Should be bigger than num_imgs * toks_per_img + mm_counts={"image": num_imgs}, + num_crops=num_crops, + ) + # Ensure we have the right number of placeholders per num_crops size + img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID) + assert img_tok_count == toks_per_img * num_imgs + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops,expected_toks_per_img,num_imgs", [ + (4, 757, 1), + (4, 757, 2), + (16, 1921, 1), + (16, 1921, 2), +]) +def test_input_processor_override(input_processor_for_phi3v: Callable, + image_assets: _ImageAssets, model: str, + num_crops: int, expected_toks_per_img: int, + num_imgs: int): + """Ensure input_processor_for_phi3v handles num_crops properly.""" + # Same as the previous test - don't initialize mm_processor_kwargs + # in this test and assume that the kwargs will be correctly expanded by + # the partial when calling the custom input processor. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained(model) + # Build the image str / prompt based on the number of images we pass + img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) + prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" + images = [image_assets[0].pil_image] * num_imgs + + llm_inputs = LLMInputs(prompt_token_ids=tokenizer.encode(prompt), + prompt=prompt, + multi_modal_data={"image": images}) + + proc_llm_inputs = input_processor_for_phi3v( + ctx=ctx, + llm_inputs=llm_inputs, + num_crops=num_crops, + ) + + # Ensure we have the right number of placeholders per num_crops size + img_tok_count = proc_llm_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) + assert img_tok_count == expected_toks_per_img * num_imgs diff --git a/tests/models/encoder_decoder/vision_language/__init__.py b/tests/models/encoder_decoder/vision_language/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py new file mode 100644 index 0000000000000..cda0926d0baf9 --- /dev/null +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -0,0 +1,283 @@ +from typing import List, Optional, Tuple, Type, overload + +import pytest +from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, + BatchEncoding) + +from vllm.multimodal.utils import rescale_image_size +from vllm.sequence import SampleLogprobs + +from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) +from ....utils import multi_gpu_test +from ...utils import check_logprobs_close + +_LIMIT_IMAGE_PER_PROMPT = 1 + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + "<|image|><|begin_of_text|>The meaning of the image is", + "cherry_blossom": + "<|image|><|begin_of_text|>The city is", +}) + +text_only_prompts = [ + "The color of the sky is blue but sometimes it can also be", +] + +models = [ + "meta-llama/Llama-3.2-11B-Vision-Instruct", +] + + +def vllm_to_hf_output(vllm_output: Tuple[List[int], str, + Optional[SampleLogprobs]], + model: str): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + config = AutoConfig.from_pretrained(model) + image_token_id = config.image_token_index + + tokenizer = AutoTokenizer.from_pretrained(model) + eos_token_id = tokenizer.eos_token_id + + hf_output_ids = [ + token_id for idx, token_id in enumerate(output_ids) + if token_id != image_token_id or output_ids[idx - 1] != image_token_id + ] + + assert output_str[0] == " " + hf_output_str = output_str[1:] + if hf_output_ids[-1] == eos_token_id: + hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) + + return hf_output_ids, hf_output_str, out_logprobs + + +@overload +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: List[float], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + ... + + +@overload +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + sizes: List[Tuple[int, int]], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + ... + + +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: Optional[List[float]] = None, + sizes: Optional[List[Tuple[int, int]]] = None, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + images = [asset.pil_image for asset in image_assets] + + if size_factors is not None: + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + elif sizes is not None: + inputs_per_image = [( + [ + prompt if size is not None else text_only_prompts[0] + for size in sizes + ], + [ + image.resize(size) if size is not None else None + for size in sizes + ], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + if len(sizes) == 0: + inputs_per_image.append( + (text_only_prompts, [None] * len(text_only_prompts))) + else: + raise ValueError("You must provide either `size_factors` or `sizes`") + + _run_test(hf_runner, + vllm_runner, + inputs_per_image, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend) + + +def _run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + inputs: List[Tuple[List[str], PromptImageInput]], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test are from IMAGE_ASSETS. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding MultiModalConfig as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + # max_model_len should be greater than image_feature_size + with vllm_runner(model, + dtype=dtype, + max_num_seqs=16, + max_model_len=4096, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT + }) as vllm_model: + vllm_outputs_per_image = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs + ] + + def process(hf_inputs: BatchEncoding): + return hf_inputs + + from transformers import AutoConfig + from transformers.models.mllama import MllamaConfig as MllamaConfigHf + + # use transformer's MllamaConfig for hf_runner + # and vllm's MllamaConfig for vllm_runner + AutoConfig.register("mllama", MllamaConfigHf, exist_ok=True) + with hf_runner(model, + dtype=dtype, + postprocess_inputs=process, + auto_cls=AutoModelForVision2Seq) as hf_model: + hf_outputs_per_image = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs + ] + + from vllm.transformers_utils.configs.mllama import MllamaConfig + AutoConfig.register("mllama", MllamaConfig, exist_ok=True) + for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, + vllm_outputs_per_image): + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, model) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "sizes", + [ + # Text only + [], + # Single-size + [(512, 512)], + # Single-size, batched + [(512, 512), (512, 512), (512, 512)], + # Multi-size, batched + [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024), + (1024, 1024), (512, 1536), (512, 2028)], + # Multi-size, batched, including text only + [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024), + (1024, 1024), (512, 1536), (512, 2028), None], + # mllama has 8 possible aspect ratios, carefully set the sizes + # to cover all of them + ], +) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype, + max_tokens, num_logprobs) -> None: + run_test( + hf_runner, + vllm_runner, + image_assets, + model, + sizes=sizes, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "sizes", + [ + [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024), + (1024, 1024), (512, 1536), (512, 2028), None], + ], +) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_distributed(hf_runner, vllm_runner, image_assets, model, sizes, + dtype, max_tokens, num_logprobs) -> None: + run_test( + hf_runner, + vllm_runner, + image_assets, + model, + sizes=sizes, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=2, + ) diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 7c466c92d5293..76b2f494d5b25 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket): # Throws an error in first forward pass. with pytest.raises(RAISED_ERROR): - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket): # Engine is errored, should get ENGINE_DEAD_ERROR. with pytest.raises(MQEngineDeadError): - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket): # Generate call should throw ENGINE_DEAD_ERROR with pytest.raises(MQEngineDeadError): - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -153,27 +153,20 @@ async def test_failed_abort(tmp_socket): await client.check_health() # Trigger an abort on the client side. - async def bad_abort_after_2s(): - await asyncio.sleep(2.0) - await client.abort(request_id="foo") + # This request ID does not exist, and will cause the engine to error + await client.abort(request_id="foo") - # Trigger an abort in 2s from now. - abort_task = asyncio.create_task(bad_abort_after_2s()) - - # Exception in abort() will happen during this generation. - # This will kill the engine and should return ENGINE_DEAD_ERROR + # Future generation requests will now fail # with reference to the original KeyError("foo") with pytest.raises(MQEngineDeadError) as execinfo: async for _ in client.generate( - prompt="Hello my name is", - sampling_params=SamplingParams(max_tokens=2000), + inputs="Hello my name is", + sampling_params=SamplingParams(max_tokens=10), request_id=uuid.uuid4()): pass assert "KeyError" in repr(execinfo.value) assert client.errored - await abort_task - # This should raise the original error. with pytest.raises(RAISED_ERROR): await client.check_health() @@ -190,7 +183,7 @@ async def test_bad_request(tmp_socket): # Invalid request should fail, but not crash the server. with pytest.raises(ValueError): - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-1", lora_request=LoRARequest( @@ -199,7 +192,7 @@ async def test_bad_request(tmp_socket): pass # This request should be okay. - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-2"): pass diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py index 3ffa126070ca0..e27fd77923412 100644 --- a/tests/mq_llm_engine/utils.py +++ b/tests/mq_llm_engine/utils.py @@ -20,7 +20,7 @@ async def generate( count = 0 async for out in client.generate( request_id=request_id, - prompt="Hello my name is Robert and", + inputs="Hello my name is Robert and", sampling_params=SamplingParams(max_tokens=num_tokens, temperature=0)): diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index 36167cf95f589..ac2ebc622ba6f 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -107,7 +107,7 @@ def validate_generated_texts(hf_runner, quantization='bitsandbytes', load_format='bitsandbytes', tensor_parallel_size=vllm_tp_size, - enforce_eager=True, + enforce_eager=False, gpu_memory_utilization=0.8) as llm: vllm_outputs = llm.generate_greedy(prompts, 8) vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index 98a02dec895d2..a9bedc2956fdd 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -9,7 +9,7 @@ # 1. Increase max_tokens to 256. # 2. Increase beam_width to 8. # 3. Use the model "huggyllama/llama-7b". -MAX_TOKENS = [128] +MAX_TOKENS = [64] BEAM_WIDTHS = [4] MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"] @@ -33,8 +33,8 @@ def test_beam_search_single_input( max_tokens) with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_beam_search(example_prompts, - beam_width, max_tokens) + vllm_outputs = vllm_model.generate_beam_search_new( + example_prompts, beam_width, max_tokens) for i in range(len(example_prompts)): hf_output_ids, hf_output_texts = hf_outputs[i] diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index 1eba98cefd04a..4ddad66dce1fb 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -365,7 +365,7 @@ def test_accept_tokens_partially(seed: int, device: str): # Next only keep the first 2 draft tokens same as the zero temperature # tokens. For the remaining 3 choose some other tokens. In the # response we will expect the first 2 tokens to be the same as the - # draft tokens and the rest as -1 + # draft tokens and the recovered token and rest as -1 draft_token_ids_to_replace = get_draft_token_ids( batch_size, k, vocab_size, zero_temperature_token_ids) draft_token_ids = torch.cat( @@ -378,6 +378,8 @@ def test_accept_tokens_partially(seed: int, device: str): assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2]) + assert torch.all( + output_token_ids[:, 2] == target_with_bonus_probs.argmax(-1)[:, 2]) assert torch.all(output_token_ids[:, -3:] == -1) @@ -443,14 +445,14 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str): @pytest.mark.parametrize("seed", list(range(10))) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_replacement_token_ids(seed: int, device: str): +def test_get_recovered_token_ids(seed: int, device: str): """ Test the TypicalAcceptanceSampler's method for generating replacement token IDs. - This test verifies that the `_replacement_token_ids` method of the + This test verifies that the `_get_recovered_token_ids` method of the TypicalAcceptanceSampler correctly identifies the token IDs to be used - as replacements based on the target probability distribution. + as recovered token IDs based on the target probability distribution. Specifically, it ensures that the method correctly identifies the tokens with the highest probability for each sequence in the batch. """ @@ -462,10 +464,7 @@ def test_replacement_token_ids(seed: int, device: str): typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) typical_acceptance_sampler.init_gpu_tensors(device=device) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - expected_replacement_tokens = -torch.ones( - (batch_size, k), dtype=torch.long) - expected_replacement_tokens[:, 0] = torch.argmax(target_probs[:, 0, :], - dim=1) + expected_replacement_tokens = torch.argmax(target_probs, dim=-1) actual_replacement_tokens = ( - typical_acceptance_sampler._replacement_token_ids(target_probs)) + typical_acceptance_sampler._get_recovered_token_ids(target_probs)) assert torch.all(expected_replacement_tokens == actual_replacement_tokens) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 3d93f4a23b68a..b450ef97c89d4 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -1,13 +1,16 @@ from itertools import cycle -from typing import List, Optional, Tuple +from typing import List, Optional, Sequence, Tuple, Union import pytest from vllm import LLM, SamplingParams from vllm.model_executor.utils import set_random_seed +from vllm.sequence import PromptLogprobs, SampleLogprobs from ...conftest import cleanup -from ...models.utils import check_logprobs_close, check_outputs_equal +from ...models.utils import (TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs, + check_logprobs_close, check_outputs_equal) from ...utils import RemoteOpenAIServer PROMPTS = [ @@ -81,45 +84,77 @@ def get_output_from_llm_generator( return tokens, token_ids, acceptance_rate -def run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size: int, - max_output_len: int, - seed: Optional[int] = 0, - temperature: float = 0.0, - logprobs: int = 1): - org_args = { - **common_llm_kwargs, - **per_test_common_llm_kwargs, - **baseline_llm_kwargs, - } - - sd_args = { - **common_llm_kwargs, - **per_test_common_llm_kwargs, - **test_llm_kwargs, - } - - prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))] - - sampling_params = SamplingParams(temperature=temperature, - max_tokens=max_output_len, - seed=seed, - logprobs=logprobs) - - with vllm_runner(**org_args) as vllm_model: - org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) - - with vllm_runner(**sd_args) as vllm_model: - sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) - - check_logprobs_close(outputs_0_lst=org_outputs, - outputs_1_lst=sd_outputs, - name_0="org", - name_1="sd") +def check_logprobs_correctness( + spec_outputs: Sequence[Union[TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs]], + baseline_outputs: Sequence[Union[TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs]], + disable_logprobs: bool = False, +): + """Compare sampled and prompt logprobs between baseline and spec decoding + """ + if not disable_logprobs: + return check_logprobs_close( + outputs_0_lst=baseline_outputs, + outputs_1_lst=spec_outputs, + name_0="org", + name_1="sd", + ) + + # Check correctness when disable_logprobs == True + for spec_output, baseline_output in zip(spec_outputs, baseline_outputs): + # Check generated token logprobs. + spec_logprobs = spec_output[2] + baseline_logprobs = baseline_output[2] + _check_logprobs_when_output_disabled(spec_logprobs, + baseline_logprobs, + is_prompt_logprobs=False) + + # Check prompt logprobs too, if they exist + if len(baseline_output) == 4: + assert len(spec_output) == 4 + spec_prompt_logprobs = spec_output[3] + baseline_prompt_logprobs = baseline_output[3] + _check_logprobs_when_output_disabled(spec_prompt_logprobs, + baseline_prompt_logprobs, + is_prompt_logprobs=True) + + +def _check_logprobs_when_output_disabled( + spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs], + baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs], + is_prompt_logprobs: bool = False, +): + # Prompt logprobs are optional + if is_prompt_logprobs and baseline_logprobs is None: + assert spec_logprobs is None + return + + assert spec_logprobs is not None + assert baseline_logprobs is not None + assert len(spec_logprobs) == len(baseline_logprobs) + + # For each generated position of the sequence. + for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate( + zip(spec_logprobs, baseline_logprobs)): + + # First prompt logprob is expected to be None + if is_prompt_logprobs and baseline_pos_logprobs is None: + assert spec_pos_logprobs is None + assert pos == 0 + continue + + assert spec_pos_logprobs is not None + assert baseline_pos_logprobs is not None + + # When disabled, the 1 logprob is returned with dummy values for the + # score and rank, but the token id should match the baseline model + assert len(spec_pos_logprobs) == 1 + (spec_pos_logprob_token_id, + spec_pos_logprob) = next(iter(spec_pos_logprobs.items())) + assert spec_pos_logprob.rank == -1 + assert spec_pos_logprob.logprob == 0.0 + assert spec_pos_logprob_token_id in baseline_pos_logprobs def run_equality_correctness_test( @@ -135,7 +170,10 @@ def run_equality_correctness_test( disable_seed: bool = False, ignore_eos: bool = True, ensure_all_accepted: bool = False, - expected_acceptance_rate: Optional[float] = None): + expected_acceptance_rate: Optional[float] = None, + logprobs: Optional[int] = None, + prompt_logprobs: Optional[int] = None, + disable_logprobs: bool = False): org_args = { **common_llm_kwargs, @@ -157,10 +195,12 @@ def run_equality_correctness_test( sampling_params = SamplingParams(temperature=temperature, max_tokens=max_output_len, seed=seed, - ignore_eos=ignore_eos) + ignore_eos=ignore_eos, + logprobs=logprobs, + prompt_logprobs=prompt_logprobs) with vllm_runner(**org_args) as vllm_model: - org_outputs = vllm_model.generate(prompts, sampling_params) + org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) with vllm_runner(**sd_args) as vllm_model: if ensure_all_accepted or expected_acceptance_rate is not None: @@ -169,7 +209,7 @@ def run_equality_correctness_test( 'prometheus'] stat_logger.local_interval = -100 - sd_outputs = vllm_model.generate(prompts, sampling_params) + sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) if ensure_all_accepted or expected_acceptance_rate is not None: acceptance_rate = (stat_logger.metrics. @@ -185,11 +225,16 @@ def run_equality_correctness_test( if expected_acceptance_rate is not None: assert acceptance_rate >= expected_acceptance_rate - 1e-2 - check_outputs_equal(outputs_0_lst=org_outputs, - outputs_1_lst=sd_outputs, + # Only pass token entries, not the logprobs + check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs], + outputs_1_lst=[out[0:2] for out in sd_outputs], name_0="org", name_1="sd") + # Check logprobs if requested + if logprobs is not None or prompt_logprobs is not None: + check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs) + def run_equality_correctness_test_tp(model, common_llm_kwargs, diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py index f2af2c2bedb12..d7ca8815ec259 100644 --- a/tests/spec_decode/e2e/test_eagle_correctness.py +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -80,6 +80,64 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, batch_size, output_len, seed) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py index 03c1733f104ff..b7d54991e0535 100644 --- a/tests/spec_decode/e2e/test_logprobs.py +++ b/tests/spec_decode/e2e/test_logprobs.py @@ -4,7 +4,7 @@ from vllm import SamplingParams -from .conftest import run_logprob_correctness_test +from .conftest import run_equality_correctness_test @pytest.mark.parametrize( @@ -25,6 +25,10 @@ "speculative_model": "JackFram/llama-160m", "num_speculative_tokens": 3, "disable_logprobs_during_spec_decoding": False, + }, { + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + "disable_logprobs_during_spec_decoding": True, }]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize( @@ -41,16 +45,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs, seed: int, logprobs: int): """Verify output logprobs are equal with and without speculative decoding. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) @pytest.mark.parametrize( @@ -91,16 +98,18 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs, output_len: int, seed: int, logprobs: int): """Veriy logprob greedy equality with different speculation lens. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) @pytest.mark.parametrize( @@ -143,16 +152,18 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, seed: int, logprobs: int): """Verify logprobs greedy equality when some sequences skip speculation. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) @pytest.mark.parametrize( @@ -267,13 +278,15 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs, """Check the behavior when logprobs are disabled. Token choices should match with the base model. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index 7cefe99d026c6..8c90e147df23a 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -87,6 +87,65 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, temperature=0.0) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 8, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, + seed: int, logprobs: int): + """Verify greedy equality with different batch size.""" + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 2d0d6fb923ad1..7f3180befaffc 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -16,7 +16,7 @@ * Test greedy equality under various number of speculative tokens. With those tests, we can say at least, MLPSpeculator would not break the -correctess for the target model outputs. +correctness for the target model outputs. """ from unittest.mock import patch @@ -88,6 +88,61 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, temperature=0.0) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": SPEC_MODEL, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [8]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + """Verify greedy equality with different batch size.""" + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 89301f24e1159..850114eb7f5a8 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -76,6 +76,65 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, temperature=0.0) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "model_name": "JackFram/llama-68m", + }, +]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 8, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + """Verify greedy equality on a tiny model with different batch size.""" + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/test_embedded_commit.py b/tests/test_embedded_commit.py index 17b01651e39af..ffeacf34b7baf 100644 --- a/tests/test_embedded_commit.py +++ b/tests/test_embedded_commit.py @@ -2,6 +2,7 @@ def test_embedded_commit_defined(): - assert vllm.__commit__ != "COMMIT_HASH_PLACEHOLDER" - # 7 characters is the length of a short commit hash - assert len(vllm.__commit__) >= 7 + assert hasattr(vllm, "__version__") + assert hasattr(vllm, "__version_tuple__") + assert vllm.__version__ != "dev" + assert vllm.__version_tuple__ != (0, 0, "dev") diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt index 2f5c6c5a117f3..3e6eba04f1a87 100644 --- a/tests/weight_loading/models-large.txt +++ b/tests/weight_loading/models-large.txt @@ -1,4 +1,5 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main +compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main diff --git a/vllm/__init__.py b/vllm/__init__.py index 4be2942deacad..595db7afdda14 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -6,22 +6,22 @@ from vllm.entrypoints.fast_sync_llm import FastSyncLLM from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptType, TextPrompt, TokensPrompt +from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import (CompletionOutput, EmbeddingOutput, EmbeddingRequestOutput, RequestOutput) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from .version import __commit__, __version__ +from .version import __version__, __version_tuple__ __all__ = [ - "__commit__", "__version__", + "__version_tuple__", "LLM", "FastSyncLLM", "ModelRegistry", - "PromptType", + "PromptInputs", "TextPrompt", "TokensPrompt", "SamplingParams", diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index cd46be92155ff..944130d424394 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -21,8 +21,10 @@ if current_platform.is_rocm(): import vllm._rocm_C # noqa: F401 +supports_moe_ops = False with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 + supports_moe_ops = True with contextlib.suppress(ImportError): import vllm._custom_C # noqa: F401 @@ -275,9 +277,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_g_idx, use_exllama, bit) -# TODO: has to be a better way to do this -try: - torch.ops._C.gptq_gemm # noqa B018 +if hasattr(torch.ops._C, "gptq_gemm"): @torch.library.register_fake("_C::gptq_gemm") def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, @@ -287,8 +287,6 @@ def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, return torch.empty((a.size(0), b_q_weight.size(1)), dtype=a.dtype, device=a.device) -except Exception: - pass def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, @@ -314,9 +312,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, size_n, size_k) -# TODO: has to be a better way to do this -try: - torch.ops._C.gptq_marlin_24_gemm # noqa B018 +if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): @torch.library.register_fake("_C::gptq_marlin_24_gemm") def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, @@ -442,8 +438,8 @@ def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, @torch.library.register_fake("_C::machete_gemm") def machete_gemm_fake( a: torch.Tensor, - b_q: torch. - Tensor, # Should be the tensor returned by machete_prepack_B + # Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, b_type: ScalarType, b_scales: Optional[torch.Tensor] = None, b_zeros: Optional[torch.Tensor] = None, @@ -460,7 +456,8 @@ def machete_gemm_fake( @torch.library.register_fake("_C::machete_prepack_B") def machete_prepack_B_fake(b_q_weight: torch.Tensor, b_type: ScalarType) -> torch.Tensor: - return torch.empty_like(b_q_weight) + return torch.empty_like(b_q_weight, + memory_format=torch.contiguous_format) @torch.library.register_fake("_C::causal_conv1d_fwd") def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, @@ -472,10 +469,10 @@ def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, return torch.empty_like(x) @torch.library.register_fake("_C::causal_conv1d_update") - def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor, - weight: torch.Tensor, - bias_: Optional[torch.Tensor], - silu_activation: bool) -> torch.Tensor: + def causal_conv1d_update_fake( + x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, + bias_: Optional[torch.Tensor], silu_activation: bool, + conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: return torch.empty_like(x) @torch.library.register_fake("_C::selective_scan_fwd") @@ -486,20 +483,11 @@ def selective_scan_fwd_fake( delta_softplus: bool, index_: Optional[torch.Tensor], x: Optional[torch.Tensor]) -> List[torch.Tensor]: a = torch.empty_like(u) - if x is not None: - b = x - else: - b = torch.empty((u.size(0), u.size(1), A.size(1)), - dtype=u.dtype, - device=u.device) if z_ is not None: c = torch.empty_like(z_) - return [a, b, c] + return [a, c] else: - return [a, b] - -except Exception: - pass + return [a] # cutlass @@ -647,6 +635,18 @@ def machete_prepack_B(b_q_weight: torch.Tensor, return torch.ops._C.machete_prepack_B(b_q_weight, b_type) +if hasattr(torch.ops._C, "permute_cols"): + + @torch.library.register_fake("_C::permute_cols") + def _permute_cols_fake(a: torch.Tensor, + perm: torch.Tensor) -> torch.Tensor: + return torch.empty_like(a) + + +def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: + return torch.ops._C.permute_cols(a, perm) + + # fp8 def scaled_fp8_quant( input: torch.Tensor, @@ -833,6 +833,24 @@ def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indicies, gating_output) +if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): + + @torch.library.register_fake("_moe_C::marlin_gemm_moe") + def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, + sorted_ids: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, b_scales: torch.Tensor, + g_idx: torch.Tensor, perm: torch.Tensor, + workspace: torch.Tensor, b_q_type: ScalarType, + size_m: int, size_n: int, size_k: int, + is_k_full: bool, num_experts: int, topk: int, + moe_block_size: int, replicate_input: bool, + apply_weights: bool) -> torch.Tensor: + return torch.empty((size_m, topk, size_n), + dtype=a.dtype, + device=a.device) + + def reshape_and_cache( key: torch.Tensor, value: torch.Tensor, diff --git a/vllm/config.py b/vllm/config.py index 1acb92f9bd0e7..8203f8d927b0a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -222,6 +222,7 @@ def __init__(self, self._verify_embedding_mode() self._verify_quantization() self._verify_cuda_graph() + self._verify_bnb_config() def _init_multimodal_config( self, limit_mm_per_prompt: Optional[Mapping[str, int]] @@ -343,6 +344,28 @@ def _verify_cuda_graph(self) -> None: self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_model_len) + def _verify_bnb_config(self) -> None: + """ + The current version of bitsandbytes (0.44.0) with 8-bit models does not + yet support CUDA graph. + """ + is_bitsandbytes = self.quantization == "bitsandbytes" + has_quantization_config = (getattr(self.hf_config, + "quantization_config", None) + is not None) + is_8bit = (self.hf_config.quantization_config.get( + "load_in_8bit", False) if has_quantization_config else False) + if all([ + is_bitsandbytes, + has_quantization_config, + is_8bit, + not self.enforce_eager, + ]): + logger.warning( + "CUDA graph is not supported on BitAndBytes 8bit yet, " + "fallback to the eager mode.") + self.enforce_eager = True + def verify_async_output_proc(self, parallel_config, speculative_config, device_config) -> None: if not self.use_async_output_proc: @@ -407,13 +430,6 @@ def verify_with_parallel_config( "Pipeline parallelism is only supported for the following " f" architectures: {_PP_SUPPORTED_MODELS}.") - # Remove the constraint after the bitsandbytes issue is fixed: - # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308 - if self.quantization == "bitsandbytes" and self.enforce_eager is False: - logger.warning("CUDA graph is not supported on BitAndBytes yet, " - "fallback to the eager mode.") - self.enforce_eager = True - if pipeline_parallel_size > 1 and self.use_async_output_proc: logger.warning("Async output processor is not supported with " "pipeline parallelism currently. Disabling it.") @@ -566,7 +582,9 @@ def get_multimodal_config(self) -> "MultiModalConfig": @property def is_encoder_decoder_model(self) -> bool: """Extract the HF encoder/decoder model flag.""" - return getattr(self.hf_config, "is_encoder_decoder", False) + return getattr(self.hf_config, "is_encoder_decoder", False) or ( + (hasattr(self.hf_config, "text_config") and getattr( + self.hf_config.text_config, "is_encoder_decoder", False))) @property def is_embedding_model(self) -> bool: @@ -952,7 +970,7 @@ class SchedulerConfig: workers instead of an entire data. It should be enabled only when SPMD worker architecture is enabled. I.e., VLLM_USE_RAY_SPMD_WORKER=1 - + policy: The scheduling policy to use. "fcfs" (default) or "priority". """ def __init__(self, @@ -967,7 +985,9 @@ def __init__(self, is_multimodal_model: bool = False, preemption_mode: Optional[str] = None, num_scheduler_steps: int = 1, - send_delta_data: bool = False) -> None: + multi_step_stream_outputs: bool = False, + send_delta_data: bool = False, + policy: str = "fcfs") -> None: if max_num_batched_tokens is None: if enable_chunked_prefill: # It is the values that have the best balance between ITL @@ -1007,7 +1027,9 @@ def __init__(self, self.embedding_mode = embedding_mode self.preemption_mode = preemption_mode self.num_scheduler_steps = num_scheduler_steps + self.multi_step_stream_outputs = multi_step_stream_outputs self.send_delta_data = send_delta_data + self.policy = policy self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c3fa95f57b737..873decff37c1e 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -766,6 +766,79 @@ def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: else: return prompt_limit + def _get_priority(self, + seq_group: SequenceGroup) -> Tuple[Optional[int], float]: + """ Get the priority of the sequence group. + Highest preference to user-defined priority, followed by arrival time. + Args: + seq_group: The sequence group input. + Returns: + The priority of the sequence group. + """ + return seq_group.priority, seq_group.arrival_time + + def _schedule_priority_preemption( + self, + budget: SchedulingBudget, + ) -> int: + """Sorts waiting and running queue. Also, force preempt requests + from the running queue if their priority is lower. + Priority-based preemption is used with the priority policy. + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are scheduled. + Returns: + A count of priority-based preemptions. + """ + + waiting_queue = self.waiting + + running_queue = deque(sorted(self.running, key=self._get_priority)) + + blocks_to_swap_out: List[Tuple[int, int]] = [] + force_preemption_count = 0 + + if waiting_queue: + seq_group = waiting_queue.popleft() + num_new_seqs = seq_group.get_max_num_running_seqs() + num_new_tokens = self._get_num_new_tokens(seq_group, + SequenceStatus.WAITING, + False, budget) + + #Only preempt if priority inversion exists + while running_queue and self._get_priority( + running_queue[-1]) > self._get_priority(seq_group): + #Only preempt if waiting sequence cannot be allocated + can_allocate = self.block_manager.can_allocate(seq_group) + if (num_new_tokens and can_allocate == AllocStatus.OK + and budget.can_schedule(num_new_tokens=num_new_tokens, + num_new_seqs=num_new_seqs)): + break + + #Adjust budget to remove the victim sequence group + vseq_group = running_queue.pop() + num_running_tokens = self._get_num_new_tokens( + vseq_group, SequenceStatus.RUNNING, False, budget) + budget.subtract_num_batched_tokens(vseq_group.request_id, + num_running_tokens) + num_running_seqs = vseq_group.get_max_num_running_seqs() + budget.subtract_num_seqs(vseq_group.request_id, + num_running_seqs) + + #Preempt out the victim sequence group + self._preempt(vseq_group, blocks_to_swap_out, + PreemptionMode.RECOMPUTE) + waiting_queue.appendleft(vseq_group) + force_preemption_count += 1 + #Put the sequence back into the waiting queue + waiting_queue.appendleft(seq_group) + + waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) + + self.waiting = waiting_queue + self.running = running_queue + return force_preemption_count + def _schedule_prefills( self, budget: SchedulingBudget, @@ -917,6 +990,10 @@ def _schedule_default(self) -> SchedulerOutputs: curr_loras, enable_chunking=False) + if len(prefills.seq_groups + ) == 0 and self.scheduler_config.policy == "priority": + self._schedule_priority_preemption(budget) + # Don't schedule decodes if prefills are scheduled. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running # only contains decode requests, not chunked prefills. @@ -1477,14 +1554,14 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, # the number of new tokens that is dividable by the block size # to avoid partial block matching. block_size = self.cache_config.block_size - reminder = budget.token_budget % block_size - if reminder != 0: + remainder = budget.token_budget % block_size + if remainder != 0: raise ValueError("When enabling chunked prefill and " "prefix caching, max_num_batched_tokens " "(chunk size) must be dividable by " "block size, but got chunk_size " f"({budget.token_budget}) % block_size " - f"({block_size}) = {reminder}") + f"({block_size}) = {remainder}") if remaining_token_budget < num_new_tokens: num_new_tokens = (remaining_token_budget // block_size) * block_size diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index b507cd2e1cddb..7d526b25ed193 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -9,11 +9,12 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup +from zmq import IPV6 # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import get_ip, get_open_port +from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL @@ -214,6 +215,8 @@ def __init__( self.remote_socket = context.socket(XPUB) self.remote_socket.setsockopt(XPUB_VERBOSE, True) remote_subscribe_port = get_open_port() + if is_valid_ipv6_address(connect_ip): + self.remote_socket.setsockopt(IPV6, 1) socket_addr = f"tcp://*:{remote_subscribe_port}" self.remote_socket.bind(socket_addr) @@ -274,6 +277,8 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.remote_socket = context.socket(SUB) self.remote_socket.setsockopt_string(SUBSCRIBE, "") + if is_valid_ipv6_address(handle.connect_ip): + self.remote_socket.setsockopt(IPV6, 1) socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}" logger.debug("Connecting to %s", socket_addr) self.remote_socket.connect(socket_addr) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 2f3160019e693..eef544b4d4a9f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -145,6 +145,7 @@ class EngineArgs: max_cpu_loras: Optional[int] = None device: str = 'auto' num_scheduler_steps: int = 1 + multi_step_stream_outputs: bool = False ray_workers_use_nsight: bool = False num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 @@ -595,6 +596,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help=('Maximum number of forward steps per ' 'scheduler call.')) + parser.add_argument( + '--multi-step-stream-outputs', + action='store_true', + help='If True, then multi-step will stream outputs for every step') parser.add_argument( '--scheduler-delay-factor', type=float, @@ -999,6 +1004,7 @@ def create_engine_config(self) -> EngineConfig: is_multimodal_model=model_config.is_multimodal_model, preemption_mode=self.preemption_mode, num_scheduler_steps=self.num_scheduler_steps, + multi_step_stream_outputs=self.multi_step_stream_outputs, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), ) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f108751056ab5..34e7e05341f02 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -17,7 +17,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptType +from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -405,7 +405,7 @@ async def stop_remote_worker_execution_loop_async(self) -> None: async def add_request_async( self, request_id: str, - prompt: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -420,7 +420,7 @@ async def add_request_async( arrival_time = time.time() preprocessed_inputs = await self.input_preprocessor.preprocess_async( - prompt, + inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -777,7 +777,7 @@ async def run_engine_loop(engine_ref: ReferenceType): async def add_request( self, request_id: str, - prompt: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -797,7 +797,7 @@ async def add_request( stream = self._request_tracker.add_request( request_id, verbose=self.log_requests, - prompt=prompt, + inputs=inputs, params=params, arrival_time=arrival_time or time.time(), lora_request=lora_request, @@ -808,7 +808,7 @@ async def add_request( async def generate( self, - prompt: PromptType, + inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -822,7 +822,8 @@ async def generate( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -880,7 +881,7 @@ async def generate( """ async for output in await self.add_request( request_id, - prompt, + inputs, sampling_params, lora_request=lora_request, trace_headers=trace_headers, @@ -890,7 +891,7 @@ async def generate( async def encode( self, - prompt: PromptType, + inputs: PromptInputs, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -903,7 +904,8 @@ async def encode( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -957,7 +959,7 @@ async def encode( """ async for output in await self.add_request( request_id, - prompt, + inputs, pooling_params, lora_request=lora_request, trace_headers=trace_headers, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 62731e9268469..939f0630c4120 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -29,7 +29,7 @@ from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, - InputRegistry, LLMInputs, PromptType) + InputRegistry, LLMInputs, PromptInputs) from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -95,7 +95,7 @@ class OutputData(NamedTuple): class SchedulerContext: - def __init__(self): + def __init__(self, multi_step_stream_outputs: bool = False): self.output_queue: Deque[OutputData] = deque() self.request_outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] @@ -103,6 +103,8 @@ def __init__(self): List[SequenceGroupMetadata]] = None self.scheduler_outputs: Optional[SchedulerOutputs] = None + self.multi_step_stream_outputs: bool = multi_step_stream_outputs + def append_output(self, outputs: List[SamplerOutput], seq_group_metadata_list: List[SequenceGroupMetadata], scheduler_outputs: SchedulerOutputs, is_async: bool, @@ -219,6 +221,7 @@ def __init__( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, + use_cached_outputs: bool = False, ) -> None: logger.info( "Initializing an LLM engine (v%s) with config: " @@ -234,8 +237,9 @@ def __init__( "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " - "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s, mm_processor_kwargs=%s)", + "num_scheduler_steps=%d, multi_step_stream_outputs=%s, " + "enable_prefix_caching=%s, use_async_output_proc=%s, " + "use_cached_outputs=%s, mm_processor_kwargs=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -266,8 +270,10 @@ def __init__( model_config.served_model_name, scheduler_config.use_v2_block_manager, scheduler_config.num_scheduler_steps, + scheduler_config.multi_step_stream_outputs, cache_config.enable_prefix_caching, model_config.use_async_output_proc, + use_cached_outputs, model_config.mm_processor_kwargs, ) # TODO(woosuk): Print more configs in debug mode. @@ -287,6 +293,7 @@ def __init__( self.observability_config = observability_config or ObservabilityConfig( ) self.log_stats = log_stats + self.use_cached_outputs = use_cached_outputs if not self.model_config.skip_tokenizer_init: self.tokenizer = self._init_tokenizer() @@ -379,7 +386,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: ] self.scheduler_contexts = [ - SchedulerContext() + SchedulerContext(multi_step_stream_outputs=self.scheduler_config. + multi_step_stream_outputs) for _ in range(self.parallel_config.pipeline_parallel_size) ] @@ -623,6 +631,7 @@ def _add_processed_request( lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, ) -> None: self._validate_model_inputs(processed_inputs) # Create the sequences. @@ -653,7 +662,8 @@ def _add_processed_request( lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + priority=priority) elif isinstance(params, PoolingParams): seq_group = self._create_sequence_group_with_pooling( request_id, @@ -662,7 +672,8 @@ def _add_processed_request( arrival_time=arrival_time, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + priority=priority) else: raise ValueError( "Either SamplingParams or PoolingParams must be provided.") @@ -681,12 +692,13 @@ def stop_remote_worker_execution_loop(self) -> None: def add_request( self, request_id: str, - prompt: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: """Add a request to the engine's request pool. @@ -696,7 +708,8 @@ def add_request( Args: request_id: The unique ID of the request. - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. params: Parameters for sampling or pooling. :class:`~vllm.SamplingParams` for text generation. @@ -704,6 +717,8 @@ def add_request( arrival_time: The arrival time of the request. If None, we use the current monotonic time. trace_headers: OpenTelemetry trace headers. + priority: The priority of the request. + Only applicable with priority scheduling. Details: - Set arrival_time to the current time if it is None. @@ -732,11 +747,16 @@ def add_request( if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") + + if priority > 0 and not self.scheduler_config.policy == "priority": + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + if arrival_time is None: arrival_time = time.time() preprocessed_inputs = self.input_preprocessor.preprocess( - prompt, + inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -751,6 +771,7 @@ def add_request( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, + priority=priority, ) def _create_sequence_group_with_sampling( @@ -763,6 +784,7 @@ def _create_sequence_group_with_sampling( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, encoder_seq: Optional[Sequence] = None, + priority: int = 0, ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -789,7 +811,8 @@ def _create_sequence_group_with_sampling( lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + priority=priority) return seq_group @@ -802,6 +825,7 @@ def _create_sequence_group_with_pooling( lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], encoder_seq: Optional[Sequence] = None, + priority: int = 0, ) -> SequenceGroup: """Creates a SequenceGroup with PoolingParams.""" # Defensive copy of PoolingParams, which are used by the pooler @@ -814,7 +838,8 @@ def _create_sequence_group_with_pooling( lora_request=lora_request, pooling_params=pooling_params, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + priority=priority) return seq_group def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: @@ -998,7 +1023,8 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - request_output = RequestOutputFactory.create(seq_group) + request_output = RequestOutputFactory.create( + seq_group, use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) @@ -1019,8 +1045,8 @@ def _process_model_outputs(self, for scheduler in self.scheduler: scheduler.free_finished_seq_groups() - # For multi-step, do not create outputs each iteration - if not is_last_step: + # For multi-step without streaming, don't create outputs each iteration + if not is_last_step and not ctx.multi_step_stream_outputs: # Immediately process request outputs here (if callback is given) if (finished_now and self.process_request_outputs_callback is not None): @@ -1037,17 +1063,27 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - request_output = RequestOutputFactory.create(seq_group) + request_output = RequestOutputFactory.create( + seq_group, use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) + # For multi-step with streaming, create outputs each iteration + if not is_last_step and ctx.multi_step_stream_outputs: + # Immediately process request outputs here (if callback is given) + if self.process_request_outputs_callback is not None: + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + for seq_group in scheduler_outputs.ignored_seq_groups: params = seq_group.sampling_params if params is not None and params.output_kind == ( RequestOutputKind.DELTA) and not seq_group.is_finished(): continue - request_output = RequestOutputFactory.create(seq_group) + request_output = RequestOutputFactory.create( + seq_group, use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) @@ -1698,7 +1734,11 @@ def is_embedding_model(self): def _validate_model_inputs(self, inputs: Union[LLMInputs, EncoderDecoderLLMInputs]): - if self.is_encoder_decoder_model(): + if self.model_config.is_multimodal_model: + # For encoder-decoder multimodal models, the max_prompt_len + # restricts the decoder prompt length + prompt_ids = inputs.get("prompt_token_ids") + elif self.is_encoder_decoder_model(): prompt_ids = inputs.get("encoder_prompt_token_ids") else: prompt_ids = inputs.get("prompt_token_ids") diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 09aa279f1e22c..1603189979a2c 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -3,7 +3,7 @@ from typing import List, Mapping, Optional, Union from vllm import PoolingParams -from vllm.inputs import PromptType +from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest @@ -23,7 +23,7 @@ class MQEngineDeadError(RuntimeError): @dataclass class RPCProcessRequest: - prompt: PromptType + inputs: PromptInputs params: Union[SamplingParams, PoolingParams] request_id: str lora_request: Optional[LoRARequest] = None @@ -43,10 +43,6 @@ class RPCAbortRequest: request_id: str -class RPCHealthRequest: - pass - - class RPCStartupRequest(Enum): IS_SERVER_READY = 1 @@ -56,8 +52,13 @@ class RPCStartupResponse: tracing_enabled: bool -RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest, - RPCStartupRequest] +class RPCUProfileRequest(Enum): + START_PROFILE = 1 + STOP_PROFILE = 2 + + +RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, + RPCUProfileRequest] REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 71099115ea125..0ee56f7bf8407 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -20,12 +20,12 @@ IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, RPC_REQUEST_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCError, RPCHealthRequest, - RPCProcessRequest, RPCStartupRequest, - RPCStartupResponse) + RPCError, RPCProcessRequest, + RPCStartupRequest, RPCStartupResponse, + RPCUProfileRequest) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT -from vllm.inputs import PromptType +from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -38,10 +38,10 @@ class MQClientClosedError(Exception): """Exception class raised when the client is used post-close. - + The client can be closed, which closes the ZMQ context. This normally - happens on server shutdown. In some cases, methods like abort and - do_log_stats will still be called and then try to open a socket, which + happens on server shutdown. In some cases, methods like abort and + do_log_stats will still be called and then try to open a socket, which causes a ZMQError and creates a huge stack trace. So, we throw this error such that we can suppress it. """ @@ -95,9 +95,9 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig): self.output_socket: Socket = self.context.socket(zmq.constants.PULL) self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") - # IPC path for ack of check_health requests. - self.health_socket: Socket = self.context.socket(zmq.constants.PULL) - self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + # IPC path for acking heartbeats. + self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) + self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" @@ -124,34 +124,28 @@ def get_data_socket(self) -> Iterator[Socket]: finally: socket.close(linger=0) - async def run_check_health_loop(self, timeout: int): - """Background loop that continually probes the RPCServer for health. - - The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which - the MQLLMEngine server is blocking on. - - The Server replies on the HEALTH_SOCKET (rather than on the - OUTPUT_SOCKET such that the messages are not intermingled with - output streaming). + async def run_heartbeat_loop(self, timeout: int): + """Background loop that continually listens to the RPCServer for + heartbeats. """ - try: while True: - if await self.health_socket.poll(timeout=timeout) == 0: - # Wakeup every N seconds and do a health probe. - await self._send_one_way_rpc_request( - RPCHealthRequest(), self.input_socket) - - # Wait for ack from the health socket. - await self._await_ack(error_message="Health check failed.", - socket=self.health_socket) + if await self.heartbeat_socket.poll(timeout=timeout) == 0: + # No heartbeat was received. Set error and exit the loop + self._set_errored( + TimeoutError("No heartbeat received " + "from MQLLMEngine")) + logger.debug("Shutting down MQLLMEngineClient check " + "health loop due to timeout") + break + else: - # Server sent a health status message unprompted. + # Heartbeat received- check the message await self._check_success( - error_message="Health check failed.", - socket=self.health_socket) + error_message="Heartbeat failed.", + socket=self.heartbeat_socket) - logger.debug("Health probe successful.") + logger.debug("Heartbeat successful.") except asyncio.CancelledError: logger.debug("Shutting down MQLLMEngineClient check health loop.") @@ -234,7 +228,7 @@ async def setup(self): # Start health_loop. self.health_loop = asyncio.create_task( - self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT)) + self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) def close(self): """Destroy the ZeroMQ Context.""" @@ -351,7 +345,7 @@ async def do_log_stats(self): async def check_health(self): """ The check health loop probes the health status of the - Engine's health every N seconds and sets _errored_with + Engine's health every N seconds and sets _errored_with if the engine is unhealthy. """ if self._errored_with is not None: @@ -375,7 +369,7 @@ def dead_error(self) -> BaseException: def generate( self, - prompt: PromptType, + inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -389,7 +383,8 @@ def generate( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -398,13 +393,13 @@ def generate( prompt_adapter_request: Prompt Adapter request to use for generation, if any. """ - return self._process_request(prompt, sampling_params, request_id, + return self._process_request(inputs, sampling_params, request_id, lora_request, trace_headers, prompt_adapter_request) def encode( self, - prompt: PromptType, + inputs: PromptInputs, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -417,7 +412,8 @@ def encode( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -428,12 +424,12 @@ def encode( The output `EmbeddingRequestOutput` objects from the LLMEngine for the request. """ - return self._process_request(prompt, pooling_params, request_id, + return self._process_request(inputs, pooling_params, request_id, lora_request, trace_headers) async def _process_request( self, - prompt: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], request_id: str, lora_request: Optional[LoRARequest] = None, @@ -466,7 +462,7 @@ async def _process_request( request_bytes = pickle.dumps( RPCProcessRequest( - prompt=prompt, + inputs=inputs, params=params, request_id=request_id, lora_request=lora_request, @@ -497,3 +493,15 @@ async def _process_request( await self.abort(request_id) finally: self.output_queues.pop(request_id) + + async def start_profile(self) -> None: + """Start profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket) + + async def stop_profile(self) -> None: + """Stop profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 788c1573ae255..1b2e7ccf8664f 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -1,5 +1,7 @@ import pickle import signal +import threading +import time from contextlib import contextmanager from typing import Iterator, List, Optional, Union @@ -15,10 +17,12 @@ IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCError, RPCHealthRequest, - RPCProcessRequest, RPCStartupRequest, - RPCStartupResponse) + RPCError, RPCProcessRequest, + RPCStartupRequest, RPCStartupResponse, + RPCUProfileRequest) # yapf: enable +from vllm.envs import VLLM_RPC_TIMEOUT +from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext @@ -66,7 +70,14 @@ def __init__(self, *args, log_requests: bool = True, **kwargs) -> None: - self.engine = LLMEngine(*args, **kwargs) + # For MQLLMEngine, we can use cached outputs, since each new request + # output is immediately pickled and send over the socket, which frees + # the python object to be reused again. + use_cached_outputs = True + + self.engine = LLMEngine(*args, + **kwargs, + use_cached_outputs=use_cached_outputs) self.log_requests = log_requests self.use_async_sockets = use_async_sockets @@ -84,9 +95,9 @@ def __init__(self, self.output_socket = self.ctx.socket(zmq.constants.PUSH) self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") - # Send health status back to client. - self.health_socket = self.ctx.socket(zmq.constants.PUSH) - self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + # Send heartbeats back to client. + self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) + self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" @@ -94,6 +105,20 @@ def __init__(self, # Error state. self._errored_with: Optional[BaseException] = None + # Heartbeat thread + self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop, + daemon=True) + self._heartbeat_stop_event = threading.Event() + # The heartbeat needs to be faster than what the client will wait for + # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds + self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0 + + self._last_alive_time = time.time() + # The heartbeats can tolerate a long period of the engine chugging + # away at a generation request. + # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds + self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0 + @property def dead_error(self) -> BaseException: if self._errored_with is not None: @@ -124,6 +149,8 @@ def start(self): try: logger.debug("Starting Startup Loop.") self.run_startup_loop() + logger.debug("Starting heartbeat thread") + self.heartbeat_thread.start() logger.debug("Starting Engine Loop.") self.run_engine_loop() except Exception as e: @@ -137,6 +164,7 @@ def start(self): def cleanup(self): """Cleanup zeromq state on shutdown.""" # Closes all sockets and destroys context. + self._heartbeat_stop_event.set() self.ctx.destroy(linger=0) del self.engine @@ -175,9 +203,11 @@ def run_engine_loop(self): """Core busy loop of the LLMEngine.""" while True: + self._alive() if not self.engine.has_unfinished_requests(): # Poll until there is work to do. while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + self._alive() self.engine.do_log_stats() logger.debug("Waiting for new requests in engine loop.") @@ -193,7 +223,6 @@ def run_engine_loop(self): def engine_step(self) -> List[RequestOutput]: """Engine step wrapper with error handling.""" - try: return self.engine.step() except SystemExit: @@ -222,10 +251,14 @@ def handle_new_input(self): self._handle_process_request(request) elif isinstance(request, RPCAbortRequest): self._handle_abort_request(request) - elif isinstance(request, RPCHealthRequest): - self._handle_health_request() + elif isinstance(request, RPCUProfileRequest): + if request == RPCUProfileRequest.START_PROFILE: + self.start_profile() + else: + self.stop_profile() else: - raise ValueError("Unknown RPCRequest Type: {request}") + raise ValueError("Unknown RPCRequest Type: " + f"{type(request)}") except Exception as e: self._set_errored(e) @@ -245,7 +278,7 @@ def _handle_process_request(self, request: RPCProcessRequest): try: self.engine.add_request( request_id=request_id, - prompt=request.prompt, + inputs=request.inputs, params=request.params, lora_request=request.lora_request, trace_headers=request.trace_headers, @@ -272,13 +305,32 @@ def _handle_abort_request(self, request: RPCAbortRequest): if self.log_requests: logger.info("Aborted request %s.", request.request_id) - def _handle_health_request(self): + def _heartbeat_loop(self): + while not self._heartbeat_stop_event.wait( + timeout=self.heartbeat_interval_seconds): + # Loops until the stop event is set + self._heartbeat() + + logger.debug("Exiting MQLLMEngine heartbeat thread") + + def _heartbeat(self): + # Send unhealthy if engine has already errored if self._errored_with is not None: self._send_unhealthy(self._errored_with) - # Raises error if unhealthy. - self.engine.check_health() - self._send_healthy() + # Check for life of the main loop + elif time.time() - self._last_alive_time > self.last_alive_threshold: + self._send_unhealthy(RuntimeError("Engine loop has died")) + + else: + # Otherwise- check health of the engine + # self.engine.check_health() raises on unhealthy + try: + self.engine.check_health() + self._send_healthy() + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): """Send List of RequestOutput to RPCClient.""" @@ -288,12 +340,14 @@ def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): def _send_healthy(self): """Send HEALTHY message to RPCClient.""" - self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False) + if not self.heartbeat_socket.closed: + self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False) def _send_unhealthy(self, error: BaseException): """Send UNHEALTHY message to RPCClient.""" - error_bytes = pickle.dumps(error) - self.health_socket.send_multipart((error_bytes, ), copy=False) + if not self.heartbeat_socket.closed: + error_bytes = pickle.dumps(error) + self.heartbeat_socket.send_multipart((error_bytes, ), copy=False) def _async_socket_engine_callback(self, request_outputs: REQUEST_OUTPUTS_T): @@ -306,6 +360,21 @@ def _set_errored(self, e: BaseException): if self._errored_with is None: self._errored_with = e + def _alive(self): + self._last_alive_time = time.time() + + def start_profile(self) -> None: + if type(self.engine.model_executor) is GPUExecutor: + self.engine.model_executor.start_profile() + else: + self.engine.model_executor._run_workers("start_profile") + + def stop_profile(self) -> None: + if type(self.engine.model_executor) is GPUExecutor: + self.engine.model_executor.stop_profile() + else: + self.engine.model_executor._run_workers("stop_profile") + def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index c73db765fc3b5..31c2bbc8e7127 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -9,8 +9,8 @@ from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, - SequenceOutput, SequenceStatus) +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Sequence, SequenceGroup, + SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import Counter @@ -110,10 +110,11 @@ def process_outputs(self, # we can take the first sample. samples = [output.samples[0] for output in outputs] - # -1 means the output token is not valid (eg. due to spec decode + # entries in sample tokens may be invalid (eg. due to spec decode # rejecting tokens). valid_samples = [ - sample for sample in samples if sample.output_token != -1 + sample for sample in samples + if sample.output_token != VLLM_INVALID_TOKEN_ID ] assert valid_samples diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index d0bbeb357b506..70444faa670a2 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -3,7 +3,7 @@ from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.inputs.data import PromptType +from vllm.inputs.data import PromptInputs from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -35,19 +35,19 @@ def dead_error(self) -> BaseException: def generate( self, - prompt: PromptType, + inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[RequestOutput, None]: - """Generate outputs for a request.""" + """Generates outputs for a request""" ... def encode( self, - prompt: PromptType, + inputs: PromptInputs, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index f1ce2c36fcceb..4a575ae8f8537 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -159,6 +159,8 @@ def _placeholder_str(self, modality: ModalityStr, hf_config.image_token_index) if model_type in ("chameleon", "internvl_chat"): return "" + if model_type == "mllama": + return "<|image|>" if model_type == "qwen2_vl": return "<|vision_start|><|image_pad|><|vision_end|>" @@ -358,6 +360,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], _ImageParser = partial(cast, ChatCompletionContentPartImageParam) _AudioParser = partial(cast, ChatCompletionContentPartAudioParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) +MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'} def _parse_chat_message_content_parts( @@ -368,7 +371,11 @@ def _parse_chat_message_content_parts( texts: List[str] = [] mm_parser = mm_tracker.create_parser() + keep_multimodal_content = \ + mm_tracker._model_config.hf_config.model_type in \ + MODEL_KEEP_MULTI_MODAL_CONTENT + has_image = False for part in parts: part_type = part["type"] if part_type == "text": @@ -383,6 +390,7 @@ def _parse_chat_message_content_parts( "will be ignored.") mm_parser.parse_image(image_url["url"]) + has_image = True elif part_type == "audio_url": audio_url = _AudioParser(part)["audio_url"] @@ -394,12 +402,20 @@ def _parse_chat_message_content_parts( raise NotImplementedError(f"Unknown part type: {part_type}") text_prompt = "\n".join(texts) - mm_placeholder_counts = mm_parser.mm_placeholder_counts() - if mm_placeholder_counts: - text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts, - text_prompt) - - return [ConversationMessage(role=role, content=text_prompt)] + if keep_multimodal_content: + text_prompt = "\n".join(texts) + role_content = [{'type': 'text', 'text': text_prompt}] + + if has_image: + role_content = [{'type': 'image'}] + role_content + return [ConversationMessage(role=role, + content=role_content)] # type: ignore + else: + mm_placeholder_counts = mm_parser.mm_placeholder_counts() + if mm_placeholder_counts: + text_prompt = _get_full_multimodal_text_prompt( + mm_placeholder_counts, text_prompt) + return [ConversationMessage(role=role, content=text_prompt)] # No need to validate using Pydantic again diff --git a/vllm/entrypoints/fast_sync_llm.py b/vllm/entrypoints/fast_sync_llm.py index eb36d124a89fa..c948fc97feeb9 100644 --- a/vllm/entrypoints/fast_sync_llm.py +++ b/vllm/entrypoints/fast_sync_llm.py @@ -8,7 +8,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.executor.multiproc_gpu_executor import MultiprocessingGPUExecutor from vllm.executor.ray_gpu_executor import RayGPUExecutor -from vllm.inputs import PromptType, TokensPrompt +from vllm.inputs import PromptInputs, TokensPrompt from vllm.logger import init_logger from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -40,7 +40,7 @@ def __init__( def _add_request( self, - inputs: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], request_id: str, ) -> None: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a86c51d23b34d..77ae7b088398a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,6 +1,8 @@ +import itertools from contextlib import contextmanager -from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast, - overload) +from dataclasses import dataclass +from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, + Union, cast, overload) from tqdm import tqdm @@ -10,7 +12,7 @@ apply_hf_chat_template, apply_mistral_chat_template, parse_chat_messages) -from vllm.inputs import PromptType, TextPrompt, TokensPrompt +from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -30,6 +32,37 @@ logger = init_logger(__name__) +@dataclass +class BeamSearchSequence: + """A sequence for beam search. + It keeps track of the tokens and the log probability of the sequence. + The text field is optional and will only be filled when the sequence is + about to be returned to the user. + """ + # The tokens includes the prompt. + tokens: List[int] + cum_logprob: float = 0.0 + text: Optional[str] = None + + +@dataclass +class BeamSearchOutput: + """The output of beam search. + It contains the list of the best beam search sequences. + The length of the list is equal to the beam width. + """ + sequences: List[BeamSearchSequence] + + +class BeamSearchInstance: + + def __init__(self, prompt_tokens: List[int]): + self.beams: List[BeamSearchSequence] = [ + BeamSearchSequence(tokens=prompt_tokens) + ] + self.completed: List[BeamSearchSequence] = [] + + class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -260,8 +293,8 @@ def generate( @overload def generate( self, - prompts: Union[PromptType, Sequence[PromptType]], - /, + inputs: Union[PromptInputs, Sequence[PromptInputs]], + /, # We may enable `inputs` keyword after removing the old API *, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -278,7 +311,7 @@ def generate( ) def generate( self, - prompts: Union[Union[PromptType, Sequence[PromptType]], + prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], Optional[Union[str, List[str]]]] = None, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -287,7 +320,8 @@ def generate( lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, - GuidedDecodingRequest]] = None + GuidedDecodingRequest]] = None, + priority: Optional[List[int]] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -296,9 +330,7 @@ def generate( into a single list and pass it to this method. Args: - prompts: The prompts to the LLM. You may pass a sequence of prompts - for batch inference. See :class:`~vllm.inputs.PromptType` - for more details about the format of each prompts. + inputs: A list of inputs to generate completions for. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. @@ -308,6 +340,8 @@ def generate( lora_request: LoRA request to use for generation, if any. prompt_adapter_request: Prompt Adapter request to use for generation, if any. + priority: The priority of the requests, if any. + Only applicable when priority scheduling policy is enabled. Returns: A list of ``RequestOutput`` objects containing the @@ -324,13 +358,12 @@ def generate( "models (XForCausalLM, XForConditionalGeneration).") if prompt_token_ids is not None: - parsed_prompts = self._convert_v1_inputs( + inputs = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], - prompts) + inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) if isinstance(guided_options_request, dict): if len(guided_options_request) > 1: @@ -345,18 +378,119 @@ def generate( sampling_params = SamplingParams() self._validate_and_add_requests( - prompts=parsed_prompts, + inputs=inputs, params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, - guided_options=guided_options_request) + guided_options=guided_options_request, + priority=priority) outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) + def beam_search( + self, + prompts: List[Union[str, List[int]]], + beam_width: int, + max_tokens: int, + ignore_eos: bool = False, + ) -> List[BeamSearchOutput]: + """ + Generate sequences using beam search. + + Args: + prompts: A list of prompts. Each prompt can be a string or a list + of token IDs. + beam_width: The number of beams to keep at each step. + max_tokens: The max number of tokens to generate for each prompt. + + TODO: how does beam search work together with length penalty, frequency + penalty, and stopping criteria, etc.? + """ + + tokenizer = self.get_tokenizer() + # generate 2 * beam_width candidates at each step + # following the huggingface transformers implementation + # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa + beam_search_params = SamplingParams(logprobs=2 * beam_width, + max_tokens=1, + temperature=0.0) + instances: List[BeamSearchInstance] = [] + + for prompt in prompts: + prompt_tokens = prompt if isinstance( + prompt, list) else tokenizer.encode(prompt) + instances.append(BeamSearchInstance(prompt_tokens)) + + for _ in range(max_tokens): + all_beams: List[BeamSearchSequence] = list( + sum((instance.beams for instance in instances), [])) + pos = [0] + list( + itertools.accumulate( + len(instance.beams) for instance in instances)) + instance_start_and_end: List[Tuple[int, int]] = list( + zip(pos[:-1], pos[1:])) + + if len(all_beams) == 0: + break + + prompts_batch = [ + TokensPrompt(prompt_token_ids=beam.tokens) + for beam in all_beams + ] + + # only runs for one step + # we don't need to use tqdm here + output = self.generate(prompts_batch, + sampling_params=beam_search_params, + use_tqdm=False) + + for (start, end), instance in zip(instance_start_and_end, + instances): + instance_new_beams = [] + for i in range(start, end): + current_beam = all_beams[i] + result = output[i] + + if result.outputs[0].logprobs is not None: + # if `result.outputs[0].logprobs` is None, it means + # the sequence is completed because of the max-model-len + # or abortion. we don't need to add it to the new beams. + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob) + + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + instance.completed.append(new_beam) + else: + instance_new_beams.append(new_beam) + sorted_beams = sorted(instance_new_beams, + key=lambda x: x.cum_logprob, + reverse=True) + instance.beams = sorted_beams[:beam_width] + + outputs = [] + for instance in instances: + instance.completed.extend(instance.beams) + sorted_completed = sorted(instance.completed, + key=lambda x: x.cum_logprob, + reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + beam.text = tokenizer.decode(beam.tokens) + outputs.append(BeamSearchOutput(sequences=best_beams)) + + return outputs + def chat( self, - messages: List[ChatCompletionMessageParam], + messages: Union[List[ChatCompletionMessageParam], + List[List[ChatCompletionMessageParam]]], sampling_params: Optional[Union[SamplingParams, List[SamplingParams]]] = None, use_tqdm: bool = True, @@ -376,8 +510,9 @@ def chat( to the OpenAI API. Args: - messages: A single conversation represented as a list of messages. - Each message is a dictionary with 'role' and 'content' keys. + messages: A list of conversations or a single conversation. + - Each conversation is represented as a list of messages. + - Each message is a dictionary with 'role' and 'content' keys. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. When it @@ -394,42 +529,56 @@ def chat( A list of ``RequestOutput`` objects containing the generated responses in the same order as the input messages. """ + list_of_messages: List[List[ChatCompletionMessageParam]] - tokenizer = self.get_tokenizer() - model_config = self.llm_engine.get_model_config() - - conversation, mm_data = parse_chat_messages(messages, model_config, - tokenizer) - - prompt_data: Union[str, List[int]] - if isinstance(tokenizer, MistralTokenizer): - prompt_data = apply_mistral_chat_template( - tokenizer, - messages=messages, - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - tools=tools, - ) + # Handle multi and single conversations + if is_list_of(messages, list): + # messages is List[List[...]] + list_of_messages = messages else: - prompt_data = apply_hf_chat_template( - tokenizer, - conversation=conversation, - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - tools=tools, - ) + # messages is List[...] + list_of_messages = [messages] + + prompts: List[Union[TokensPrompt, TextPrompt]] = [] + + for msgs in list_of_messages: + tokenizer = self.get_tokenizer() + model_config = self.llm_engine.get_model_config() + + conversation, mm_data = parse_chat_messages( + msgs, model_config, tokenizer) + + prompt_data: Union[str, List[int]] + if isinstance(tokenizer, MistralTokenizer): + prompt_data = apply_mistral_chat_template( + tokenizer, + messages=msgs, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + tools=tools, + ) + else: + prompt_data = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + tools=tools, + ) + + prompt: Union[TokensPrompt, TextPrompt] + if is_list_of(prompt_data, int): + prompt = TokensPrompt(prompt_token_ids=prompt_data) + else: + prompt = TextPrompt(prompt=prompt_data) - prompt: PromptType - if is_list_of(prompt_data, int): - prompt = TokensPrompt(prompt_token_ids=prompt_data) - else: - prompt = TextPrompt(prompt=prompt_data) + if mm_data is not None: + prompt["multi_modal_data"] = mm_data - if mm_data is not None: - prompt["multi_modal_data"] = mm_data + prompts.append(prompt) return self.generate( - prompt, + prompts, sampling_params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, @@ -499,8 +648,8 @@ def encode( @overload def encode( self, - prompts: Union[PromptType, Sequence[PromptType]], - /, + inputs: Union[PromptInputs, Sequence[PromptInputs]], + /, # We may enable `inputs` keyword after removing the old API *, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -517,7 +666,7 @@ def encode( ) def encode( self, - prompts: Union[Union[PromptType, Sequence[PromptType]], + prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], Optional[Union[str, List[str]]]] = None, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -533,9 +682,9 @@ def encode( into a single list and pass it to this method. Args: - prompts: The prompts to the LLM. You may pass a sequence of prompts - for batch inference. See :class:`~vllm.inputs.PromptType` - for more details about the format of each prompts. + inputs: The inputs to the LLM. You may pass a sequence of inputs for + batch inference. See :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. @@ -558,20 +707,19 @@ def encode( ) if prompt_token_ids is not None: - parsed_prompts = self._convert_v1_inputs( + inputs = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], - prompts) + inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() self._validate_and_add_requests( - prompts=parsed_prompts, + inputs=inputs, params=pooling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -615,9 +763,9 @@ def _convert_v1_inputs( raise ValueError("Either prompts or prompt_token_ids must be " "provided.") - parsed_prompts: List[PromptType] = [] + inputs: List[PromptInputs] = [] for i in range(num_requests): - item: PromptType + item: PromptInputs if prompts is not None: item = TextPrompt(prompt=prompts[i]) @@ -626,24 +774,25 @@ def _convert_v1_inputs( else: raise AssertionError - parsed_prompts.append(item) + inputs.append(item) - return parsed_prompts + return inputs def _validate_and_add_requests( self, - prompts: Union[PromptType, Sequence[PromptType]], + inputs: Union[PromptInputs, Sequence[PromptInputs]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], guided_options: Optional[GuidedDecodingRequest] = None, + priority: Optional[List[int]] = None, ) -> None: - if isinstance(prompts, (str, dict)): + if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. - prompts = [prompts] + inputs = [inputs] - num_requests = len(prompts) + num_requests = len(inputs) if isinstance(params, list) and len(params) != num_requests: raise ValueError("The lengths of prompts and params " "must be the same.") @@ -660,29 +809,32 @@ def _validate_and_add_requests( sp.output_kind = RequestOutputKind.FINAL_ONLY # Add requests to the engine. - for i, prompt in enumerate(prompts): + for i, request_inputs in enumerate(inputs): self._add_request( - prompt, + request_inputs, params[i] if isinstance(params, Sequence) else params, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, prompt_adapter_request=prompt_adapter_request, + priority=priority[i] if priority else 0, ) def _add_request( self, - prompt: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request( request_id, - prompt, + inputs, params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, + priority=priority, ) def _add_guided_processor( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7e9f53b1816d1..40d27f984fbaa 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -107,6 +107,11 @@ class UsageInfo(OpenAIBaseModel): completion_tokens: Optional[int] = 0 +class RequestResponseMetadata(BaseModel): + request_id: str + final_usage_info: Optional[UsageInfo] = None + + class JsonSchemaResponseFormat(OpenAIBaseModel): name: str description: Optional[str] = None diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1ee4b3ce17cfa..94076ea3a51db 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -22,7 +22,8 @@ ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, - DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo) + DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata, + ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (BaseModelPath, LoRAModulePath, OpenAIServing, @@ -175,6 +176,11 @@ async def create_chat_completion( "--enable-auto-tool-choice and --tool-call-parser to be set") request_id = f"chat-{random_uuid()}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + try: guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) @@ -241,11 +247,13 @@ async def create_chat_completion( # Streaming response if request.stream: return self.chat_completion_stream_generator( - request, result_generator, request_id, conversation, tokenizer) + request, result_generator, request_id, conversation, tokenizer, + request_metadata) try: return await self.chat_completion_full_generator( - request, result_generator, request_id, conversation, tokenizer) + request, result_generator, request_id, conversation, tokenizer, + request_metadata) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -262,6 +270,7 @@ async def chat_completion_stream_generator( request_id: str, conversation: List[ConversationMessage], tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, ) -> AsyncGenerator[str, None]: model_name = self.base_model_paths[0].name created_time = int(time.time()) @@ -300,6 +309,8 @@ async def chat_completion_stream_generator( async for res in result_generator: if res.prompt_token_ids is not None: num_prompt_tokens = len(res.prompt_token_ids) + if res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(res.encoder_prompt_token_ids) # We need to do it here, because if there are exceptions in # the result_generator, it needs to be sent as the FIRST @@ -580,6 +591,13 @@ async def chat_completion_stream_generator( exclude_unset=True, exclude_none=True)) yield f"data: {final_usage_data}\n\n" + # report to FastAPI middleware aggregate usage across all choices + num_completion_tokens = sum(previous_num_tokens) + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_completion_tokens, + total_tokens=num_prompt_tokens + num_completion_tokens) + except ValueError as e: # TODO: Use a vllm-specific Validation Error logger.error("error in chat completion stream generator: %s", e) @@ -595,6 +613,7 @@ async def chat_completion_full_generator( request_id: str, conversation: List[ConversationMessage], tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, ) -> Union[ErrorResponse, ChatCompletionResponse]: model_name = self.base_model_paths[0].name @@ -714,6 +733,9 @@ async def chat_completion_full_generator( completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, ) + + request_metadata.final_usage_info = usage + response = ChatCompletionResponse( id=request_id, created=created_time, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 9abd74d0561d0..0e8609002e39e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -18,7 +18,9 @@ CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, - ErrorResponse, UsageInfo) + ErrorResponse, + RequestResponseMetadata, + UsageInfo) # yapf: enable from vllm.entrypoints.openai.serving_engine import (BaseModelPath, LoRAModulePath, @@ -94,6 +96,10 @@ async def create_completion( request_id = f"cmpl-{random_uuid()}" created_time = int(time.time()) + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + # Schedule the request and get the result generator. generators: List[AsyncGenerator[RequestOutput, None]] = [] try: @@ -165,13 +171,15 @@ async def create_completion( # Streaming response if stream: - return self.completion_stream_generator(request, - result_generator, - request_id, - created_time, - model_name, - num_prompts=len(prompts), - tokenizer=tokenizer) + return self.completion_stream_generator( + request, + result_generator, + request_id, + created_time, + model_name, + num_prompts=len(prompts), + tokenizer=tokenizer, + request_metadata=request_metadata) # Non-streaming response final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) @@ -198,6 +206,7 @@ async def create_completion( created_time, model_name, tokenizer, + request_metadata, ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") @@ -227,6 +236,7 @@ async def completion_stream_generator( model_name: str, num_prompts: int, tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, ) -> AsyncGenerator[str, None]: num_choices = 1 if request.n is None else request.n previous_text_lens = [0] * num_choices * num_prompts @@ -346,6 +356,14 @@ async def completion_stream_generator( exclude_unset=False, exclude_none=True)) yield f"data: {final_usage_data}\n\n" + # report to FastAPI middleware aggregate usage across all choices + total_prompt_tokens = sum(num_prompt_tokens) + total_completion_tokens = sum(previous_num_tokens) + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens) + except ValueError as e: # TODO: Use a vllm-specific Validation Error data = self.create_streaming_error_response(str(e)) @@ -360,6 +378,7 @@ def request_output_to_completion_response( created_time: int, model_name: str, tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, ) -> CompletionResponse: choices: List[CompletionResponseChoice] = [] num_prompt_tokens = 0 @@ -433,6 +452,8 @@ def request_output_to_completion_response( total_tokens=num_prompt_tokens + num_generated_tokens, ) + request_metadata.final_usage_info = usage + return CompletionResponse( id=request_id, created=created_time, diff --git a/vllm/envs.py b/vllm/envs.py index 90c6612f23bd7..ee4711dbec842 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -69,6 +69,7 @@ VLLM_RPD_PROFILER_DIR: Optional[str] = None VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False + VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: bool = False VLLM_SYNC_SERVER_ACCUM_REQUESTS: int = 1 VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS: int = 1 VLLM_MOE_PADDING: bool = False @@ -221,6 +222,10 @@ def get_default_config_root(): lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")), + # If set, allowing the use of deprecated beam search implementation + "VLLM_ALLOW_DEPRECATED_BEAM_SEARCH": + lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BEAM_SEARCH", "0") == "1", + # Internal flag to enable Dynamo graph capture "VLLM_TEST_DYNAMO_GRAPH_CAPTURE": lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 59e9854393b6b..7e46acefc5b0e 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -18,9 +18,14 @@ try: import ray - from ray._private.state import available_resources_per_node from ray.util import placement_group_table from ray.util.placement_group import PlacementGroup + try: + from ray._private.state import available_resources_per_node + except ImportError: + # Ray 2.9.x doesn't expose `available_resources_per_node` + from ray._private.state import state as _state + available_resources_per_node = _state._available_resources_per_node class RayWorkerWrapper(WorkerWrapperBase): """Ray wrapper for vllm.worker.Worker, allowing Worker to be diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index ba1bef1ab3ecc..0b08e9691f915 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,5 +1,5 @@ from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptType, SingletonPrompt, TextPrompt, + LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, TokensPrompt, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry @@ -16,8 +16,8 @@ __all__ = [ "TextPrompt", "TokensPrompt", - "PromptType", - "SingletonPrompt", + "PromptInputs", + "SingletonPromptInputs", "ExplicitEncoderDecoderPrompt", "LLMInputs", "EncoderDecoderLLMInputs", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index e072bb65714b9..a71e9a7b5db66 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -33,7 +33,7 @@ class TokensPrompt(TypedDict): """ -SingletonPrompt = Union[str, TextPrompt, TokensPrompt] +SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt] """ Set of possible schemas for a single LLM input: @@ -46,7 +46,7 @@ class TokensPrompt(TypedDict): the user desires to express both the encoder & decoder prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt` -A prompt of type :class:`SingletonPromptType` may be employed +A prompt of type :class:`SingletonPromptInputs` may be employed as (1) input to a decoder-only model, (2) input to the encoder of an encoder/decoder model, in the scenario where the decoder-prompt is not specified explicitly, or @@ -55,12 +55,12 @@ class TokensPrompt(TypedDict): """ _T1_co = TypeVar("_T1_co", - bound=SingletonPrompt, - default=SingletonPrompt, + bound=SingletonPromptInputs, + default=SingletonPromptInputs, covariant=True) _T2_co = TypeVar("_T2_co", - bound=SingletonPrompt, - default=SingletonPrompt, + bound=SingletonPromptInputs, + default=SingletonPromptInputs, covariant=True) @@ -72,7 +72,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): The encoder and decoder prompts, respectively, may formatted according to any of the - :class:`SingletonPromptType` schemas, and are not + :class:`SingletonPromptInputs` schemas, and are not required to have the same schema. Only the encoder prompt may have multi-modal data. @@ -81,7 +81,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): be used as an input to a decoder-only model, and that the `encoder_prompt` and `decoder_prompt` fields of this data structure themselves must be - :class:`SingletonPromptType` instances. + :class:`SingletonPromptInputs` instances. """ encoder_prompt: _T1_co @@ -89,7 +89,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): decoder_prompt: Optional[_T2_co] -PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] +PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt] """ Set of possible schemas for an LLM input, including both decoder-only and encoder/decoder input types: @@ -139,9 +139,19 @@ class EncoderDecoderLLMInputs(LLMInputs): available. """ + encoder_multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] + """ + Optional multi-modal data to pass to the encoder model, + if the model supports it. + """ + -_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) -_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) +_T1 = TypeVar("_T1", + bound=SingletonPromptInputs, + default=SingletonPromptInputs) +_T2 = TypeVar("_T2", + bound=SingletonPromptInputs, + default=SingletonPromptInputs) def build_explicit_enc_dec_prompt( diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index e5fa1e4184277..ac9d355c64c80 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -5,7 +5,7 @@ from vllm.utils import is_list_of from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptType, SingletonPrompt, TextPrompt, + LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, TokensPrompt) @@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict): def parse_singleton_prompt( - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, ) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: - if isinstance(prompt, str): - return ParsedStrPrompt(type="str", content=prompt) - elif isinstance(prompt, dict): - if "prompt_token_ids" in prompt: + if isinstance(inputs, str): + return ParsedStrPrompt(type="str", content=inputs) + elif isinstance(inputs, dict): + if "prompt_token_ids" in inputs: return ParsedTokensPrompt(type="tokens", - content=prompt) # type: ignore - elif "prompt" in prompt: - return ParsedTextPrompt(type="text", content=prompt) + content=inputs) # type: ignore + elif "prompt" in inputs: + return ParsedTextPrompt(type="text", content=inputs) raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") def is_explicit_encoder_decoder_prompt( - prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: - return isinstance(prompt, dict) and "encoder_prompt" in prompt + inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]: + return isinstance(inputs, dict) and "encoder_prompt" in inputs def is_valid_encoder_decoder_llm_inputs( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 1f1b048d37e9b..bee3d1ed75cbb 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -9,8 +9,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType, - SingletonPrompt) +from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, + SingletonPromptInputs) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt if TYPE_CHECKING: @@ -128,6 +128,7 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]: def _prepare_decoder_input_ids_for_generation( self, decoder_input_ids: Optional[List[int]], + force_bos: bool = True, ) -> List[int]: """ Prepares `decoder_input_ids` for generation with encoder-decoder models. @@ -157,8 +158,8 @@ def _prepare_decoder_input_ids_for_generation( # use decoder_start_token_id as decoder_input_ids decoder_input_ids = self._get_default_enc_dec_decoder_prompt() - if (len(decoder_input_ids) == 0 - or decoder_input_ids[0] != decoder_start_token_id): + if force_bos and (len(decoder_input_ids) == 0 + or decoder_input_ids[0] != decoder_start_token_id): decoder_input_ids = [decoder_start_token_id] + decoder_input_ids return decoder_input_ids @@ -206,7 +207,7 @@ async def _tokenize_prompt_async( def _extract_prompt_components( self, - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: @@ -216,7 +217,7 @@ def _extract_prompt_components( Arguments: * request_id - * prompt: single encoder or decoder input prompt + * inputs: single encoder or decoder input prompt * lora_request: this is only valid for decoder prompts Returns: @@ -226,24 +227,24 @@ def _extract_prompt_components( * multi_modal_data ''' - parsed = parse_singleton_prompt(prompt) + parsed = parse_singleton_prompt(inputs) if parsed["type"] == "str": - prompt_text = parsed["content"] + prompt = parsed["content"] prompt_token_ids = self._tokenize_prompt( - prompt_text, + prompt, request_id=request_id, lora_request=lora_request, ) multi_modal_data = None elif parsed["type"] == "tokens": - prompt_text = None + prompt = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") elif parsed["type"] == "text": - prompt_text = parsed["content"]["prompt"] + prompt = parsed["content"]["prompt"] prompt_token_ids = self._tokenize_prompt( - prompt_text, + prompt, request_id=request_id, lora_request=lora_request, ) @@ -251,33 +252,33 @@ def _extract_prompt_components( else: assert_never(parsed) - return prompt_text, prompt_token_ids, multi_modal_data + return prompt, prompt_token_ids, multi_modal_data async def _extract_prompt_components_async( self, - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: """Async version of :meth:`_extract_prompt_components`.""" - parsed = parse_singleton_prompt(prompt) + parsed = parse_singleton_prompt(inputs) if parsed["type"] == "str": - prompt_text = parsed["content"] + prompt = parsed["content"] prompt_token_ids = await self._tokenize_prompt_async( - prompt_text, + prompt, request_id=request_id, lora_request=lora_request, ) multi_modal_data = None elif parsed["type"] == "tokens": - prompt_text = None + prompt = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") elif parsed["type"] == "text": - prompt_text = parsed["content"]["prompt"] + prompt = parsed["content"]["prompt"] prompt_token_ids = await self._tokenize_prompt_async( - prompt_text, + prompt, request_id=request_id, lora_request=lora_request, ) @@ -285,7 +286,7 @@ async def _extract_prompt_components_async( else: assert_never(parsed) - return prompt_text, prompt_token_ids, multi_modal_data + return prompt, prompt_token_ids, multi_modal_data def _build_enc_dec_llm_inputs( self, @@ -295,23 +296,30 @@ def _build_enc_dec_llm_inputs( encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps - if encoder_mm_data is not None or decoder_mm_data is not None: - raise ValueError("Multi-modal encoder-decoder models are " - "not supported yet") + if decoder_mm_data is not None: + raise ValueError( + "Multi-modality decoder inputs of encoder-decoder models are " + "not supported yet") - decoder_prompt_ids = ( - self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) + # For Multi-Modal models (e.g., mllama), the text input can be + # <|image|><|begin_of_text|>hello world. And we should not add + # another <|begin_of_text|> to the beginning. + decoder_prompt_ids = (self._prepare_decoder_input_ids_for_generation( + decoder_prompt_ids, + force_bos=(encoder_mm_data is None and decoder_mm_data is None))) return EncoderDecoderLLMInputs( prompt_token_ids=decoder_prompt_ids, prompt=decoder_prompt, + multi_modal_data=decoder_mm_data, encoder_prompt_token_ids=encoder_prompt_ids, encoder_prompt=encoder_prompt, + encoder_multi_modal_data=encoder_mm_data, ) def _process_encoder_decoder_prompt( self, - prompt: PromptType, + inputs: PromptInputs, request_id: str, ) -> EncoderDecoderLLMInputs: ''' @@ -339,7 +347,7 @@ def _process_encoder_decoder_prompt( Arguments: - * prompt: an input prompt + * inputs: an input prompt * request_id Returns: @@ -350,13 +358,13 @@ def _process_encoder_decoder_prompt( encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents - if is_explicit_encoder_decoder_prompt(prompt): + if is_explicit_encoder_decoder_prompt(inputs): encoder_comps = self._extract_prompt_components( - prompt["encoder_prompt"], + inputs["encoder_prompt"], request_id=request_id, ) - if (decoder_input := prompt["decoder_prompt"]) is None: + if (decoder_input := inputs["decoder_prompt"]) is None: decoder_comps = None, None, None else: decoder_comps = self._extract_prompt_components( @@ -365,7 +373,7 @@ def _process_encoder_decoder_prompt( ) else: encoder_comps = self._extract_prompt_components( - prompt, + inputs, request_id=request_id, ) @@ -375,20 +383,20 @@ def _process_encoder_decoder_prompt( async def _process_encoder_decoder_prompt_async( self, - prompt: PromptType, + inputs: PromptInputs, request_id: str, ) -> EncoderDecoderLLMInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents - if is_explicit_encoder_decoder_prompt(prompt): + if is_explicit_encoder_decoder_prompt(inputs): encoder_task = self._extract_prompt_components_async( - prompt["encoder_prompt"], + inputs["encoder_prompt"], request_id=request_id, ) - if (decoder_input := prompt["decoder_prompt"]) is None: + if (decoder_input := inputs["decoder_prompt"]) is None: encoder_comps = await encoder_task decoder_comps = None, None, None else: @@ -401,7 +409,7 @@ async def _process_encoder_decoder_prompt_async( encoder_task, decoder_task) else: encoder_comps = await self._extract_prompt_components_async( - prompt, + inputs, request_id=request_id, ) @@ -425,7 +433,7 @@ def _build_decoder_only_llm_inputs( def _process_decoder_only_prompt( self, - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -436,7 +444,7 @@ def _process_decoder_only_prompt( Arguments: - * prompt: input prompt + * inputs: input prompt * request_id * lora_request * prompt_adapter_request @@ -447,7 +455,7 @@ def _process_decoder_only_prompt( ''' prompt_comps = self._extract_prompt_components( - prompt, + inputs, request_id=request_id, lora_request=lora_request, ) @@ -459,14 +467,14 @@ def _process_decoder_only_prompt( async def _process_decoder_only_prompt_async( self, - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: """Async version of :meth:`_process_decoder_only_prompt`.""" prompt_comps = await self._extract_prompt_components_async( - prompt, + inputs, request_id=request_id, lora_request=lora_request, ) @@ -478,7 +486,7 @@ async def _process_decoder_only_prompt_async( def preprocess( self, - prompt: PromptType, + inputs: PromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -488,17 +496,17 @@ def preprocess( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return self._process_encoder_decoder_prompt( - prompt, + inputs, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(prompt): + if is_explicit_encoder_decoder_prompt(inputs): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return self._process_decoder_only_prompt( - prompt, + inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -506,7 +514,7 @@ def preprocess( async def preprocess_async( self, - prompt: PromptType, + inputs: PromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -516,17 +524,17 @@ async def preprocess_async( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return await self._process_encoder_decoder_prompt_async( - prompt, + inputs, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(prompt): + if is_explicit_encoder_decoder_prompt(inputs): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return await self._process_decoder_only_prompt_async( - prompt, + inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 6ab23d1c4b769..159d958ebf671 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -112,6 +112,8 @@ class InputRegistry: def __init__(self) -> None: self._dummy_factories_by_model_type: Dict[Type[nn.Module], DummyDataFactory] = {} + self._dummy_encoder_factories_by_model_type: Dict[ + Type[nn.Module], DummyDataFactory] = {} self._input_processors_by_model_type: Dict[Type[nn.Module], InputProcessor] = {} @@ -162,11 +164,44 @@ def _get_dummy_data_factory(self, model_cls: Type[nn.Module]): return self._dummy_factories_by_model_type \ .get(model_cls, self._default_dummy_data_factory) + def register_dummy_encoder_data(self, factory: DummyDataFactory): + """ + Register a dummy encoder data factory to a model class + + This is similar to :meth:`~register_dummy_data`, but for encoder input. + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._dummy_encoder_factories_by_model_type: + logger.warning( + "Model class %s already has dummy encoder data " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._dummy_encoder_factories_by_model_type[model_cls] = factory + + return model_cls + + return wrapper + + def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]): + if model_cls in self._dummy_encoder_factories_by_model_type: + dummy_factory = self._dummy_encoder_factories_by_model_type[ + model_cls] + else: + logger.warning( + "No dummy encoder data factory registered to %s. " + "Using the dummy data factory for the model instead.", + model_cls) + dummy_factory = self._get_dummy_data_factory(model_cls) + return dummy_factory + def dummy_data_for_profiling( self, model_config: "ModelConfig", seq_len: int, mm_registry: "MultiModalRegistry", + is_encoder_data: bool = False, ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: """ Create dummy data for profiling the memory usage of a model. @@ -184,8 +219,10 @@ def dummy_data_for_profiling( from vllm.model_executor.model_loader import get_model_architecture model_cls, _ = get_model_architecture(model_config) - dummy_factory = self._get_dummy_data_factory(model_cls) - + if is_encoder_data: + dummy_factory = self._get_dummy_encoder_data_factory(model_cls) + else: + dummy_factory = self._get_dummy_data_factory(model_cls) mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) mm_processor_kwargs = get_allowed_kwarg_only_overrides( dummy_factory, overrides=model_config.mm_processor_kwargs) @@ -196,10 +233,15 @@ def dummy_data_for_profiling( # Having more tokens is over-conservative but otherwise fine num_tokens = seq_data.prompt_token_ids - assert len(num_tokens) >= seq_len, ( - f"Expected at least {seq_len} dummy tokens for profiling, " - f"but found {len(num_tokens)} tokens instead.") - + if len(num_tokens) < seq_len: + if is_encoder_data: + logger.warning( + "Expected at least %d dummy encoder tokens for profiling, " + "but found %d tokens instead.", seq_len, len(num_tokens)) + else: + raise AssertionError( + f"Expected at least {seq_len} dummy tokens for profiling, " + f"but found {len(num_tokens)} tokens instead.") if mm_data is not None: for k, v in mm_data.items(): num_items = len(v) if isinstance(v, list) else 1 diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/bgmv_expand.py index 619408b9315cf..6a32387a6f36c 100644 --- a/vllm/lora/ops/bgmv_expand.py +++ b/vllm/lora/ops/bgmv_expand.py @@ -100,7 +100,7 @@ def _bgmv_expand( corresponding to each batch, An index of -1 means no lora should be applied. batches (int): batch size - add_inputs (bool, optional): Defaults to False. adds the final lora + add_inputs (bool, optional): Defaults to False, adds the final lora results to the output. """ assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/bgmv_expand_slice.py index c16db233891a5..73628fd20d327 100644 --- a/vllm/lora/ops/bgmv_expand_slice.py +++ b/vllm/lora/ops/bgmv_expand_slice.py @@ -104,7 +104,7 @@ def _bgmv_expand_slice( lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch, An index of -1 means no lora should be applied. - slice_offst (int): output_tensor's offst + slice_offset (int): output_tensor's offset slice_size (int): current output_tensor's size batches (int): batch size add_inputs (bool, optional): Defaults to False. diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index c71332d8bdfb2..adb3ab5b46b87 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -106,6 +106,7 @@ def _sgmv_expand( lora_indices_tensor: torch.Tensor, batches: int, max_seq_length: int, + token_nums: int, add_inputs: bool = False, ) -> None: """ @@ -115,17 +116,19 @@ def _sgmv_expand( output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index - into sequence. E.g.,if the sequence length is [4, 6], it is + into sequence. E.g., if the sequence length is [4, 6], it is [0, 4, 10]. - seq_len_tensor (torch.Tensor): (batch_size,). record the sequence - length of the sequences in the batch + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch. An index of -1 means no lora should be applied. batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences - in the batch - add_inputs (bool, optional): Defaults to False. adds the final lora + max_seq_length (int): The max sequence lengths of the sequences in the + batch. + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + add_inputs (bool, optional): Defaults to False, adds the final lora results to the output. """ @@ -134,6 +137,7 @@ def _sgmv_expand( torch.float16, torch.bfloat16, ] + assert inputs.size(0) == token_nums assert inputs.size(1) == lora_b_weights.size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index b4ae9a2acbb5c..efa234520ab87 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -112,6 +112,7 @@ def _sgmv_expand_slice( lora_indices_tensor: torch.Tensor, batches: int, max_seq_length: int, + token_nums: int, slice_offset: int, slice_size: int, add_inputs: bool = False, @@ -124,20 +125,22 @@ def _sgmv_expand_slice( output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index - into sequence. E.g.,if the sequence length is [4, 6], it is + into sequence. E.g., if the sequence length is [4, 6], it is [0, 4, 10]. - seq_len_tensor (torch.Tensor): (batch_size,). record the sequence - length of the sequences in the batch + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch. An index of -1 means no lora should be applied. batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences + max_seq_length (int): The max sequence lengths of the sequences in the batch - slice_offst (int): output_tensor's offst + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + slice_offset (int): output_tensor's offset slice_size (int): current output_tensor's size - add_inputs (bool, optional): Defaults to False. adds the final lora - results to the output.. + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. """ assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] @@ -145,6 +148,7 @@ def _sgmv_expand_slice( torch.float16, torch.bfloat16, ] + assert inputs.size(0) == token_nums assert inputs.size(1) == lora_b_weights.size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index c0791c260e915..c003f3dc0ce9e 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -110,6 +110,7 @@ def _sgmv_shrink( lora_indices_tensor: torch.Tensor, batches: int, max_seq_length: int, + token_nums: int, scaling: float, ) -> None: """ @@ -120,17 +121,19 @@ def _sgmv_shrink( output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index - into sequence. E.g.,if the sequence length is [4, 6], it is + into sequence. E.g., if the sequence length is [4, 6], it is [0, 4]. - seq_len_tensor (torch.Tensor): (batch_size,). record the sequence - length of the sequences in the batch + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch. An index of -1 means no lora should be applied. batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences - in the batch - scaling (float): Scaling factor. + max_seq_length (int): The max sequence lengths of the sequences in the + batch. + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + scaling (float): Scaling factor. """ assert inputs.dtype == lora_a_weights.dtype assert inputs.dtype in [torch.float16, torch.bfloat16] @@ -138,6 +141,7 @@ def _sgmv_shrink( torch.float16, torch.bfloat16, ] + assert inputs.size(0) == token_nums assert inputs.size(1) == lora_a_weights.size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 6d5c834299961..5033ce4126929 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -27,7 +27,7 @@ def compute_meta( token_lora_tensor: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: """ Get the information required for the sgmv kernel. With the features: 1. If consecutive requests in the batch use the same LoRA, this function @@ -43,7 +43,7 @@ def compute_meta( b_seq_start_tensor = torch.zeros_like(seq_length_tensor) b_seq_start_tensor[1:].copy_(cum_result[:-1]) max_length = seq_length_tensor.max().item() - + token_nums = seq_length_tensor.sum().item() batch_size = lora_indices_tensor.size(0) no_lora = False # -1 means no lora should be applied. Use `no_lora` to determine whether @@ -52,7 +52,7 @@ def compute_meta( if batch_size == 1 and lora_indices_tensor == -1: no_lora = True return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, no_lora) + batch_size, max_length, token_nums, no_lora) # TODO see if this can be vectorized @@ -178,7 +178,7 @@ def convert_mapping( class PunicaWrapper: """ PunicaWrapper is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for + kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the punica kernel. """ @@ -216,6 +216,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, dtype=torch.long, device=device) self.max_length: int = 0 + self.token_nums: int = 0 self.batch_size: int = -1 self.is_prefill = False self.no_lora = False @@ -276,13 +277,13 @@ def _update_base_metadata( long_lora_offsets_tensor) else: self._long_lora_indices.zero_() - self.indices_len[:] = indices_len def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, no_lora) = compute_meta(token_lora_tensor) + batch_size, max_length, token_nums, + no_lora) = compute_meta(token_lora_tensor) self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( b_seq_start_tensor) @@ -291,25 +292,28 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: lora_indices_tensor) self.batch_size = batch_size self.max_length = max_length + self.token_nums = token_nums self.no_lora = no_lora @property def prefill_metadata( - self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]: + self + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: """ This property provides a convenient way to access the necessary metadata for prefill-related kernel computations. - 1. seq_start_locs: Tensor of sequence start positions - 2. seq_lengths: Tensor of sequence lengths + 1. seq_start_locs: Tensor of sequence start positions. + 2. seq_lengths: Tensor of sequence lengths. 3. lora_indices_per_batch: Tensor of lora indices, and an index of -1 means no lora should be applied. - 4. batch_size: batch size after clustering identical lora indices - 5. max_length: The maximum sequence length in the batch + 4. batch_size: Batch size after clustering identical lora indices. + 5. max_length: The maximum sequence length in the batch. + 6. token_nums: The token numbers in the batch. """ return (self._seq_start_locs[:self.batch_size], self._seq_lengths[:self.batch_size], self._lora_indices_per_batch[:self.batch_size], - self.batch_size, self.max_length) + self.batch_size, self.max_length, self.token_nums) @property def token_lora_indices(self) -> torch.Tensor: @@ -324,7 +328,7 @@ def token_lora_indices(self) -> torch.Tensor: def sampler_indices(self) -> torch.Tensor: """ This property is used to access the lora indices specifically for - LogitsProcessorWithLoRA + LogitsProcessorWithLoRA. """ sampler_indices_len = self.indices_len[1] return self._sampler_indices[:sampler_indices_len] @@ -332,7 +336,7 @@ def sampler_indices(self) -> torch.Tensor: @property def sampler_indices_padded(self) -> torch.Tensor: """ - This property provides access to padded sampler indices + This property provides access to padded sampler indices. """ indices_padded_len = self.indices_len[2] return self._sampler_indices_padded[:indices_padded_len] @@ -341,7 +345,7 @@ def sampler_indices_padded(self) -> torch.Tensor: def embeddings_indices(self) -> torch.Tensor: """ This property provides access to the indices used for lora embeddings, - specifically for VocabParallelEmbeddingWithLoRA + specifically for VocabParallelEmbeddingWithLoRA. """ embeddings_indices_len = self.indices_len[3] return self._embeddings_indices[:, :embeddings_indices_len] @@ -350,7 +354,7 @@ def embeddings_indices(self) -> torch.Tensor: def long_lora_indices(self) -> torch.Tensor: """ This property provides access to the indices used for long context - lora, specifically for LinearScalingRotaryEmbeddingWithLora + lora, specifically for LinearScalingRotaryEmbeddingWithLora. """ long_lora_len = self.indices_len[4] return self._long_lora_indices[:long_lora_len] @@ -524,7 +528,7 @@ def add_lora(self, scale (float): Scaling factor. y_offset (Optional[int], optional): Offset to apply to the starting column of y. - y_slice_size (Optional[int], optional): Size of the y column slice.. + y_slice_size (Optional[int], optional): Size of the y column slice. buffer (Optional[torch.Tensor], optional): Defaults to None. """ y_org = y diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 4d0ff111a9397..2d96fd42e6227 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -338,10 +338,12 @@ def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: - # compressed-tensors represents weights on disk which are flipped + # compressed-tensors checkpoints with packed weights are stored flipped + # TODO (mgoin): check self.quant_method.quant_config.quant_format + # against known CompressionFormat enum values that have this quality loaded_weight = loaded_weight.t().contiguous() if ( self.quant_method.__class__.__name__ - == "CompressedTensorsMoEMethod") else loaded_weight + == "CompressedTensorsWNA16MoEMethod") else loaded_weight if shard_id not in ("w1", "w2", "w3"): raise ValueError(f"shard_id must be ['w1','w2','w3'] but " @@ -368,6 +370,9 @@ def weight_loader(self, param: torch.nn.Parameter, # Case input scale: input_scale loading is only supported for fp8 if "input_scale" in weight_name: + # this is needed for compressed-tensors only + loaded_weight = loaded_weight.to(param.data.device) + if param.data[expert_id] != 1 and (param.data[expert_id] - loaded_weight).abs() > 1e-5: raise ValueError( diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index a0bed07ac6193..5fe451b2f1318 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -361,8 +361,8 @@ def selective_scan_fn(u, x[:, :, 0, 0::2] = 1 if prev_state is not None: x[:, :, 0, 1::2].copy_(prev_state) - out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, - delta_softplus, position_indices, x) + out, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, + delta_softplus, position_indices, x) last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) if z is None: return out if not return_last_state else (out, last_state) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index eed01953fb4af..fe33b7341fd38 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -7,10 +7,11 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) + verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) @@ -231,7 +232,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qweight", marlin_qweight) + replace_parameter(layer, "qweight", marlin_qweight) # Permute scales from AWQ format to marlin format. marlin_scales = marlin_permute_scales( @@ -239,7 +240,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, group_size=self.quant_config.group_size) - replace_tensor(layer, "scales", marlin_scales) + replace_parameter(layer, "scales", marlin_scales) # Permute zero-points from AWQ format to marlin format. marlin_zp = awq_to_marlin_zero_points( @@ -247,7 +248,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.num_groups, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qzeros", marlin_zp) + replace_parameter(layer, "qzeros", marlin_zp) # Not-used layer.g_idx = marlin_make_empty_g_idx(device) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 66bc5395dbd7a..38495d5a5a863 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -121,12 +121,12 @@ class BitsAndBytesLinearMethod(LinearMethodBase): def __init__(self, quant_config: BitsAndBytesConfig): try: import bitsandbytes - if bitsandbytes.__version__ < "0.42.0": + if bitsandbytes.__version__ < "0.44.0": raise ImportError("bitsandbytes version is wrong. Please " - "install bitsandbytes>=0.42.0.") + "install bitsandbytes>=0.44.0.") except ImportError as err: - raise ImportError("Please install bitsandbytes>=0.42.0 via " - "`pip install bitsandbytes>=0.42.0` to use " + raise ImportError("Please install bitsandbytes>=0.44.0 via " + "`pip install bitsandbytes>=0.44.0` to use " "bitsandbytes quantizer.") from err self.quant_config = quant_config diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index e536fae45c845..362feeef2e33c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -73,7 +73,7 @@ def get_quant_method( if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) if isinstance(layer, FusedMoE): - return CompressedTensorsMoEMethod(self) + return CompressedTensorsMoEMethod.get_moe_method(self) return None @classmethod diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 7dee2fca81153..6666a4bf1f26a 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -5,12 +5,16 @@ import torch from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( WNA16_SUPPORTED_BITS) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - CompressionFormat) + CompressionFormat, QuantizationStrategy) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import is_hip, print_warning_once class GPTQMarlinState(Enum): @@ -18,11 +22,219 @@ class GPTQMarlinState(Enum): READY = enum.auto() -__all__ = ["CompressedTensorsMoEMethod"] +__all__ = [ + "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", + "CompressedTensorsWNA16MoEMethod" +] class CompressedTensorsMoEMethod(FusedMoEMethodBase): + @staticmethod + def get_moe_method( + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ) -> "CompressedTensorsMoEMethod": + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + weight_quant = quant_config.target_scheme_map["Linear"].get("weights") + input_quant = quant_config.target_scheme_map["Linear"].get( + "input_activations") + + if quant_config._is_wNa16_group_channel(weight_quant, input_quant): + return CompressedTensorsWNA16MoEMethod(quant_config) + elif quant_config._is_fp8_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Fp8MoEMethod(quant_config) + else: + raise RuntimeError( + f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") + + +class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( + "weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations") + + if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy == QuantizationStrategy.TENSOR): + raise ValueError( + "For FP8 Fused MoE layers, only per-tensor scales" + "for weights and activations are supported. Found " + f"{self.weight_quant}, {self.input_quant}") + + self.static_input_scales = not self.input_quant.dynamic + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.static_input_scales: + if (layer.w13_input_scale is None or layer.w2_input_scale is None): + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.w13_input_scale) + or not all_close_1d(layer.w2_input_scale)): + print_warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. ") + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False) + + # If rocm, normalize the weights and scales to e4m3fnuz + if is_hip(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, + layer.w13_input_scale) + w2_weight, w2_weight_scale, w2_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, + layer.w2_input_scale) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, + requires_grad=False) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter(w13_input_scale, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, + requires_grad=False) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, + requires_grad=False) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_experts(x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=True, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) + + +class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + def __init__( self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 3cade3d3fbcd0..cb65557be8f90 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -1,17 +1,16 @@ -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Set import torch -from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( ActivationOrdering) +from vllm.model_executor.layers.quantization.kernels import ( + MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, marlin_repeat_scales_on_all_ranks, - marlin_sort_g_idx, replace_tensor, verify_marlin_supported, - verify_marlin_supports_shape) + marlin_repeat_scales_on_all_ranks) from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -19,6 +18,8 @@ RowvLLMParameter) from vllm.scalar_type import scalar_types +logger = init_logger(__name__) + __all__ = ["CompressedTensorsWNA16"] WNA16_SUPPORTED_TYPES_MAP = { 4: scalar_types.uint4b8, @@ -28,6 +29,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): + _kernel_backends_being_used: Set[str] = set() def __init__(self, strategy: str, @@ -52,35 +54,43 @@ def __init__(self, self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits] - # Verify supported on platform. - verify_marlin_supported(quant_type=self.quant_type, - group_size=self.group_size) - @classmethod def get_min_capability(cls) -> int: # ampere and up return 80 - def create_weights(self, layer: torch.nn.Module, input_size: int, - output_partition_sizes: List[int], + def create_weights(self, layer: torch.nn.Module, output_size: int, + input_size: int, output_partition_sizes: List[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): output_size_per_partition = sum(output_partition_sizes) + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_type, + act_type=params_dtype, + group_size=self.group_size, + zero_points=False, + has_g_idx=self.has_g_idx + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsWNA16", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + # If group_size is -1, we are in channelwise case. group_size = self.group_size if self.group_size != -1 else input_size row_parallel = (input_size != input_size_per_partition) partition_scales = not marlin_repeat_scales_on_all_ranks( self.has_g_idx, self.group_size, row_parallel) - verify_marlin_supports_shape( - output_size_per_partition=output_size_per_partition, - input_size_per_partition=input_size_per_partition, - input_size=input_size, - group_size=group_size) - scales_and_zp_size = input_size // group_size if partition_scales: @@ -137,69 +147,17 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, weight_loader=weight_loader) layer.register_parameter("weight_g_idx", weight_g_idx) - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.input_size = input_size - layer.group_size = group_size + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name=None, + w_gidx_param_name="weight_g_idx") # Checkpoints are serialized in compressed-tensors format, which is - # different from marlin format. Handle repacking here. + # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - device = layer.weight_packed.device - - # Allocate marlin workspace. - layer.workspace = marlin_make_workspace( - layer.output_size_per_partition, device) - - # Handle sorting for activation reordering if needed. - if self.has_g_idx: - g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx) - layer.g_idx_sort_indices = g_idx_sort_indices - replace_tensor(layer, "weight_g_idx", g_idx) - else: - layer.weight_g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - # No zero-point - layer.weight_zp = marlin_make_empty_g_idx(device) - # Update for kernel - layer.weight_packed = torch.nn.Parameter( - layer.weight_packed.t().contiguous(), requires_grad=False) - layer.weight_scale = torch.nn.Parameter( - layer.weight_scale.squeeze().t().contiguous(), requires_grad=False) - - # Repack weights from compressed-tensors format to marlin format. - marlin_qweight = ops.gptq_marlin_repack( - layer.weight_packed, - perm=layer.g_idx_sort_indices, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - num_bits=self.quant_type.size_bits) - replace_tensor(layer, "weight_packed", marlin_qweight) - - # Permute scales from compressed-tensors format to marlin format. - # scale is required on all partitions if activation reordering - marlin_scales = marlin_permute_scales( - layer.weight_scale, - size_k=(layer.input_size - if self.has_g_idx else layer.input_size_per_partition), - size_n=layer.output_size_per_partition, - group_size=layer.group_size) - replace_tensor(layer, "weight_scale", marlin_scales) + self.kernel.process_weights_after_loading(layer) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: - - return apply_gptq_marlin_linear( - input=x, - weight=layer.weight_packed, - weight_scale=layer.weight_scale, - weight_zp=layer.weight_zp, - g_idx=layer.weight_g_idx, - g_idx_sort_indices=layer.g_idx_sort_indices, - workspace=layer.workspace, - wtype=self.quant_type, - output_size_per_partition=layer.output_size_per_partition, - input_size_per_partition=layer.input_size_per_partition, - is_k_full=True, - bias=bias) + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index c067a76405df6..1cfadb4f42ca8 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -217,6 +217,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) layer.qweight = Parameter(layer.qweight.data, requires_grad=False) layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) + layer.scales = Parameter(layer.scales.data, requires_grad=False) # exllama needs to shuffle the weight after the weight is loaded # here we do the shuffle on first forward pass diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 5a1b2d701ab0d..3d3ce711e58b0 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,7 +1,6 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Set, Union import torch -from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -11,12 +10,12 @@ set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.kernels import ( + MPLinearLayerConfig, choose_mp_linear_kernel) +from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, - marlin_permute_scales, marlin_repeat_scales_on_all_ranks, - marlin_sort_g_idx, replace_tensor, verify_marlin_supported, - verify_marlin_supports_shape) + check_marlin_supported, marlin_moe_permute_scales, + marlin_repeat_scales_on_all_ranks, verify_marlin_supported) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -159,6 +158,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase): quant_config: The GPTQ Marlin quantization config. """ + _kernel_backends_being_used: Set[str] = set() + def __init__(self, quant_config: GPTQMarlinConfig) -> None: self.quant_config = quant_config @@ -176,25 +177,34 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: - - del output_size output_size_per_partition = sum(output_partition_sizes) is_row_parallel = input_size != input_size_per_partition weight_loader = extra_weight_attrs.get("weight_loader") + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_config.quant_type, + act_type=params_dtype, + group_size=self.quant_config.group_size, + zero_points=False, + has_g_idx=self.quant_config.desc_act + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for GPTQMarlinLinearMethod", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + # Normalize group_size if self.quant_config.group_size != -1: group_size = self.quant_config.group_size else: group_size = input_size - verify_marlin_supports_shape( - output_size_per_partition=output_size_per_partition, - input_size_per_partition=input_size_per_partition, - input_size=input_size, - group_size=group_size, - ) - # Determine sharding if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, self.quant_config.group_size, @@ -275,57 +285,15 @@ def create_weights( layer.register_parameter("g_idx", g_idx) layer.register_parameter("scales", scales) layer.register_parameter("qzeros", qzeros) - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.input_size = input_size - layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act, - is_row_parallel) - - # Checkpoints are serialized in AutoGPTQ format, which is different from the - # marlin format. This function is called after the weights are loaded. - # Here, we handle the repacking, including the activation reordering case. - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - device = layer.qweight.device - # required by torch.compile - layer.qweight = Parameter(layer.qweight.data, requires_grad=False) - layer.scales = Parameter(layer.scales.data, requires_grad=False) + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="qweight", + w_s_param_name="scales", + w_zp_param_name="qzeros", + w_gidx_param_name="g_idx") - # Allocate marlin workspace - layer.workspace = marlin_make_workspace( - layer.output_size_per_partition, device) - - # Handle sorting for activation reordering if needed. - if self.quant_config.desc_act: - g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx) - layer.g_idx_sort_indices = g_idx_sort_indices - replace_tensor(layer, "g_idx", g_idx) - else: - layer.g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - # No zero-point - layer.zp = marlin_make_empty_g_idx(device) - - # Repack weights from autogptq format to marlin format. - marlin_qweight = ops.gptq_marlin_repack( - layer.qweight, - perm=layer.g_idx_sort_indices, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits, - ) - replace_tensor(layer, "qweight", marlin_qweight) - - # Permute scales from autogptq format to marlin format. - marlin_scales = marlin_permute_scales( - layer.scales, - size_k=(layer.input_size if self.quant_config.desc_act else - layer.input_size_per_partition), - size_n=layer.output_size_per_partition, - group_size=self.quant_config.group_size, - ) - replace_tensor(layer, "scales", marlin_scales) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) def apply( self, @@ -333,20 +301,7 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return apply_gptq_marlin_linear( - input=x, - weight=layer.qweight, - weight_scale=layer.scales, - weight_zp=layer.zp, - g_idx=layer.g_idx, - g_idx_sort_indices=layer.g_idx_sort_indices, - workspace=layer.workspace, - wtype=self.quant_config.quant_type, - output_size_per_partition=layer.output_size_per_partition, - input_size_per_partition=layer.input_size_per_partition, - is_k_full=layer.is_k_full, - bias=bias, - ) + return self.kernel.apply_weights(layer, x, bias) class GPTQMarlinMoEMethod(FusedMoEMethodBase): @@ -506,12 +461,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w13_g_idx_sort_indices[e]] w2_sorted_g_idx[e] = layer.w2_g_idx[e][ w2_g_idx_sort_indices[e]] - replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx) - replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx) - replace_tensor(layer, "w13_g_idx_sort_indices", - w13_g_idx_sort_indices) - replace_tensor(layer, "w2_g_idx_sort_indices", - w2_g_idx_sort_indices) + replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx) + replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx) + replace_parameter(layer, "w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", + w2_g_idx_sort_indices) else: # Reset g_idx related tensors num_experts = layer.w13_g_idx.shape[0] @@ -544,7 +499,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_qweight.shape[2], self.quant_config.quant_type.size_bits, ) - replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + replace_parameter(layer, "w13_qweight", marlin_w13_qweight) marlin_w2_qweight = ops.gptq_marlin_moe_repack( layer.w2_qweight, layer.w2_g_idx_sort_indices, @@ -552,7 +507,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_qweight.shape[2], self.quant_config.quant_type.size_bits, ) - replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + replace_parameter(layer, "w2_qweight", marlin_w2_qweight) # Repack scales marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, @@ -560,14 +515,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_n=layer.w13_scales.shape[2], group_size=self.quant_config.group_size, ) - replace_tensor(layer, "w13_scales", marlin_w13_scales) + replace_parameter(layer, "w13_scales", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, size_n=layer.w2_scales.shape[2], group_size=self.quant_config.group_size, ) - replace_tensor(layer, "w2_scales", marlin_w2_scales) + replace_parameter(layer, "w2_scales", marlin_w2_scales) def apply( self, diff --git a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py new file mode 100644 index 0000000000000..fe50c4930d043 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py @@ -0,0 +1,83 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Optional, Tuple + +import torch + +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.scalar_type import ScalarType + + +@dataclass +class MPLinearLayerConfig: + full_weight_shape: Tuple[int, int] # [in, out] + partition_weight_shape: Tuple[int, int] + weight_type: ScalarType + act_type: torch.dtype + group_size: int + zero_points: bool + has_g_idx: bool + + +class MPLinearKernel(ABC): + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + raise NotImplementedError + + def __init__(self, + c: MPLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + w_zp_param_name: Optional[str] = None, + w_gidx_param_name: Optional[str] = None) -> None: + assert self.can_implement(c) + self.config = c + self.w_q_name = w_q_param_name + self.w_s_name = w_s_param_name + self.w_zp_name = w_zp_param_name + self.w_gidx_name = w_gidx_param_name + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + def _transform_param(self, layer: torch.nn.Module, name: Optional[str], + fn: Callable) -> None: + if name is not None and getattr(layer, name, None) is not None: + + old_param = getattr(layer, name) + new_param = fn(old_param) + # replace the parameter with torch.nn.Parameter for TorchDynamo + # compatibility + replace_parameter( + layer, name, + torch.nn.Parameter(new_param.data, requires_grad=False)) + + def _get_weight_params( + self, layer: torch.nn.Module + ) -> Tuple[torch.Tensor, # w_q + torch.Tensor, # w_s + Optional[torch.Tensor], # w_zp, + Optional[torch.Tensor] # w_gidx + ]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.w_zp_name or "", None), + getattr(layer, self.w_gidx_name or "", None), + ) diff --git a/vllm/model_executor/layers/quantization/kernels/__init__.py b/vllm/model_executor/layers/quantization/kernels/__init__.py new file mode 100644 index 0000000000000..47591c2aa644e --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/__init__.py @@ -0,0 +1,72 @@ +import os +from typing import List, Optional, Type + +from vllm.model_executor.layers.quantization.kernels.machete import ( + MacheteLinearKernel) +from vllm.model_executor.layers.quantization.kernels.marlin import ( + MarlinLinearKernel) +from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( + MPLinearKernel, MPLinearLayerConfig) +from vllm.platforms import current_platform + +# in priority/performance order (when available) +_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ + MacheteLinearKernel, + MarlinLinearKernel, +] + + +def choose_mp_linear_kernel( + config: MPLinearLayerConfig, + compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: + """ + Choose an MPLinearKernel that can implement the given config for the given + compute capability. Attempts to choose the best kernel in terms of + performance. + + Args: + config (MPLinearLayerConfig): Description of the linear layer to be + implemented. + compute_capability (Optional[int], optional): The compute capability of + the target device, if None uses `current_platform` to get the compute + capability. Defaults to None. + + Raises: + ValueError: If no kernel can implement the given config. + + Returns: + Type[MPLinearKernel]: Chosen kernel. + """ + if compute_capability is None: + if current_platform is None: + raise ValueError("Cannot determine compute capability") + _cc = current_platform.get_device_capability() + compute_capability = _cc[0] * 10 + _cc[1] + + failure_reasons = [] + for kernel in _POSSIBLE_KERNELS: + if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ + .split(","): + failure_reasons.append( + f' {kernel.__name__} disabled by environment variable') + continue + + if kernel.get_min_capability() > compute_capability: + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel.get_min_capability()}, current compute capability " + f"is {compute_capability}") + continue + + can_implement, failure_reason = kernel.can_implement(config) + if can_implement: + return kernel + else: + failure_reasons.append( + f' {kernel.__name__} cannot implement due to: {failure_reason}' + ) + + raise ValueError( + "Failed to find a kernel that can implement the "\ + "WNA16 linear layer. Reasons: \n" + + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/kernels/machete.py b/vllm/model_executor/layers/quantization/kernels/machete.py new file mode 100644 index 0000000000000..fa39cb511528e --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/machete.py @@ -0,0 +1,118 @@ +from functools import partial +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.machete_utils import ( + MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, + query_machete_supported_quant_types) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_weights_into_int32, unpack_weights_into_int32) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class MacheteLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.has_g_idx and\ + c.partition_weight_shape[0] != c.full_weight_shape[0]: + return False, "Act reordering currently not supported by Machete, "\ + "when the input features are partitioned across "\ + "devices" + + if c.zero_points: + return False, "Zero points currently not supported by "\ + " Compressed Tensors + Machete. (Kernel supports it"\ + " but CompressedTensorsWNA16 does not so support has"\ + " not been added to MacheteWNA16Kernel yet" + + if c.weight_type not in query_machete_supported_quant_types( + c.zero_points): + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Machete, supported types are: "\ + f"{query_machete_supported_quant_types(c.zero_points)}" + + if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Machete, supported group sizes are: "\ + f"{MACHETE_SUPPORTED_GROUP_SIZES}" + + return check_machete_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1]) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + c = self.config + + if c.has_g_idx: + assert self.w_gidx_name is not None + perm = torch.argsort(getattr(layer, self.w_gidx_name))\ + .to(torch.int) + + self.act_perm = lambda x: x[:, perm] + # use `ops.permute_cols` if possible + if c.act_type in [torch.float16, torch.bfloat16] \ + and c.partition_weight_shape[0] % 8 == 0: + self.act_perm = partial(ops.permute_cols, perm=perm) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + if c.has_g_idx: + x_unpacked = unpack_weights_into_int32(x.data, + c.weight_type, + packed_dim=0) + x_perm = x_unpacked[perm, :] + x.data = pack_weights_into_int32(x_perm, + c.weight_type, + packed_dim=0) + x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), + self.config.weight_type) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous() + return x + + # Repack weights and scales for Machete + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + if c.has_g_idx: + x_2d = self.act_perm(x_2d) + + output = ops.machete_gemm(a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_zeros=None, + b_scales=w_s, + b_group_size=c.group_size) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/marlin.py b/vllm/model_executor/layers/quantization/kernels/marlin.py new file mode 100644 index 0000000000000..5b4bba76ee0ca --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/marlin.py @@ -0,0 +1,132 @@ +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, + check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, + marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, + query_marlin_supported_quant_types) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class MarlinLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.zero_points: + return False, "Zero points currently not supported by "\ + " MarlinLinearKernel. Will be added when AWQMarlin "\ + "is migrated over to using MPLinearKernel backend" + + quant_types = query_marlin_supported_quant_types(c.zero_points) + if c.weight_type not in quant_types: + return False, f"Quant type ({c.weight_type}) not supported by"\ + f" Marlin, supported types are: {quant_types}" + + if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Marlin, supported group sizes are: "\ + f"{MARLIN_SUPPORTED_GROUP_SIZES}" + + return check_marlin_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1], + c.full_weight_shape[1], + c.group_size) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, self.w_q_name).device + c = self.config + + row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0]) + self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) + + # Allocate marlin workspace. + self.workspace = marlin_make_workspace(c.partition_weight_shape[1], + device) + + # Default names since marlin requires empty parameters for these, + # TODO: remove this requirement from marlin (allow optional tensors) + if self.w_gidx_name is None: + self.w_gidx_name = "g_idx" + if self.w_zp_name is None: + self.w_zp_name = "w_zp" + + if c.has_g_idx: + g_idx, g_idx_sort_indices = marlin_sort_g_idx( + getattr(layer, self.w_gidx_name)) + self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + if c.zero_points: + pass + # TODO (lucas): add the following when AWQMarlin is migrated over to + # using MPLinearKernel backend + # self._transform_param(layer, self.w_zp_name, lambda x: \ + # marlin_zero_points( + # x, + # size_k=c.partition_weight_shape[0], + # size_n=c.partition_weight_shape[1], + # num_bits=c.weight_type.size_bits)) + else: + setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = ops.gptq_marlin_repack(x.data.contiguous(), + perm=layer.g_idx_sort_indices, + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = marlin_permute_scales(x.data.contiguous(), + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + group_size=c.group_size) + return x + + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) + + # `process_weights_after_loading` will ensure w_zp and w_gidx are not + # None for marlin + return apply_gptq_marlin_linear( + input=x, + weight=w_q, + weight_scale=w_s, + weight_zp=w_zp, # type: ignore + g_idx=w_gidx, # type: ignore + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=self.workspace, + wtype=c.weight_type, + input_size_per_partition=c.partition_weight_shape[0], + output_size_per_partition=c.partition_weight_shape[1], + is_k_full=self.is_k_full, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/utils/__init__.py b/vllm/model_executor/layers/quantization/utils/__init__.py index e69de29bb2d1d..e60f0c79ac1f7 100644 --- a/vllm/model_executor/layers/quantization/utils/__init__.py +++ b/vllm/model_executor/layers/quantization/utils/__init__.py @@ -0,0 +1,3 @@ +from .layer_utils import replace_parameter, update_tensor_inplace + +__all__ = ['update_tensor_inplace', 'replace_parameter'] diff --git a/vllm/model_executor/layers/quantization/utils/layer_utils.py b/vllm/model_executor/layers/quantization/utils/layer_utils.py new file mode 100644 index 0000000000000..edce6d19b6c49 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/layer_utils.py @@ -0,0 +1,37 @@ +from typing import Union + +import torch + + +def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor): + assert dst.dtype == src.dtype, "Tensors must have the same dtype" + + # update tensor shape and stride + dst.as_strided_(src.shape, src.stride()) + + # If not the same underlying storage move tensor data + if dst.data_ptr() != src.data_ptr(): + dst.copy_(src) + del src + + +# Newly generated tensors need to replace existing tensors that are +# already registered as parameters by vLLM (and won't be freed) +def replace_parameter(mod: torch.nn.Module, name: str, + new: Union[torch.Tensor, torch.nn.Parameter]) -> None: + + old = getattr(mod, name) + if type(old) is type(new) and old.dtype == new.dtype and \ + old.untyped_storage().nbytes() == new.untyped_storage().nbytes(): + # If we can just update in-place to avoid re-registering + # can be faster if the underlying storage is the same + update_tensor_inplace(old, new) + else: + # Fallback re-register parameter, convert to Parameter if necessary + # this not only ensures we don't register a tensor as a parameter, but + # also ensures that all parameter subclasses get re-registered as + # parameters for `torch.compile` compatibility + if not isinstance(new, torch.nn.Parameter): + new = torch.nn.Parameter(new, requires_grad=False) + mod.register_parameter(name, + torch.nn.Parameter(new, requires_grad=False)) diff --git a/vllm/model_executor/layers/quantization/utils/machete_utils.py b/vllm/model_executor/layers/quantization/utils/machete_utils.py new file mode 100644 index 0000000000000..18e1332050cdd --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/machete_utils.py @@ -0,0 +1,30 @@ +from typing import List, Optional, Tuple + +import torch + +from vllm.scalar_type import ScalarType, scalar_types + +MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128] +MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128] + + +def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]: + if zero_points: + return [scalar_types.uint4, scalar_types.uint8] + else: + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]: + return [torch.float16, torch.bfloat16] + + +def check_machete_supports_shape(in_features: int, out_featrues: int) \ + -> Tuple[bool, Optional[str]]: + if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: + return False, "Input features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}" + if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: + return False, "Output features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}" + return True, None diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index fea94cf7322ad..53762965732ce 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -120,6 +120,19 @@ def verify_marlin_supports_shape(output_size_per_partition: int, "with --quantization gptq.") +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) + except ValueError as e: + return False, e.__str__() + return True, None + + def marlin_make_workspace(output_size_per_partition: int, device: torch.device) -> torch.Tensor: max_workspace_size = (output_size_per_partition // @@ -148,6 +161,11 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: requires_grad=False) +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + def marlin_sort_g_idx( g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) @@ -240,17 +258,6 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, return marlin_zp -# Newly generated tensors need to replace existing tensors that are -# already registered as parameters by vLLM (and won't be freed) -def replace_tensor(layer: torch.nn.Module, name: str, - new_t: torch.Tensor) -> None: - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - def apply_gptq_marlin_linear( input: torch.Tensor, weight: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index bdfda31de852b..833d00073564e 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -20,6 +20,49 @@ } +def pack_weights_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_weights_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: # prefix: model.layers.0.self_attn.q_proj # proj_name: q_proj diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 2176602ce894e..66b1e4348df27 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -15,7 +15,8 @@ SamplingTensors, SequenceGroupToSample) from vllm.sampling_params import SamplingType -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SampleLogprobs, SequenceOutput) from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics from vllm.utils import rpd_mark @@ -800,10 +801,10 @@ def _sample_with_torch( # Create output tensor for sampled token ids. if include_gpu_probs_tensor: - sampled_token_ids_tensor = torch.empty(logprobs.shape[0], - 1, - dtype=torch.long, - device=logprobs.device) + sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1), + VLLM_INVALID_TOKEN_ID, + dtype=torch.long, + device=logprobs.device) else: sampled_token_ids_tensor = None diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 8c03e46927752..584cf971d9c05 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -80,7 +80,7 @@ def forward( target_probs = target_with_bonus_probs[:, :-1] accepted = self._evaluate_accepted_tokens(target_probs, draft_token_ids) - recovered_token_ids = self._replacement_token_ids(target_probs) + recovered_token_ids = self._get_recovered_token_ids(target_probs) output_token_ids = self._create_output(accepted, recovered_token_ids, draft_token_ids, bonus_token_ids) @@ -148,16 +148,10 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): accepted_mask = candidates_prob > threshold return accepted_mask - def _replacement_token_ids(self, target_probs): + def _get_recovered_token_ids(self, target_probs): """ - Generate one replacement token ID for each sequence based on target - probabilities. The replacement token is used as the fallback option - if typical acceptance sampling does not accept any draft tokens for - that particular sequence. - - This method computes the token IDs to be replaced by selecting the - token with the highest probability for each sequence in the first - position. The rest of the output is filled with -1. + The recovered token ids will fill the first unmatched token + by the target token. Parameters ---------- @@ -168,13 +162,9 @@ def _replacement_token_ids(self, target_probs): Returns ------- torch.Tensor - A tensor of shape (batch_size, k) with the replacement - token IDs. Only the first column is set, and the rest of the - columns are filled with -1. + A tensor of shape (batch_size, k) with the recovered token + ids which are selected from target probs. """ - max_indices = torch.argmax(target_probs[:, 0, :], dim=1) - output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), - dtype=self.token_id_dtype, - device=target_probs.device) - output[:, 0] = max_indices - return output + max_indices = torch.argmax(target_probs, dim=-1) + + return max_indices diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index ce95c98279aef..8fed5267a9eb5 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1,6 +1,7 @@ # ruff: noqa: SIM117 import collections import copy +import dataclasses import fnmatch import glob import json @@ -8,7 +9,8 @@ import os from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional, Tuple, Type +from typing import (Any, Dict, Generator, Iterable, List, Optional, Tuple, + Type, cast) import gguf import huggingface_hub @@ -207,6 +209,22 @@ def load_model(self, *, model_config: ModelConfig, class DefaultModelLoader(BaseModelLoader): """Model loader that can load different file types from disk.""" + @dataclasses.dataclass + class Source: + """A source for weights.""" + + model_or_path: str + """The model ID or path.""" + + revision: Optional[str] + """The optional model revision.""" + + prefix: str = "" + """A prefix to prepend to all weights.""" + + fall_back_to_pt: bool = True + """Whether .pt weights can be used.""" + def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: @@ -313,17 +331,16 @@ def _prepare_weights(self, model_name_or_path: str, return hf_folder, hf_weights_files, use_safetensors def _get_weights_iterator( - self, model_name_or_path: str, revision: Optional[str], - fall_back_to_pt: bool + self, source: "Source" ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( - model_name_or_path, revision, fall_back_to_pt) + source.model_or_path, source.revision, source.fall_back_to_pt) if self.load_config.load_format == LoadFormat.NPCACHE: # Currently np_cache only support *.bin checkpoints assert use_safetensors is False weights_iterator = np_cache_weights_iterator( - model_name_or_path, self.load_config.download_dir, hf_folder, + source.model_or_path, self.load_config.download_dir, hf_folder, hf_weights_files) elif use_safetensors: weights_iterator = safetensors_weights_iterator(hf_weights_files) @@ -341,7 +358,29 @@ def _xla_weights_iterator(iterator: Generator): xm.mark_step() weights_iterator = _xla_weights_iterator(weights_iterator) - return weights_iterator + + # Apply the prefix. + return ((source.prefix + name, tensor) + for (name, tensor) in weights_iterator) + + def _get_all_weights( + self, + model_config: ModelConfig, + model: nn.Module, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + + primary_weights = DefaultModelLoader.Source( + model_config.model, + model_config.revision, + prefix="", + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", + True)) + yield from self._get_weights_iterator(primary_weights) + + secondary_weights = cast(Iterable[DefaultModelLoader.Source], + getattr(model, "secondary_weights", ())) + for source in secondary_weights: + yield from self._get_weights_iterator(source) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, @@ -360,13 +399,8 @@ def load_model(self, *, model_config: ModelConfig, model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) - model.load_weights( - self._get_weights_iterator(model_config.model, - model_config.revision, - fall_back_to_pt=getattr( - model, - "fall_back_to_pt_during_load", - True)), ) + + model.load_weights(self._get_all_weights(model_config, model)) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) @@ -829,12 +863,12 @@ def _get_quantized_weights_iterator( # only load the bitsandbytes module when needed try: import bitsandbytes - if bitsandbytes.__version__ < "0.42.0": + if bitsandbytes.__version__ < "0.44.0": raise ImportError("bitsandbytes version is wrong. Please " - "install bitsandbytes>=0.42.0.") + "install bitsandbytes>=0.44.0.") except ImportError as err: - raise ImportError("Please install bitsandbytes>=0.42.0 via " - "`pip install bitsandbytes>=0.42.0` to use " + raise ImportError("Please install bitsandbytes>=0.44.0 via " + "`pip install bitsandbytes>=0.44.0` to use " "bitsandbytes quantizer.") from err hf_weights_files, use_safetensors = self._prepare_weights( diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 9e4476943326d..c8fb4bf5b8b81 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -102,6 +102,8 @@ "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), "UltravoxModel": ("ultravox", "UltravoxModel"), + "MllamaForConditionalGeneration": ("mllama", + "MllamaForConditionalGeneration"), } _CONDITIONAL_GENERATION_MODELS = { "BartModel": ("bart", "BartForConditionalGeneration"), diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index ad1ab0231d861..13811d33768a6 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -44,7 +44,7 @@ def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None: self.model = model_cls(self.config.model, *args, **kwargs) self.fc = nn.Linear(config.model.hidden_size * 2, config.model.hidden_size, - bias=False) + bias=getattr(self.config, "bias", False)) self.orig_vocab_size = config.vocab_size self.truncated_vocab_size = config.truncated_vocab_size @@ -136,10 +136,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if self.config.truncated_vocab_size < self.config.vocab_size: self.token_map = nn.Parameter(loaded_weight, requires_grad=False) - elif name.startswith("fc."): + elif name.startswith("fc.weight"): weight_loader = getattr(self.fc.weight, "weight_loader", default_weight_loader) weight_loader(self.fc.weight, loaded_weight) + elif name.startswith("fc.bias"): + if self.fc.bias is not None: + weight_loader = getattr(self.fc.bias, "weight_loader", + default_weight_loader) + weight_loader(self.fc.bias, loaded_weight) + else: + raise ValueError("Found bias in the loaded weights " + "but the model config doesn't have bias") elif name.startswith("model.lm_head.") or name.startswith( "model.model."): model_weights[name.split("model.", 1)[-1]] = loaded_weight diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 5f365bbc30670..d4853fd790098 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -25,6 +25,7 @@ import torch from torch import nn +from transformers import GraniteConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig @@ -48,7 +49,6 @@ default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs.granite import GraniteConfig from vllm.utils import is_hip from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 005a24f10aa17..b1748700d481a 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -19,7 +19,7 @@ from vllm.distributed import get_pp_group from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.intern_vit import InternVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -230,8 +230,9 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): else: raise TypeError(f"Invalid image type: {type(image_data)}") - tokenizer = cached_get_tokenizer(model_config.tokenizer, - trust_remote_code=True) + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) prompt = llm_inputs.get("prompt") prompt_token_ids = llm_inputs["prompt_token_ids"] @@ -278,8 +279,9 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): use_thumbnail=use_thumbnail) for img in data ] model_config = ctx.model_config - tokenizer = cached_get_tokenizer(model_config.tokenizer, - trust_remote_code=True) + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) image_token_id = tokenizer.encode(IMG_CONTEXT, add_special_tokens=False, return_tensors="pt")[0] @@ -298,8 +300,9 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int, model_config = ctx.model_config hf_config = ctx.get_hf_config() vision_config = hf_config.vision_config - tokenizer = cached_get_tokenizer(model_config.tokenizer, - trust_remote_code=True) + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) seq_data = dummy_seq_data_for_clip( vision_config, @@ -376,6 +379,11 @@ def __init__(self, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) + if hasattr(self.language_model, "sampler"): + self.sampler = self.language_model.sampler + else: + self.sampler = Sampler() + def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index c0fb6fef78bab..7da7991b4f849 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -33,6 +33,7 @@ from torch import nn from torch.nn.init import trunc_normal_ from transformers import PretrainedConfig +from typing_extensions import NotRequired from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig @@ -52,6 +53,7 @@ from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SequenceData @@ -64,6 +66,17 @@ } +class MiniCPMVImageInput(TypedDict): + """Input mapper input with auxiliary data for computing image bounds.""" + image: Image.Image + + # Image bounds token ids in 0-dim scaler tensor. + im_start_id: torch.Tensor + im_end_id: torch.Tensor + slice_start_id: NotRequired[torch.Tensor] + slice_end_id: NotRequired[torch.Tensor] + + class MiniCPMVImagePixelInputs(TypedDict): pixel_values: List[torch.Tensor] """ @@ -88,8 +101,6 @@ class MiniCPMVImagePixelInputs(TypedDict): """ -MiniCPMVImageInputs = MiniCPMVImagePixelInputs - DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) @@ -234,6 +245,25 @@ def forward(self, x: torch.Tensor, return x +def _build_image_input(ctx: InputContext, + image: Image.Image) -> MiniCPMVImageInput: + tokenizer = cached_get_tokenizer( + ctx.model_config.tokenizer, + trust_remote_code=ctx.model_config.trust_remote_code) + if hasattr(tokenizer, "slice_start_id"): + return MiniCPMVImageInput( + image=image, + im_start_id=torch.tensor(tokenizer.im_start_id), + im_end_id=torch.tensor(tokenizer.im_end_id), + slice_start_id=torch.tensor(tokenizer.slice_start_id), + slice_end_id=torch.tensor(tokenizer.slice_end_id)) + else: + return MiniCPMVImageInput(image=image, + im_start_id=torch.tensor( + tokenizer.im_start_id), + im_end_id=torch.tensor(tokenizer.im_end_id)) + + def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: version_float = getattr(config, "version", None) @@ -257,10 +287,13 @@ def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int): return SequenceData.from_token_counts((0, seq_len)) -def dummy_image_for_minicpmv(hf_config: PretrainedConfig, num_images: int): +def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig, + num_images: int): width = height = hf_config.image_size - image = Image.new("RGB", (width, height), color=0) - return {"image": image if num_images == 1 else [image] * num_images} + image = _build_image_input(ctx, + image=Image.new("RGB", (width, height), + color=0)) + return {"image": [image] if num_images == 1 else [image] * num_images} def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, @@ -269,7 +302,7 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, num_images = mm_counts["image"] seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images) - mm_data = dummy_image_for_minicpmv(hf_config, num_images) + mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images) return seq_data, mm_data @@ -280,8 +313,9 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): return llm_inputs model_config = ctx.model_config version = get_version_by_config(model_config.hf_config) - tokenizer = cached_get_tokenizer(model_config.tokenizer, - trust_remote_code=True) + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) image_processor = cached_get_image_processor(model_config.tokenizer) def get_placeholder(image_size: Tuple[int, int], num_image: int): @@ -317,6 +351,10 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int): new_prompt = "".join(new_prompt_chunks) new_token_ids = tokenizer.encode(new_prompt) + multi_modal_data["image"] = [ + _build_image_input(ctx, image) for image in images + ] + llm_inputs = LLMInputs( prompt_token_ids=new_token_ids, prompt=new_prompt, @@ -325,6 +363,32 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int): return llm_inputs +def input_mapper_for_minicpmv(ctx: InputContext, data: object): + model_config = ctx.model_config + + image_processor = cached_get_image_processor( + model_config.model, trust_remote_code=model_config.trust_remote_code) + if image_processor is None: + raise RuntimeError("No HuggingFace processor is available " + "to process the image object") + + if not isinstance(data, list): + raise ValueError( + "Image input must be list of MiniCPMVImageInput, got (%s)", data) + batch_data = image_processor \ + .preprocess([img["image"] for img in data], return_tensors="pt") \ + .data + + if len(data) > 0: + batch_data["im_start_id"] = data[0]["im_start_id"] + batch_data["im_end_id"] = data[0]["im_end_id"] + if "slice_start_id" in data[0]: + batch_data["slice_start_id"] = data[0]["slice_start_id"] + batch_data["slice_end_id"] = data[0]["slice_end_id"] + + return MultiModalInputs(batch_data) + + class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): """ The abstract class of MiniCPMV can only be inherited, but cannot be @@ -365,7 +429,7 @@ def __init__( def get_embedding( self, input_ids: torch.Tensor, - image_inputs: Optional[MiniCPMVImageInputs], + image_inputs: Optional[MiniCPMVImagePixelInputs], ) -> Tuple[torch.Tensor, torch.Tensor]: vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids) if hasattr(self.config, "scale_emb"): @@ -393,14 +457,20 @@ def get_embedding( return vlm_embedding, vision_hidden_states - def _get_image_bounds(self, input_ids: torch.Tensor) -> torch.Tensor: - tokenizer = cached_get_tokenizer(self.config._name_or_path, - trust_remote_code=True) - start_cond = input_ids == tokenizer.im_start_id - end_cond = input_ids == tokenizer.im_end_id - if hasattr(tokenizer, "slice_start_id"): - start_cond |= (input_ids == tokenizer.slice_start_id) - end_cond |= (input_ids == tokenizer.slice_end_id) + def _get_image_bounds( + self, + input_ids: torch.Tensor, + im_start_id: torch.Tensor, + im_end_id: torch.Tensor, + slice_start_id: Optional[torch.Tensor] = None, + slice_end_id: Optional[torch.Tensor] = None) -> torch.Tensor: + # All the images in the batch should share the same special image + # bound token ids. + start_cond = input_ids == im_start_id[0] + end_cond = input_ids == im_end_id[0] + if slice_start_id is not None: + start_cond |= (input_ids == slice_start_id[0]) + end_cond |= (input_ids == slice_end_id[0]) image_start_tokens, = torch.where(start_cond) image_start_tokens += 1 @@ -419,7 +489,7 @@ def _parse_and_validate_inputs( self, input_ids: torch.Tensor, **kwargs: object, - ) -> Optional[MiniCPMVImageInputs]: + ) -> Optional[MiniCPMVImagePixelInputs]: pixel_values = kwargs.pop("pixel_values", []) tgt_sizes = kwargs.pop("tgt_sizes", []) @@ -456,8 +526,17 @@ def _parse_and_validate_inputs( if len(pixel_values_flat) == 0: return None - return MiniCPMVImageInputs( - image_bounds=self._get_image_bounds(input_ids), + im_start_id = kwargs.pop("im_start_id", None) + im_end_id = kwargs.pop("im_end_id", None) + slice_start_id = kwargs.pop("slice_start_id", None) + slice_end_id = kwargs.pop("slice_end_id", None) + if im_start_id is None: + return None + + return MiniCPMVImagePixelInputs( + image_bounds=self._get_image_bounds(input_ids, im_start_id, + im_end_id, slice_start_id, + slice_end_id), pixel_values=pixel_values_flat, tgt_sizes=torch.stack(tgt_sizes_flat), ) @@ -564,8 +643,8 @@ def get_vision_embedding( ) -> torch.Tensor: raise NotImplementedError - def get_vision_hidden_states(self, - data: MiniCPMVImageInputs) -> torch.Tensor: + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: raise NotImplementedError def is_default_weight_loading(self, name: str) -> bool: @@ -654,8 +733,8 @@ def get_vision_embedding( res.append(self.resampler(vision_embedding, tgt_size)) return torch.vstack(res) - def get_vision_hidden_states(self, - data: MiniCPMVImageInputs) -> torch.Tensor: + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: pixel_values = data["pixel_values"] return self.get_vision_embedding(pixel_values) @@ -713,8 +792,8 @@ def get_vision_embedding( vision_embedding = self.resampler(vision_embedding, tgt_sizes) return vision_embedding - def get_vision_hidden_states(self, - data: MiniCPMVImageInputs) -> torch.Tensor: + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] @@ -807,8 +886,8 @@ def get_vision_embedding( ).last_hidden_state return vision_embedding - def get_vision_hidden_states(self, - data: MiniCPMVImageInputs) -> torch.Tensor: + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: pixel_values = data["pixel_values"] tgt_sizes = data["tgt_sizes"] @@ -851,7 +930,7 @@ def is_default_weight_loading(self, name: str) -> bool: } -@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_minicpmv) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv) @INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py new file mode 100644 index 0000000000000..45d6ad3c0efa5 --- /dev/null +++ b/vllm/model_executor/models/mllama.py @@ -0,0 +1,1142 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mllama model.""" +import math +from array import array +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers.models.mllama.configuration_mllama as config_mllama +from PIL import Image +from torch import nn +from transformers.modeling_outputs import (BaseModelOutput, + CausalLMOutputWithPast) +from transformers.models.mllama.image_processing_mllama import ( + get_optimal_tiled_canvas) + +import vllm.distributed.parallel_state as ps +from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.config import CacheConfig, MultiModalConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData + +from .clip import CLIPMLP +from .interfaces import SupportsMultiModal +from .llama import LlamaDecoderLayer, LlamaMLP + +logger = init_logger(__name__) +MLLAMA_IMAGE_TOKEN_ID = 128256 +MLLAMA_IMAGE_TOKEN = "<|image|>" + + +class MllamaImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: """ + """(batch_size, max_num_image, max_num_chunk, num_channel, height, width)""" + aspect_ratio_ids: torch.Tensor + """Shape: `(batch_size, max_num_image)`""" + aspect_ratio_mask: torch.Tensor + """Shape: `(batch_size, max_num_image, max_num_tiles)`""" + + +# TODO: support LlamaImageEmbeddingInputs + + +def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): + # move encoder_prompt to prompt + if llm_inputs.get("prompt") is None: + llm_inputs["prompt"] = llm_inputs["encoder_prompt"] + llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"] + + # process multi-modal data + assert "decoder_multi_modal_data" not in llm_inputs, \ + "multi-modal data should be put in encoder message of mllama" + multi_modal_data = llm_inputs.get("encoder_multi_modal_data") + + if multi_modal_data is None or "image" not in multi_modal_data \ + or multi_modal_data["image"] is None: + # text-only + llm_inputs["encoder_prompt"] = "" + llm_inputs["encoder_prompt_token_ids"] = [] + llm_inputs["encoder_multi_modal_data"] = {} + return llm_inputs + + # get num_tiles + if isinstance(multi_modal_data['image'], Image.Image): + multi_modal_data['image'] = [multi_modal_data['image']] + hf_config = ctx.model_config.hf_config + num_tiles = 0 + for image in multi_modal_data["image"]: + width, height = image.size + tile_size = hf_config.vision_config.image_size + canvas_height, canvas_width = get_optimal_tiled_canvas( + image_height=height, + image_width=width, + max_image_tiles=hf_config.vision_config.max_num_tiles, + tile_size=tile_size, + ) + num_tiles_height = canvas_height // tile_size + num_tiles_width = canvas_width // tile_size + num_tiles += num_tiles_height * num_tiles_width + + # set encoder prompt based on num_tiles + assert hf_config.vision_config.image_size % 14 == 0, \ + "chunk size should be multiple of 14" + token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 + num_tokens = num_tiles * token_per_chunk + llm_inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens + llm_inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID + ] * num_tokens + + return llm_inputs + + +def get_max_mllama_image_tokens(ctx: InputContext) -> int: + hf_config = ctx.model_config.hf_config + token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 + return hf_config.vision_config.max_num_tiles * token_per_chunk + + +def dummy_decoder_seq_data(seq_len: int, num_images: int): + # <|image|> * num_images + 0 * (seq_len - num_images) + assert seq_len >= num_images, \ + "seq_len should be greater than or equal to num_images" + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [MLLAMA_IMAGE_TOKEN_ID]) * num_images + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - num_images) + return SequenceData(token_ids) + + +def dummy_encoder_seq_data(ctx: InputContext, num_images: int): + num_tokens = get_max_mllama_image_tokens(ctx) * num_images + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [MLLAMA_IMAGE_TOKEN_ID]) * num_tokens + return SequenceData(token_ids) + + +def dummy_image(num_images: int, ): + width = height = 1024 + image = Image.new("RGB", (width, height), color=0) + return {"image": image if num_images == 1 else [image] * num_images} + + +def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + num_images = mm_counts["image"] + return dummy_decoder_seq_data(seq_len, num_images), None + + +def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + num_images = mm_counts["image"] + return dummy_encoder_seq_data(ctx, num_images), dummy_image(num_images) + + +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, + 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles*target_length, max_num_tiles*target_length) + attention_mask = attention_mask.reshape(batch_size, + max_num_tiles * target_length, 1) + attention_mask = attention_mask @ attention_mask.transpose( + -1, -2) * torch.finfo(dtype).min + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + +class ColumnParallelConv2dPatch(torch.nn.Module): + """Conv2D Patching layer with model parallelism. + Column parallel over unfolded input. + Arguments: + in_channels: Input channels. + out_channels: Output channels. + kernel_size: Size of convolution kernel. + stride (default 1): Stride for convolution. + bias (default False): Use bias in Conv2d. + Input: (bsz, in_channels, width, height) + Output: (bsz, num_tokens, out_channels) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + bias: bool = False, + ) -> None: + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride) + self._linear = ColumnParallelLinear( + in_channels * kernel_size[0] * kernel_size[1], + out_channels, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._unfold(x) + x = x.permute(0, 2, 1) + x, _ = self._linear(x) + return x + + +class MllamaPrecomputedAspectRatioEmbedding(nn.Module): + + def __init__(self, + config: config_mllama.MllamaVisionConfig, + is_gated: bool = True): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.is_gated = is_gated + + self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.hidden_size) + if is_gated: + self.gate = nn.Parameter(torch.zeros(1)) + + def forward(self, hidden_state: torch.Tensor, + aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + embeddings = self.embedding(aspect_ratio_ids) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, + self.hidden_size) + + if self.is_gated: + embeddings = embeddings * self.gate.tanh() + + hidden_state = hidden_state + embeddings + return hidden_state + + +class MllamaPrecomputedPositionEmbedding(nn.Module): + + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.num_patches = (config.image_size // config.patch_size)**2 + 1 + self.hidden_size = config.hidden_size + self.scale = config.hidden_size**-0.5 + + self.gate = nn.Parameter(torch.zeros(1)) + + # position embedding + position_embedding = torch.randn(self.num_patches, self.hidden_size) + self.embedding = nn.Parameter(self.scale * position_embedding) + + # tile position embedding + self.tile_embedding = nn.Embedding( + self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.num_patches * self.hidden_size) + + def forward(self, hidden_state: torch.Tensor, + aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + # position embeddings + gated_position_embedding = (1 - self.gate.tanh()) * self.embedding + hidden_state = hidden_state + gated_position_embedding.view( + 1, 1, self.num_patches, self.hidden_size) + + # precomputed tile position embeddings + tile_position_embedding = self.tile_embedding(aspect_ratio_ids) + batch_size = hidden_state.shape[0] + tile_position_embedding = tile_position_embedding.reshape( + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size) + gated_tile_position_embedding = self.gate.tanh( + ) * tile_position_embedding + hidden_state = hidden_state + gated_tile_position_embedding + + return hidden_state + + +# TODO: support other attention backends for attention in vision model +class MllamaVisionSdpaAttention(nn.Module): + + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + + model_parallel_size = get_tensor_model_parallel_world_size() + self.embed_dim = config.hidden_size + self.num_heads = config.attention_heads + self.head_dim = config.hidden_size // config.attention_heads + self.num_local_heads = self.num_heads // model_parallel_size + self.q_size = self.num_local_heads * self.head_dim + self.kv_size = self.num_local_heads * self.head_dim + + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=False, + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.embed_dim, + bias=False, + input_is_parallel=True, + ) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_state) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(q.shape[0], q.shape[1], self.num_local_heads, + self.head_dim).transpose(1, 2) + k = k.view(k.shape[0], k.shape[1], self.num_local_heads, + self.head_dim).transpose(1, 2) + v = v.view(v.shape[0], v.shape[1], self.num_local_heads, + self.head_dim).transpose(1, 2) + + # TODO: remove padding in image encoder + attn_output = F.scaled_dot_product_attention(q, + k, + v, + attn_mask=attention_mask, + dropout_p=0.0) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(attn_output.shape[0], + attn_output.shape[1], -1) + output, _ = self.o_proj(attn_output) + return output + + +class MllamaVisionEncoderLayer(nn.Module): + + def __init__(self, + config: config_mllama.MllamaVisionConfig, + is_gated: bool = False): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.attention_heads + self.is_gated = is_gated + self.intermediate_size = config.intermediate_size + + self.self_attn = MllamaVisionSdpaAttention(config) + self.mlp = CLIPMLP(config) + + self.input_layernorm = nn.LayerNorm(self.hidden_size, + eps=config.norm_eps) + self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, + eps=config.norm_eps) + + # there used to be an if else here, no code path + if is_gated: + self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4) + self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state, + attention_mask=attention_mask) + gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() + hidden_state = residual + gate_attn * hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() + hidden_state = residual + gate_ffn * hidden_state + + return hidden_state + + +class MllamaVisionEncoder(nn.Module): + + def __init__(self, + config: config_mllama.MllamaVisionConfig, + num_layers=32, + is_gated=False, + output_hidden_states=None): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + MllamaVisionEncoderLayer(config, is_gated) + for _ in range(num_layers) + ]) + self.output_hidden_states = output_hidden_states or [] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> Union[Tuple, BaseModelOutput]: + encoder_states = () + + for i, encoder_layer in enumerate(self.layers): + if i in self.output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + hidden_states = encoder_layer( + hidden_states, + attention_mask, + ) + + if len(self.layers) - 1 in self.output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + + return hidden_states, encoder_states + + +class MllamaVisionModel(nn.Module): + + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + self.image_size = config.image_size + self.patch_size = config.patch_size + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.in_channels = config.num_channels + self.intermediate_layers_indices = config.intermediate_layers_indices + + self.num_patches = (self.image_size // self.patch_size)**2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = ColumnParallelConv2dPatch( + in_channels=config.num_channels, + out_channels=self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.class_embedding = nn.Parameter(self.scale * + torch.randn(self.hidden_size)) + self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( + config) + + self.pre_tile_positional_embedding = \ + MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) + self.post_tile_positional_embedding = \ + MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) + + # layer norms + self.layernorm_pre = nn.LayerNorm(self.hidden_size) + self.layernorm_post = nn.LayerNorm(self.hidden_size) + + # encoders + self.transformer = MllamaVisionEncoder( + config, + config.num_hidden_layers, + is_gated=False, + output_hidden_states=config.intermediate_layers_indices) + self.global_transformer = MllamaVisionEncoder(config, + config.num_global_layers, + is_gated=True) + + def apply_class_embedding(self, + hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, + hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward(self, pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + aspect_ratio_mask: torch.Tensor) -> torch.Tensor: + batch_size, num_concurrent_media, num_tiles, num_channels, \ + height, width = pixel_values.shape + + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, + height, width) + aspect_ratio_ids = aspect_ratio_ids.reshape( + batch_size * num_concurrent_media, -1) + + # patch embedding + patch_embeds = self.patch_embedding( + pixel_values.to(self.layernorm_pre.weight.dtype)) + hidden_state = patch_embeds + hidden_state = ps.get_tp_group().all_gather(hidden_state) + + # tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, -1, dim) + hidden_state = self.pre_tile_positional_embedding( + hidden_state, aspect_ratio_ids) + + # apply cls token + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # apply position embeddings + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, num_patches, dim) + hidden_state = self.gated_positional_embedding(hidden_state, + aspect_ratio_ids) + + # apply encoder + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = ( + 0, 0, 0, num_padding_patches + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + attention_mask = aspect_ratio_mask.reshape( + batch_size * num_concurrent_media, -1) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.layernorm_pre.weight.dtype, + ) + + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, + dim) + output = self.transformer( + hidden_state, + attention_mask=attention_mask, + ) + hidden_state, intermediate_hidden_states = output[0], output[1] + intermediate_hidden_states = torch.stack(intermediate_hidden_states, + dim=-1) + + # apply global encoder + hidden_state = self.layernorm_post(hidden_state) + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim) + hidden_state = self.post_tile_positional_embedding( + hidden_state, aspect_ratio_ids) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles * (num_patches + num_padding_patches), dim) + hidden_state = self.global_transformer( + hidden_state, attention_mask=attention_mask)[0] + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim) + hidden_state = hidden_state[:, :, :slice_index] + + # adding intermediate layer outputs + hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, + num_tiles, num_patches, dim) + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, num_tiles, + num_patches + num_padding_patches, -1) + intermediate_hidden_states = intermediate_hidden_states[:, :, : + slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1) + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], + dim=-1) + return hidden_state + + +class MllamaTextRMSNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + """ + MllamaTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class MllamaTextCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Optional[config_mllama.MllamaTextConfig] = None, + layer_idx: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.model_parallel_size = get_tensor_model_parallel_world_size() + self.num_heads = self.config.num_attention_heads + self.num_local_heads = self.num_heads // self.model_parallel_size + self.num_key_value_heads = self.config.num_key_value_heads + self.num_local_key_value_heads = \ + self.num_key_value_heads // self.model_parallel_size + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.head_dim = config.hidden_size // self.num_heads + self.layer_idx = layer_idx + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.q_local_size = self.num_local_heads * self.head_dim + self.kv_local_size = self.num_local_key_value_heads * self.head_dim + + # TODO: change to Q/KV separate linear after #7448 is merged + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.num_heads, + self.num_key_value_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + ) + # vllm.model_executor.layers.layernorm.RMSNorm has precision issue, + # use huggingface's instead + self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.scaling = self.head_dim**-0.5 + + self.attn = Attention( + self.num_local_heads, + self.head_dim, + self.scaling, + self.num_local_key_value_heads, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + cross_attention_states: Optional[torch.Tensor], + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv_dec, _ = self.qkv_proj(hidden_states) + q, _, _ = qkv_dec.split( + [self.q_local_size, self.kv_local_size, self.kv_local_size], + dim=-1) + if cross_attention_states is None: + k = None + v = None + else: + qkv_enc, _ = self.qkv_proj(cross_attention_states) + _, k, v = qkv_enc.split( + [self.q_local_size, self.kv_local_size, self.kv_local_size], + dim=-1) + k = k.view(-1, self.num_local_key_value_heads, self.head_dim) + v = v.view(-1, self.num_local_key_value_heads, self.head_dim) + k = self.k_norm(k) + q = q.view(-1, self.num_local_heads, self.head_dim) + q = self.q_norm(q) + + output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.ENCODER_DECODER) + out, _ = self.o_proj(output) + return out + + +class MllamaCrossAttentionDecoderLayer(torch.nn.Module): + """Cross-attention transformer block with tanh-gated attention + and feedforward.""" + + def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int, + quant_config: Optional[QuantizationConfig]) \ + -> None: + super().__init__() + self.layer_idx = layer_idx + self.cross_attn = MllamaTextCrossAttention( + config=config, + layer_idx=layer_idx, + quant_config=quant_config, + ) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) + + self.mlp = LlamaMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + cross_attention_mask: torch.Tensor, + full_text_row_masked_out_mask: torch.Tensor, + kv_cache: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.cross_attn( + hidden_states=hidden_states, + attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = full_text_row_masked_out_mask * hidden_states + hidden_states = residual + self.cross_attn_attn_gate.tanh( + ) * hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = full_text_row_masked_out_mask * hidden_states + hidden_states = residual + self.cross_attn_mlp_gate.tanh( + ) * hidden_states + return hidden_states + + +class MllamaTextModel(nn.Module): + config_class = config_mllama.MllamaTextConfig + base_model_prefix = "model" + + def __init__(self, config: config_mllama.MllamaTextConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig]): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, + config.hidden_size) + self.cross_attention_layers = config.cross_attention_layers + + layers = [] + for layer_idx in range(config.num_hidden_layers): + if layer_idx in self.cross_attention_layers: + layers.append( + MllamaCrossAttentionDecoderLayer( + config, layer_idx, quant_config=quant_config)) + else: + # TODO: force LlamaDecoderLayer to config.attention_bias=False + layers.append( + LlamaDecoderLayer(config, + cache_config=cache_config, + quant_config=quant_config)) + + self.layers = nn.ModuleList(layers) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + cross_attention_states: Optional[torch.LongTensor], + cross_attention_mask: Optional[torch.LongTensor], + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, + torch.Tensor]], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + skip_cross_attention: bool, + ) -> torch.Tensor: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + + for idx, decoder_layer in enumerate(self.layers): + if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer): + if not skip_cross_attention: + hidden_states = decoder_layer( + hidden_states=hidden_states, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask= + full_text_row_masked_out_mask, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + ) + elif isinstance(decoder_layer, LlamaDecoderLayer): + hidden_states, residual = decoder_layer( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + residual=None, + ) + hidden_states = hidden_states + residual + else: + raise ValueError( + f"Unknown decoder layer type {type(decoder_layer)}") + hidden_states = self.norm(hidden_states) + return hidden_states + + +class MllamaForCausalLM(nn.Module): + config_class = config_mllama.MllamaTextConfig + base_model_prefix = "language_model" + _no_split_modules = [ + "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer" + ] + + def __init__(self, config: config_mllama.MllamaTextConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig]): + super().__init__() + self.vocab_size = config.vocab_size + self.model = MllamaTextModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + ) + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + cross_attention_states: Optional[torch.LongTensor], + cross_attention_mask: Optional[torch.LongTensor], + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, + torch.Tensor]], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + skip_cross_attention: bool, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + skip_cross_attention=skip_cross_attention, + ) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_decoder_data_for_mllama) +@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama) +@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) +class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): + + def __init__(self, + config: config_mllama.MllamaConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.vocab_size = config.text_config.vocab_size + self.hidden_size = config.text_config.hidden_size + self.max_num_tiles = config.vision_config.max_num_tiles + self.vision_output_dim = config.vision_config.vision_output_dim + self.pad_token_id = \ + config.pad_token_id if config.pad_token_id is not None else -1 + self.image_size = config.vision_config.image_size + + self.vision_model = MllamaVisionModel(config.vision_config) + self.language_model = MllamaForCausalLM( + config.text_config, + cache_config=cache_config, + quant_config=quant_config, + ) + self.multi_modal_projector = nn.Linear( + config.vision_config.vision_output_dim, + config.text_config.hidden_size, + bias=True, + ) + self.logits_processor = LogitsProcessor(config.output_hidden_states, + config.text_config.vocab_size) + self.sampler = Sampler() + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.language_model.lm_head, + hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def _parse_and_validate_image_input(self, **kwargs: object): + # tensor with the same shape will be batched together by + # MultiModalInputs.batch, so pixel_values here can be: + # - List[List[torch.Tensor]]: + # with shape (num_tiles, 3, image_res, image_res) + # - List[torch.Tensor]: + # with shape (num_image, num_tiles, 3, image_res, image_res) + # - torch.Tensor: + # with shape (bs, num_image, num_tiles, 3, image_res, image_res) + pixel_values: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "pixel_values", None) + image_embeds: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "image_embeds", None) + aspect_ratio_ids: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "aspect_ratio_ids", None) + aspect_ratio_mask: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "aspect_ratio_mask", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None and image_embeds is not None: + raise ValueError( + "Both pixel values and image embeds are provided.") + + if pixel_values is not None: + assert aspect_ratio_ids is not None + assert aspect_ratio_mask is not None + max_num_images = max([len(x[0]) for x in pixel_values]) + if max_num_images == 0: + raise ValueError("No images provided.") + max_num_tiles = max( + max([len(x) for x in y[0]]) for y in pixel_values) + device = self.multi_modal_projector.weight.device + bsz = len(pixel_values) + out_num_tiles = [] + out_images = torch.zeros( + bsz, + max_num_images, + max_num_tiles, + 3, + self.image_size, + self.image_size, + dtype=torch.float32, + device=device, + ) + out_ar_ids = torch.ones(bsz, + max_num_images, + dtype=torch.int64, + device=device) + out_ar_mask = torch.zeros(bsz, + max_num_images, + max_num_tiles, + dtype=torch.int64, + device=device) + for b in range(len(pixel_values)): + _num_tiles = [] + for i in range(len(pixel_values[b][0])): + img = pixel_values[b][0][i] + out_images[b, i, :img.shape[0]] = img + out_ar_ids[b, i] = aspect_ratio_ids[b][0][i] + out_ar_mask[b, i] = aspect_ratio_mask[b][0][i] + _num_tiles.append(img.shape[0]) + out_num_tiles.append(_num_tiles) + + return MllamaImagePixelInputs( + type="pixel_values", + data=out_images, + aspect_ratio_ids=out_ar_ids, + aspect_ratio_mask=out_ar_mask, + ) + + if image_embeds is not None: + raise NotImplementedError + + raise AssertionError("This line should be unreachable.") + + def flat_encoder_result(self, cross_attention_states: torch.Tensor, + attn_metadata: AttentionMetadata): + + cross_attention_states_flat = torch.zeros( + sum(attn_metadata.encoder_seq_lens), + cross_attention_states.shape[-1], + device=cross_attention_states.device, + dtype=cross_attention_states.dtype) + start_pos = 0 + for seq_len, vision_token_in_batch in zip( + attn_metadata.encoder_seq_lens, cross_attention_states): + end_pos = start_pos + seq_len + cross_attention_states_flat[ + start_pos:end_pos] = vision_token_in_batch[:seq_len] + start_pos = end_pos + cross_attention_states = cross_attention_states_flat + + full_text_row_masked_out_mask = torch.ones( + (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool) + start_pos = 0 + for seq_len, encoder_seq_len in zip( + attn_metadata.seq_lens_tensor.cpu(), + attn_metadata.encoder_seq_lens): + if encoder_seq_len == 0: + full_text_row_masked_out_mask[start_pos:start_pos + + seq_len] = False + start_pos += seq_len + full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( + cross_attention_states.device) + + return cross_attention_states, full_text_row_masked_out_mask + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs: object, + ) -> Union[Tuple, CausalLMOutputWithPast]: + if attn_metadata.num_prefill_tokens > 0 and \ + attn_metadata.num_decode_tokens > 0: + raise ValueError("Chunk prefill not supported") + image_inputs = self._parse_and_validate_image_input(**kwargs) + if image_inputs is None: + cross_attention_mask = None + full_text_row_masked_out_mask = ( + attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to( + input_ids.device) + cross_attention_states = None + skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0 + else: + # NOTE: llama's reference implementation runs vision model on CPU + pixel_values = image_inputs['data'] + aspect_ratio_ids = image_inputs['aspect_ratio_ids'] + aspect_ratio_mask = image_inputs['aspect_ratio_mask'] + cross_attention_states = self.vision_model(pixel_values, + aspect_ratio_ids, + aspect_ratio_mask) + cross_attention_states = self.multi_modal_projector( + cross_attention_states) + + bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape) + cross_attention_states = cross_attention_states.view( + bsz, -1, image_token_dim) + + cross_attention_states, full_text_row_masked_out_mask = \ + self.flat_encoder_result(cross_attention_states, attn_metadata) + skip_cross_attention = False + # TODO: support multi-image by this mask + cross_attention_mask = None + + outputs = self.language_model( + input_ids=input_ids, + positions=positions, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + skip_cross_attention=skip_cross_attention, + ) + + return outputs + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + updated_params = set() + for name, loaded_weight in weights: + if 'patch_embedding.weight' in name: + name = name.replace('patch_embedding.weight', + 'patch_embedding._linear.weight') + loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + updated_params.add(name) + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict.pop(name) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 6f17f571ccaea..245381518a7f8 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -307,7 +307,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90 -def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16): +def _calc_hd_transform_size(*, width: int, height: int, hd_num: int): transposed = False if width < height: width, height = height, width @@ -337,8 +337,10 @@ def get_phi3v_image_feature_size( *, input_height: int, input_width: int, + num_crops: int, ) -> int: - num_crops = hf_config.get("num_crops", 16) + if num_crops is None: + num_crops = hf_config.get("num_crops", 16) new_width, new_height = _calc_hd_transform_size(width=input_width, height=input_height, hd_num=num_crops) @@ -347,20 +349,26 @@ def get_phi3v_image_feature_size( + (new_height // 336 + 1) * 12 -def get_max_phi3v_image_tokens(ctx: InputContext): +def get_max_phi3v_image_tokens(ctx: InputContext, + *, + num_crops: Optional[int] = None): return get_phi3v_image_feature_size( ctx.get_hf_image_processor_config(), input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, + num_crops=num_crops, ) -def dummy_data_for_phi3v(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): +def dummy_data_for_phi3v(ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + *, + num_crops: Optional[int] = None): num_images = mm_counts["image"] - image_feature_size = get_max_phi3v_image_tokens(ctx) + image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops) seq_data = dummy_seq_data_for_clip( CLIP_VIT_LARGE_PATCH14_336_CONFIG, @@ -398,7 +406,10 @@ def _get_image_placeholder_token_ids(model_config: ModelConfig, return image_placeholder_token_ids -def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_phi3v(ctx: InputContext, + llm_inputs: LLMInputs, + *, + num_crops: Optional[int] = None): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs @@ -412,7 +423,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): image_feature_size = [ get_phi3v_image_feature_size(hf_config, input_width=w, - input_height=h) + input_height=h, + num_crops=num_crops) ] image_data = [image_data] elif is_list_of(image_data, Image.Image): @@ -422,7 +434,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): image_feature_size.append( get_phi3v_image_feature_size(hf_config, input_width=w, - input_height=h)) + input_height=h, + num_crops=num_crops)) elif isinstance(image_data, torch.Tensor): num_images, image_feature_size, hidden_size = image_data.shape elif is_list_of(image_data, torch.Tensor): diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index a3555a294bb66..487d9fc2f4337 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -321,13 +321,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=True, - quant_config=None, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=True, - quant_config=None, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index e62a841485f2d..761c1370b9776 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -674,8 +674,9 @@ def input_processor_for_qwen(ctx: InputContext, prompt = llm_inputs.get("prompt") prompt_token_ids = llm_inputs["prompt_token_ids"] model_config = ctx.model_config - tokenizer = cached_get_tokenizer(model_config.tokenizer, - trust_remote_code=True) + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) image_data = multi_modal_data["image"] if isinstance(image_data, torch.Tensor): num_dims = len(image_data.shape) @@ -735,8 +736,9 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: return MultiModalInputs() model_config = ctx.model_config - tokenizer = cached_get_tokenizer(model_config.tokenizer, - trust_remote_code=True) + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) image_pair_tok = tokenizer.encode(IMG_START + IMG_END, add_special_tokens=False, @@ -824,8 +826,9 @@ def dummy_data_for_qwen( # We have a visual component - use images to warm up num_images = mm_counts["image"] model_config = ctx.model_config - tokenizer = cached_get_tokenizer(model_config.tokenizer, - trust_remote_code=True) + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) # Build the image prompts with no imgpads; the tokenizer will add img pads image_prompt = ''.join( diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 9f72210c60bf9..f895e693b7107 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -31,12 +31,9 @@ import torch.nn.functional as F from einops import rearrange, repeat from PIL import Image -from transformers import Qwen2VLConfig from transformers.image_utils import (get_image_size, infer_channel_dimension_format, to_numpy_array) -from transformers.models.qwen2_vl.configuration_qwen2_vl import ( - Qwen2VLVisionConfig) from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( make_batched_images, make_batched_videos, smart_resize) @@ -66,7 +63,10 @@ from vllm.multimodal.image import cached_get_image_processor from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors, SequenceData +from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig, + Qwen2VLVisionConfig) from vllm.transformers_utils.processor import get_processor +from vllm.utils import is_cpu from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory) @@ -281,6 +281,21 @@ def forward( context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) + elif is_cpu(): + seq_length = q.size(1) + q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]] + attention_mask = torch.zeros([1, seq_length, seq_length], + device=q.device, + dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], + cu_seqlens[i - 1]:cu_seqlens[i]] = True + output = F.scaled_dot_product_attention(q, + k, + v, + attention_mask, + dropout_p=0.0) + context_layer = rearrange(output, "b h s d -> b s h d ") else: from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalMask diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 32a0e895005cb..71808eb4c2719 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -25,6 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.utils import (flatten_bn, @@ -334,14 +335,23 @@ def __init__(self, self.multi_modal_config = multimodal_config assert self.multi_modal_config + self.secondary_weights = [] + self.audio_tower = ModifiedWhisperEncoder(config.audio_config) if config.audio_model_id is not None: - self.audio_tower = ModifiedWhisperEncoder.from_pretrained( - config.audio_model_id) - else: - self.audio_tower = ModifiedWhisperEncoder(config.audio_config) + self.secondary_weights.append( + DefaultModelLoader.Source( + model_or_path=config.audio_model_id, + revision=None, + prefix="audio_tower.", + )) self.multi_modal_projector = UltravoxProjector(config) self.language_model = init_vllm_registered_model( config.text_config, cache_config, quant_config) + if config.text_model_id is not None: + self.secondary_weights.append( + DefaultModelLoader.Source(model_or_path=config.text_model_id, + revision=None, + prefix="language_model.")) def _audio_features_to_embeddings( self, input_features: torch.Tensor) -> torch.Tensor: @@ -466,6 +476,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components weights_group = group_weights_with_prefix(weights) + # load audio tower weights + audio_tower_weights = weights_group["audio_tower"] + audio_tower_params_dict = dict( + self.audio_tower.named_parameters( + prefix=self.audio_tower.base_model_prefix)) + for name, loaded_weight in audio_tower_weights: + if name in audio_tower_params_dict: + param = audio_tower_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + # load projector weights projector_weights = weights_group["multi_modal_projector"] projector_params_dict = dict( diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 9ffb339ffeab3..7a6d7c90f34d5 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -328,6 +328,64 @@ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): marlin_tile_size=self.marlin_tile_size) +def permute_param_layout_(param: BasevLLMParameter, input_dim: int, + output_dim: int, **kwargs) -> BasevLLMParameter: + """ + Permute a parameter's layout to the specified input and output dimensions, + useful for forcing the parameter into a known layout, for example, if I need + a packed (quantized) weight matrix to be in the layout + {input_dim = 0, output_dim = 1, packed_dim = 0} + then I can call: + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + to ensure x is in the correct layout (permuting it to the correct layout if + required, asserting if it cannot get it to the correct layout) + """ + + curr_input_dim = getattr(param, "input_dim", None) + curr_output_dim = getattr(param, "output_dim", None) + + if curr_input_dim is None or curr_output_dim is None: + assert param.data.dim() == 2,\ + "permute_param_layout_ only supports 2D parameters when either "\ + "input_dim or output_dim is not set" + + # if one of the dimensions is not set, set it to the opposite of the other + # we can only do this since we asserted the parameter is 2D above + if curr_input_dim is None: + assert curr_output_dim is not None,\ + "either input or output dim must be set" + curr_input_dim = (curr_output_dim + 1) % 2 + if curr_output_dim is None: + assert curr_input_dim is not None,\ + "either input or output dim must be set" + curr_output_dim = (curr_input_dim + 1) % 2 + + # create permutation from the current layout to the layout with + # self.input_dim at input_dim and self.output_dim at output_dim preserving + # other dimensions + perm = [ + i for i in range(param.data.dim()) + if i not in [curr_input_dim, curr_output_dim] + ] + perm.insert(input_dim, curr_input_dim) + perm.insert(output_dim, curr_output_dim) + + if "packed_dim" in kwargs: + assert hasattr(param, "packed_dim") and\ + param.packed_dim == perm[kwargs["packed_dim"]],\ + "permute_param_layout_ currently doesn't support repacking" + + param.data = param.data.permute(*perm) + if hasattr(param, "_input_dim"): + param._input_dim = input_dim + if hasattr(param, "_output_dim"): + param._output_dim = output_dim + if "packed_dim" in kwargs and hasattr(param, "_packed_dim"): + param._packed_dim = kwargs["packed_dim"] + + return param + + def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 87d3a4576f332..8bcb38ef241ed 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -54,6 +54,12 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: if isinstance(nested_tensors, torch.Tensor): return nested_tensors + if isinstance(nested_tensors, np.ndarray): + return torch.from_numpy(nested_tensors) + + if isinstance(nested_tensors, (int, float)): + return torch.tensor(nested_tensors) + stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors] if not is_list_of(stacked, torch.Tensor, check="all"): # Only tensors (not lists) can be stacked. diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 31b1c3f93411a..d3a230e40477e 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -2,6 +2,7 @@ import torch from PIL import Image +from transformers.image_processing_base import BatchFeature from vllm.config import ModelConfig from vllm.inputs.registry import InputContext @@ -39,6 +40,10 @@ def _default_input_mapper( ) -> MultiModalInputs: model_config = ctx.model_config + # Processed by input processor + if isinstance(data, BatchFeature): + return MultiModalInputs(data.data) + # PIL image if isinstance(data, Image.Image) or is_list_of(data, Image.Image): image_processor = self._get_hf_image_processor(model_config) diff --git a/vllm/outputs.py b/vllm/outputs.py index 85ea9196b25df..44cde6b561d85 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -114,17 +114,28 @@ def __init__( self.encoder_prompt_token_ids = encoder_prompt_token_ids @classmethod - def from_seq_group(cls, - seq_group: SequenceGroup) -> Optional["RequestOutput"]: + def from_seq_group(cls, seq_group: SequenceGroup, + use_cache: bool) -> Optional["RequestOutput"]: sampling_params = seq_group.sampling_params if sampling_params is None: raise ValueError( "Sampling parameters are missing for a CompletionRequest.") + finished = seq_group.is_finished() if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( not finished): return None + # Init cache (if needed) + if use_cache and seq_group.cached_request_output is None: + seq_group.cached_request_output = RequestOutput( # type: ignore + request_id="", + prompt=None, + prompt_token_ids=[], + prompt_logprobs=None, + outputs=[], + finished=False) + seqs = seq_group.get_seqs() if len(seqs) == 1: top_n_seqs = seqs @@ -149,29 +160,66 @@ def from_seq_group(cls, outputs = [] include_prompt = True - for seq in top_n_seqs: + for i, seq in enumerate(top_n_seqs): output_text = seq.get_output_text_to_return( text_buffer_length, delta) + output_token_ids = seq.get_output_token_ids_to_return(delta) + num_output_tokens = 1 if isinstance(output_token_ids, + int) else len(output_token_ids) + output_logprobs = seq.output_logprobs if include_logprobs else None if delta: # Slice logprobs delta if applicable if output_logprobs: - output_logprobs = output_logprobs[-len(output_token_ids):] + output_logprobs = output_logprobs[-num_output_tokens:] # Don't include prompt if this is after the first output # containing decode token ids - if include_prompt and seq.get_output_len() > len( - output_token_ids): + if include_prompt and seq.get_output_len() > num_output_tokens: include_prompt = False - outputs.append( - CompletionOutput( - seqs.index(seq), output_text, output_token_ids, + if use_cache: + # Get cached output object + cached_outputs = seq_group.cached_request_output.outputs # type: ignore + if i >= len(cached_outputs): + cached_outputs.append( + CompletionOutput(index=i, + text="", + token_ids=[], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + stop_reason=None)) + output = cached_outputs[i] + + # Init cached output object + assert output.index == i + output.text = output_text + + if isinstance(output_token_ids, int): + output.token_ids.clear() + output.token_ids.append(output_token_ids) + else: + output.token_ids = output_token_ids + + output.cumulative_logprob = seq.get_cumulative_logprob() \ + if include_logprobs else None + output.logprobs = output_logprobs + output.finish_reason = SequenceStatus.get_finished_reason( + seq.status) + output.stop_reason = seq.stop_reason + + else: + output = CompletionOutput( + seqs.index(seq), output_text, [output_token_ids] + if isinstance(output_token_ids, int) else output_token_ids, seq.get_cumulative_logprob() if include_logprobs else None, output_logprobs, SequenceStatus.get_finished_reason(seq.status), - seq.stop_reason)) + seq.stop_reason) + + outputs.append(output) # Every sequence in the sequence group should have the same prompt. if include_prompt: @@ -188,16 +236,20 @@ def from_seq_group(cls, prompt_logprobs = None finished_time = time.time() if finished else None seq_group.set_finished_time(finished_time) - return cls(seq_group.request_id, - prompt, - prompt_token_ids, - prompt_logprobs, - outputs, - finished, - seq_group.metrics, - lora_request=seq_group.lora_request, - encoder_prompt=encoder_prompt, - encoder_prompt_token_ids=encoder_prompt_token_ids) + + init_args = (seq_group.request_id, prompt, prompt_token_ids, + prompt_logprobs, outputs, finished, seq_group.metrics, + seq_group.lora_request, encoder_prompt, + encoder_prompt_token_ids) + + if use_cache: + request_output = seq_group.cached_request_output + request_output.__init__(*init_args) # type: ignore + + else: + request_output = cls(*init_args) + + return request_output def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " @@ -261,10 +313,10 @@ def __repr__(self): class RequestOutputFactory: @staticmethod - def create(seq_group): + def create(seq_group: SequenceGroup, use_cache: bool = False): # Determine the type based on a condition, for example: if hasattr(seq_group, 'embeddings') and seq_group.embeddings is not None: return EmbeddingRequestOutput.from_seq_group(seq_group) else: - return RequestOutput.from_seq_group(seq_group) + return RequestOutput.from_seq_group(seq_group, use_cache) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 914035bc78f1c..cb0c79d535a27 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -8,6 +8,7 @@ import torch from typing_extensions import Annotated +import vllm.envs as envs from vllm.logger import init_logger logger = init_logger(__name__) @@ -272,6 +273,10 @@ def __post_init__(self) -> None: self._verify_args() if self.use_beam_search: + if not envs.VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: + raise ValueError( + "Using beam search as a sampling parameter is deprecated, and will be removed in the future release. Please use the `vllm.LLM.use_beam_search` method for dedicated beam search instead, or set the environment variable `VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1` to suppress this error. For more details, see https://github.com/vllm-project/vllm/issues/8306 ." # noqa + ) self._verify_beam_search() else: self._verify_non_beam_search() diff --git a/vllm/sequence.py b/vllm/sequence.py index d8e54ff1fc708..49a198df045bd 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -13,6 +13,7 @@ import msgspec import torch +from vllm.inputs import EncoderDecoderLLMInputs, LLMInputs from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams @@ -21,11 +22,12 @@ from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if TYPE_CHECKING: - from vllm.inputs import LLMInputs from vllm.multimodal.base import MultiModalDataDict VLLM_TOKEN_ID_ARRAY_TYPE = "l" +VLLM_INVALID_TOKEN_ID = -1 + # We use dataclass for now because it is used for # openai server output, and msgspec is not serializable. @@ -436,7 +438,7 @@ def __init__( self.stop_reason: Union[int, str, None] = None # These are used to keep track of delta outputs - self._last_token_ids_offset: int = 0 + self._last_output_token_ids_offset: int = 0 self._last_output_text_offset: int = 0 # Used for incremental detokenization @@ -469,7 +471,15 @@ def prompt_token_ids(self) -> List[int]: @property def multi_modal_data(self) -> "MultiModalDataDict": - return self.inputs.get("multi_modal_data") or {} + if self.inputs.get("multi_modal_data") and self.inputs.get( + "encoder_multi_modal_data"): + raise ValueError( + "Multi-modal data in both encoder and decoder is not supported." + ) + inputs = self.inputs + return self.inputs.get("multi_modal_data") or (cast( + EncoderDecoderLLMInputs, + inputs).get("encoder_multi_modal_data")) or {} @property def lora_int_id(self) -> int: @@ -499,18 +509,26 @@ def get_output_text_to_return(self, buffer_length: int, return self.output_text[last_offset:length] return "" - def get_output_token_ids_to_return(self, - delta: bool) -> GenericSequence[int]: + def get_output_token_ids_to_return( + self, delta: bool) -> Union[GenericSequence[int], int]: """If delta is True, only new tokens since the last call to this method are returned""" if not delta: return self.get_output_token_ids() - length = self.get_output_len() - last_offset = self._last_token_ids_offset - if last_offset < length: - self._last_token_ids_offset = length - return self.data._output_token_ids[last_offset:] - return () + + output_len = self.get_output_len() + + # Get the number of new tokens + num_new_tokens = output_len - self._last_output_token_ids_offset + self._last_output_token_ids_offset = output_len + + # Return new tokens + if num_new_tokens == 1: + # Optimization for single decode token case + # (which is what we have most of the time) + return self.data._cached_all_token_ids[-1] + + return self.data._cached_all_token_ids[-num_new_tokens:] def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size @@ -636,6 +654,7 @@ class SequenceGroup: unless you are working with an encoder/decoder model. trace_headers: OpenTelemetry trace headers. prompt_adapter_request: Prompt Adapter request. + priority: User-defined priority of the request. """ def __init__( @@ -650,9 +669,11 @@ def __init__( encoder_seq: Optional[Sequence] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: self.request_id = request_id self.seqs = seqs + self.arrival_time = arrival_time self.is_single_seq = len(seqs) == 1 self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -670,6 +691,9 @@ def __init__( self.prompt_adapter_request = prompt_adapter_request self.encoder_seq = encoder_seq self.trace_headers = trace_headers + self.priority = priority + + self.cached_request_output = None @property def prompt(self) -> Optional[str]: diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index b2204e8b27afd..9eb8bbfc54076 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -6,9 +6,9 @@ from vllm import SamplingParams from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest, - SequenceData, SequenceGroupMetadata, - get_all_seq_ids) +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE, + ExecuteModelRequest, SequenceData, + SequenceGroupMetadata, get_all_seq_ids) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len @@ -69,10 +69,10 @@ def score_proposals( proposal_lens_list = proposals.proposal_lens.tolist() proposal_token_ids_list = proposals.proposal_token_ids.tolist() - # Filter the list to ignore -1 proposals. + # Filter the list to ignore invalid proposals. proposal_token_ids_list_without_skips = [ proposals for proposals in proposal_token_ids_list - if -1 not in proposals + if VLLM_INVALID_TOKEN_ID not in proposals ] (spec_indices, non_spec_indices, target_seq_group_metadata_list, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 9e645a49f699c..dbf880a8f475c 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -13,9 +13,10 @@ SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) from vllm.model_executor.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) -from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, ExecuteModelRequest, HiddenStates, SequenceGroupMetadata, - get_all_seq_ids, get_all_seq_ids_and_request_ids) + get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, @@ -28,7 +29,8 @@ from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker from vllm.spec_decode.target_model_runner import TargetModelRunner -from vllm.spec_decode.util import (Timer, create_sequence_group_output, +from vllm.spec_decode.util import (Timer, create_logprobs_output, + create_sequence_group_output, get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) @@ -436,8 +438,8 @@ def _serialize_sampler_output_no_logprobs( self, execute_model_req: ExecuteModelRequest, sampler_output: SamplerOutput) -> SamplerOutput: """ - Creates and returns a `SamplerOutput` with only the sampled token IDs - being serialized to CPU & populated in `CompletionSequenceGroupOutput`. + Creates and returns a `SamplerOutput` with only the token IDs being + serialized to CPU and populated in `CompletionSequenceGroupOutput`. All other parameters in `CompletionSequenceGroupOutput` related to log probabilities are skipped. @@ -449,14 +451,46 @@ def _serialize_sampler_output_no_logprobs( Returns: SamplerOutput: A new `SamplerOutput` instance containing a list of - `CompletionSequenceGroupOutput` objects with only sampled token - IDs populated. + `CompletionSequenceGroupOutput` objects with only token IDs + populated. """ - seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list) - sampled_token_ids_list = sampler_output.sampled_token_ids.tolist() + seq_output_prompt_logprobs = [ + seq.is_prompt and seq.sampling_params.prompt_logprobs is not None + and seq.sampling_params.prompt_logprobs > 0 + for seq in execute_model_req.seq_group_metadata_list + ] + # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID + sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where( + # subtracting is faster than testing for equality + sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \ + if any(seq_output_prompt_logprobs) else \ + sampler_output.sampled_token_ids).tolist() + + seq_data_entries = ( + (seq_id, seq_data) for sg in \ + execute_model_req.seq_group_metadata_list \ + for seq_id, seq_data in sg.seq_data.items() + ) completion_seq_group_output_list: List[ CompletionSequenceGroupOutput] = [] - for index, seq_id in enumerate(seq_ids): + for index, ((seq_id, seq_data), needs_prompt_logprobs) in \ + enumerate(zip(seq_data_entries, seq_output_prompt_logprobs)): + if needs_prompt_logprobs: + prompt_token_ids = seq_data.get_prompt_token_ids() + prompt_logprobs = [ + create_logprobs_output( + token_id=p_token_id, + token_id_logprob_rank=-1, + token_id_logprob=0.0, + topk_token_ids=[], + topk_logprobs=[], + ) + # no prompt logprobs for the first token + for p_token_id in prompt_token_ids[1:] + ] + else: + prompt_logprobs = None + completion_seq_group_output_list.append( create_sequence_group_output( token_id=sampled_token_ids_list[index][0], @@ -465,7 +499,7 @@ def _serialize_sampler_output_no_logprobs( seq_id=seq_id, topk_token_ids=[], topk_logprobs=[], - )) + prompt_logprobs=prompt_logprobs)) return SamplerOutput(outputs=completion_seq_group_output_list) @nvtx_range("spec_decode_worker._run_no_spec") @@ -485,6 +519,12 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, # Store hidden states from target model execution. hidden_states = sampler_output.hidden_states if hidden_states is not None: + # remove hidden_states for prompt tokens + if any(seq.is_prompt + for seq in execute_model_req.seq_group_metadata_list): + hidden_states = hidden_states[ + torch.where(sampler_output.sampled_token_ids - + VLLM_INVALID_TOKEN_ID)[0]] if self.previous_hidden_states is None: self.previous_hidden_states = HiddenStates( hidden_states, execute_model_req.seq_group_metadata_list) diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 54e718bc49017..193ef870dfceb 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -6,7 +6,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SequenceGroupMetadata, SequenceOutput) + PromptLogprobs, SequenceGroupMetadata, + SequenceOutput) SeqId = int @@ -49,21 +50,19 @@ def get_sampled_token_logprobs( return sampled_token_ids_ranks, selected_logprobs -def create_sequence_group_output( +def create_logprobs_output( token_id: int, token_id_logprob_rank: int, token_id_logprob: float, - seq_id: SeqId, topk_token_ids: List[Optional[int]], topk_logprobs: List[Optional[float]], -) -> CompletionSequenceGroupOutput: - """Create a SequenceGroupOutput given the sampling results. +) -> Dict[int, Logprob]: + """Create a Logprob Dict for a token given the sampling results. Args: token_id (int): The sampled token for the sequence. token_id_logprob_rank (int): The logprob rank of the sampled token. token_id_logprob (float): The logprob value of the sampled token. - seq_id (int): The sequence id. topk_token_ids (List[Optional[int]]): The list of top-k token ids. topk_logprobs (List[Optional[float]]): The list of top-k logprobs. """ @@ -85,14 +84,44 @@ def create_sequence_group_output( if topk_token_id is not None }) + return logprobs + + +def create_sequence_group_output( + token_id: int, + token_id_logprob_rank: int, + token_id_logprob: float, + seq_id: SeqId, + topk_token_ids: List[Optional[int]], + topk_logprobs: List[Optional[float]], + prompt_logprobs: Optional[PromptLogprobs] = None, +) -> CompletionSequenceGroupOutput: + """Create a SequenceGroupOutput given the sampling results. + + Args: + token_id (int): The sampled token for the sequence. + token_id_logprob_rank (int): The logprob rank of the sampled token. + token_id_logprob (float): The logprob value of the sampled token. + seq_id (int): The sequence id. + topk_token_ids (List[Optional[int]]): The list of top-k token ids. + topk_logprobs (List[Optional[float]]): The list of top-k logprobs. + """ + + logprobs = create_logprobs_output( + token_id, + token_id_logprob_rank, + token_id_logprob, + topk_token_ids, + topk_logprobs, + ) + return CompletionSequenceGroupOutput( samples=[ SequenceOutput(parent_seq_id=seq_id, output_token=token_id, logprobs=logprobs) ], - # TODO add prompt logprobs support. - prompt_logprobs=None, + prompt_logprobs=prompt_logprobs, ) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 6a47dcc7b6e98..0f808eb904432 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -20,12 +20,12 @@ # yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, EAGLEConfig, ExaoneConfig, - GraniteConfig, Grok1Config, - InternVLChatConfig, JAISConfig, - MedusaConfig, MLPSpeculatorConfig, + Grok1Config, InternVLChatConfig, + JAISConfig, MedusaConfig, + MllamaConfig, MLPSpeculatorConfig, MPTConfig, NemotronConfig, - RWConfig, SolarConfig, - UltravoxConfig) + Qwen2VLConfig, RWConfig, + SolarConfig, UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file @@ -38,6 +38,10 @@ logger = init_logger(__name__) +_CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = { + "mllama": MllamaConfig +} + _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { "chatglm": ChatGLMConfig, "dbrx": DbrxConfig, @@ -53,15 +57,17 @@ "nemotron": NemotronConfig, "solar": SolarConfig, "ultravox": UltravoxConfig, - # Granite can be removed from here once we have upgraded to - # transformers 4.45+ - "granite": GraniteConfig, + "qwen2_vl": Qwen2VLConfig, "grok-1": Grok1Config, + **_CONFIG_REGISTRY_OVERRIDE_HF } for name, cls in _CONFIG_REGISTRY.items(): with contextlib.suppress(ValueError): - AutoConfig.register(name, cls) + if name in _CONFIG_REGISTRY_OVERRIDE_HF: + AutoConfig.register(name, cls, exist_ok=True) + else: + AutoConfig.register(name, cls) class ConfigFormat(str, enum.Enum): diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 60bd533c2a855..6228c6020a9ab 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -6,14 +6,16 @@ # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig -from vllm.transformers_utils.configs.granite import GraniteConfig from vllm.transformers_utils.configs.grok1 import Grok1Config from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.medusa import MedusaConfig +from vllm.transformers_utils.configs.mllama import MllamaConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig +from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig, + Qwen2VLVisionConfig) from vllm.transformers_utils.configs.solar import SolarConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig @@ -27,12 +29,12 @@ "MedusaConfig", "EAGLEConfig", "ExaoneConfig", + "MllamaConfig", "MLPSpeculatorConfig", "NemotronConfig", "SolarConfig", "UltravoxConfig", - # Granite can be removed from here once we have upgraded to - # transformers 4.45+ - "GraniteConfig", + "Qwen2VLConfig", + "Qwen2VLVisionConfig", "Grok1Config", ] diff --git a/vllm/transformers_utils/configs/granite.py b/vllm/transformers_utils/configs/granite.py deleted file mode 100644 index c12838be5d385..0000000000000 --- a/vllm/transformers_utils/configs/granite.py +++ /dev/null @@ -1,199 +0,0 @@ -# coding=utf-8 -# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Granite model configuration""" - -from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_rope_utils import rope_config_validation -from transformers.utils import logging - -logger = logging.get_logger(__name__) - - -class GraniteConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of - a [`GraniteModel`]. It is used to instantiate an Granite - model according to the specified arguments, defining the model architecture. - Instantiating a configuration with the defaults will yield a similar - configuration to that of the Granite-3B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to - control the model outputs. Read the documentation from [`PretrainedConfig`] - for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the Granite model. Defines the number of - different tokens that can be represented by the `inputs_ids` - passed when calling [`GraniteModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the - Transformer decoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to - implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi - Head Attention (MHA), if `num_key_value_heads=1` the model will use - Multi Query Attention (MQA) otherwise GQA is used. When converting - a multi-head checkpoint to a GQA checkpoint, each group key and - value head should be constructed by meanpooling all the original - heads within that group. For more details checkout - [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not - specified, will default to `num_attention_heads`. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the - decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for - initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values - attentions (not used by all models). Only relevant if - `config.is_decoder=True`. - pad_token_id (`int`, *optional*): - Padding token id. - bos_token_id (`int`, *optional*, defaults to 1): - Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 2): - End of stream token id. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 10000.0): - The base period of the RoPE embeddings. - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE - embeddings. Currently supports two scaling strategies: linear and - dynamic. Their scaling factor must be a float greater than 1. The - expected format is - `{"type": strategy name, "factor": scaling factor}`. - When using this flag, don't update `max_position_embeddings` to - the expected new maximum. See the following thread for more - information on how these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. - This is an experimental feature, subject to breaking API changes - in future versions. - attention_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output - projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - mlp_bias (`bool`, *optional*, defaults to `False`): - Whether to use a bias in up_proj, down_proj and gate_proj layers - in the MLP layers. - embedding_multiplier (`float`, *optional*, defaults to 1.0): - embedding multiplier - logits_scaling (`float`, *optional*, defaults to 1.0): - divisor for output logits - residual_multiplier (`float`, *optional*, defaults to 1.0): - residual multiplier - attention_multiplier (`float`, *optional*, defaults to 1.0): - attention multiplier - - ```python - >>> from transformers import GraniteModel, GraniteConfig - - >>> # Initializing a Granite granite-3b style configuration - >>> configuration = GraniteConfig() - - >>> # Initializing a model from the granite-7b style configuration - >>> model = GraniteModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "granite" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act="silu", - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=None, - bos_token_id=1, - eos_token_id=2, - tie_word_embeddings=False, - rope_theta=10000.0, - rope_scaling=None, - attention_bias=False, - attention_dropout=0.0, - mlp_bias=False, - embedding_multiplier=1.0, - logits_scaling=1.0, - residual_multiplier=1.0, - attention_multiplier=1.0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.rope_scaling = rope_scaling - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.mlp_bias = mlp_bias - - self.embedding_multiplier = embedding_multiplier - self.logits_scaling = logits_scaling - self.residual_multiplier = residual_multiplier - self.attention_multiplier = attention_multiplier - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - rope_config_validation(self) diff --git a/vllm/transformers_utils/configs/mllama.py b/vllm/transformers_utils/configs/mllama.py new file mode 100644 index 0000000000000..49e766d7fa1f4 --- /dev/null +++ b/vllm/transformers_utils/configs/mllama.py @@ -0,0 +1,28 @@ +from transformers.models.mllama import configuration_mllama as mllama_hf_config + + +class MllamaTextConfig(mllama_hf_config.MllamaTextConfig): + ''' + Use this class to override is_encoder_decoder: + - transformers regards mllama as is_encoder_decoder=False + - vllm needs is_encoder_decoder=True to enable cross-attention + ''' + + def __init__( + self, + **kwargs, + ): + super().__init__(**kwargs) + self.is_encoder_decoder = True + + +class MllamaConfig(mllama_hf_config.MllamaConfig): + + def __init__( + self, + text_config=None, + **kwargs, + ): + if isinstance(text_config, dict): + text_config = MllamaTextConfig(**text_config) + super().__init__(text_config=text_config, **kwargs) diff --git a/vllm/transformers_utils/configs/qwen2vl.py b/vllm/transformers_utils/configs/qwen2vl.py new file mode 100644 index 0000000000000..92dd962790bc8 --- /dev/null +++ b/vllm/transformers_utils/configs/qwen2vl.py @@ -0,0 +1,131 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2VL model configuration""" + +import os +from typing import Union + +from transformers import PretrainedConfig + + +class Qwen2VLVisionConfig(PretrainedConfig): + model_type = "qwen2_vl" + + def __init__( + self, + depth=32, + embed_dim=1280, + hidden_size=3584, + hidden_act="quick_gelu", + mlp_ratio=4, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, + os.PathLike], + **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) + + if config_dict.get("model_type") == "qwen2_vl": + config_dict = config_dict["vision_config"] + + return cls.from_dict(config_dict, **kwargs) + + +class Qwen2VLConfig(PretrainedConfig): + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + vision_config=None, + rope_scaling=None, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = Qwen2VLVisionConfig(**vision_config) + elif vision_config is None: + self.vision_config = Qwen2VLVisionConfig() + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + # NOTE: the following section from original transformers config + # for Qwen2-VL is commented out to address rope config loading issue + # + # if self.rope_scaling is not None and "type" in self.rope_scaling: + # if self.rope_scaling["type"] == "mrope": + # self.rope_scaling["type"] = "default" + # self.rope_scaling["rope_type"] = self.rope_scaling["type"] + # rope_config_validation(self) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index d27d7ba9e67bb..2b418f3603a0b 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -1,13 +1,11 @@ from typing import Dict, List, Optional, Tuple -from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams, + Sequence, SequenceGroup) from .tokenizer import AnyTokenizer from .tokenizer_group import BaseTokenizerGroup -# Used eg. for marking rejected tokens in spec decoding. -INVALID_TOKEN_ID = -1 - class Detokenizer: """Provides methods to decode the output of a model into text.""" @@ -61,7 +59,7 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, continue for token_id, sample_logprob in prompt_logprobs_for_token.items(): if (sample_logprob.decoded_token is None - and token_id != INVALID_TOKEN_ID): + and token_id != VLLM_INVALID_TOKEN_ID): prompt_token_ids_with_token = ( prompt_token_ids[:token_position] + [token_id]) (new_tokens, new_text, new_prefix_offset, @@ -143,7 +141,7 @@ def decode_sequence_inplace(self, seq: Sequence, continue if (sample_logprob.decoded_token is None - and token_id != INVALID_TOKEN_ID): + and token_id != VLLM_INVALID_TOKEN_ID): all_input_ids_with_logprob = previous_tokens + [token_id] (_, new_text, _, _) = detokenize_incrementally( tokenizer=tokenizer, @@ -282,14 +280,14 @@ def detokenize_incrementally( assert prev_tokens is not None # If the new token id is out of bounds, return an empty string. - if new_token_id >= len(tokenizer): - new_tokens = [""] - else: + if 0 <= new_token_id < len(tokenizer): # Put new_token_id in a list so skip_special_tokens is respected new_tokens = tokenizer.convert_ids_to_tokens( [new_token_id], skip_special_tokens=skip_special_tokens) if isinstance(new_tokens, str): new_tokens = [new_tokens] + else: + new_tokens = [""] output_tokens = prev_tokens + new_tokens # If this is the first iteration, return all tokens. diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index f9fb8d1e103b7..2a2d74382e37a 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -111,7 +111,6 @@ def get_tokenizer( 'encoding and decoding.', FutureWarning, stacklevel=2) - if tokenizer_mode == "mistral": tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name), revision=revision) diff --git a/vllm/utils.py b/vllm/utils.py index 465e7a1ac4ec3..c74fd098ec6ee 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -5,6 +5,7 @@ import enum import gc import inspect +import ipaddress import os import random import socket @@ -630,6 +631,14 @@ def get_ip() -> str: return "0.0.0.0" +def is_valid_ipv6_address(address: str) -> bool: + try: + ipaddress.IPv6Address(address) + return True + except ValueError: + return False + + def get_distributed_init_method(ip: str, port: int) -> str: # Brackets are not permitted in ipv4 addresses, # see https://github.com/python/cpython/issues/103848 diff --git a/vllm/version.py b/vllm/version.py index 0ddc7fb99ad45..66e189dcedf71 100644 --- a/vllm/version.py +++ b/vllm/version.py @@ -1,13 +1,11 @@ -import warnings - try: - import vllm.commit_id - - __commit__ = vllm.commit_id.__commit__ + from ._version import __version__, __version_tuple__ except Exception as e: + import warnings + warnings.warn(f"Failed to read commit hash:\n{e}", RuntimeWarning, stacklevel=2) - __commit__ = "COMMIT_HASH_PLACEHOLDER" -__version__ = "0.6.1.post2" + __version__ = "dev" + __version_tuple__ = (0, 0, __version__) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index d7d7d65659b73..cebb0f36a2b28 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -12,11 +12,13 @@ SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.sequence import (IntermediateTensors, SequenceData, + SequenceGroupMetadata) from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, @@ -145,6 +147,38 @@ def build(self) -> ModelInputForCPU: query_lens=seq_lens, ) + def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data, + computed_len: int): + mm_kwargs = self.multi_modal_input_mapper(mm_data) + + # special processing for mrope position deltas. + mrope_positions = None + if self.runner.model_is_mrope: + image_grid_thw = mm_kwargs.get("image_grid_thw", None) + video_grid_thw = mm_kwargs.get("video_grid_thw", None) + assert image_grid_thw is not None or video_grid_thw is not None, ( + "mrope embedding type requires multi-modal input mapper " + "returns 'image_grid_thw' or 'video_grid_thw'.") + + hf_config = self.runner.model_config.hf_config + token_ids = seq_data.get_token_ids() + + mrope_positions, mrope_position_delta = \ + MRotaryEmbedding.get_input_positions( + token_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + vision_start_token_id=hf_config.vision_start_token_id, + vision_end_token_id=hf_config.vision_end_token_id, + spatial_merge_size=hf_config.vision_config. + spatial_merge_size, + context_len=computed_len, + ) + seq_data.mrope_position_delta = mrope_position_delta + return mm_kwargs, mrope_positions + def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -153,6 +187,8 @@ def _prepare_prompt( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + input_mrope_positions: List[List[int]] = [[] for _ in range(3)] + slot_mapping: List[int] = [] seq_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] @@ -171,14 +207,20 @@ def _prepare_prompt( seq_lens.append(seq_len) # Prompt token num input_tokens.extend(prompt_tokens) # Token ids + mrope_positions = None + if (mm_data := seq_group_metadata.multi_modal_data): + mm_kwargs, mrope_positions = self._compute_multi_modal_input( + seq_data, mm_data, computed_len) + multi_modal_inputs_list.append(mm_kwargs) + # Token position ids # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, seq_len))) - - if (mm_data := seq_group_metadata.multi_modal_data): - mm_kwargs = self.multi_modal_input_mapper(mm_data) - multi_modal_inputs_list.append(mm_kwargs) + if mrope_positions: + for idx in range(3): + input_mrope_positions[idx].extend(mrope_positions[idx]) + else: + input_positions.extend(list(range(computed_len, seq_len))) # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] @@ -202,12 +244,18 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) + if any(input_mrope_positions): + input_positions = None # type: ignore + else: + input_mrope_positions = None # type: ignore + num_prompt_tokens = len(input_tokens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) # type: ignore - input_positions = torch.tensor(input_positions, + input_positions = torch.tensor(input_positions + or input_mrope_positions, dtype=torch.long, device=self.device) # type: ignore slot_mapping = torch.tensor(slot_mapping, @@ -238,6 +286,7 @@ def _prepare_decode( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + input_mrope_positions: List[List[int]] = [[] for _ in range(3)] slot_mapping: List[int] = [] seq_lens: List[int] = [] block_tables: List[List[int]] = [] @@ -255,7 +304,17 @@ def _prepare_decode( seq_len = seq_data.get_len() position = seq_len - 1 - input_positions.append(position) + if seq_data.mrope_position_delta is not None: + context_len = seq_data.get_num_computed_tokens() + next_pos = MRotaryEmbedding.get_next_input_positions( + seq_data.mrope_position_delta, + context_len, + seq_len, + ) + for idx in range(3): + input_mrope_positions[idx].extend(next_pos[idx]) + else: + input_positions.append(position) seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) @@ -273,12 +332,18 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + if any(input_mrope_positions): + input_positions = None # type: ignore + else: + input_mrope_positions = None # type: ignore + max_decode_seq_len = max(seq_lens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) - input_positions = torch.tensor(input_positions, + input_positions = torch.tensor(input_positions + or input_mrope_positions, dtype=torch.long, device=self.device) slot_mapping = torch.tensor(slot_mapping, @@ -373,6 +438,15 @@ def __init__( raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU']) + @property + def model_is_mrope(self) -> bool: + """Detect if the model has "mrope" rope_scaling type. + mrope requires keep "rope_deltas" between prompt and decoding phases.""" + rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) + if rope_scaling is None: + return False + return rope_scaling.get("type", None) == "mrope" + def load_model(self) -> None: self.model = get_model(model_config=self.model_config, load_config=self.load_config, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 709efdc8b9d57..bd716ac3e7ec3 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -18,7 +18,8 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, + MultiModalRegistry) from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceGroupMetadata) @@ -52,6 +53,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "virtual_engine": self.virtual_engine, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, "finished_requests_ids": self.finished_requests_ids, + "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -194,6 +196,8 @@ def execute_model( "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_seqlen_agnostic else {} + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -202,6 +206,8 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), **seqlen_agnostic_kwargs) logits = self.model.compute_logits(hidden_or_intermediate_states, @@ -288,8 +294,7 @@ def profile_run(self) -> None: max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( self.model_config) if max_mm_tokens > 0: - raise NotImplementedError( - "Multi-modal encoder-decoder models are not supported yet") + logger.info("Starting profile run for multi-modal models.") batch_size = 0 for group_id in range(max_num_seqs): @@ -297,24 +302,39 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - seq_data, _ = self.input_registry \ - .dummy_data_for_profiling(self.model_config, + decoder_seq_data, decoder_dummy_multi_modal_data \ + = self.input_registry.dummy_data_for_profiling( + self.model_config, seq_len, - self.mm_registry) + self.mm_registry, + is_encoder_data=False) + encoder_seq_data, encoder_dummy_multi_modal_data \ + = self.input_registry.dummy_data_for_profiling( + self.model_config, + seq_len, + self.mm_registry, + is_encoder_data=True) # Having more tokens is over-conservative but otherwise fine - assert len(seq_data.prompt_token_ids) >= seq_len, ( + assert len(decoder_seq_data.prompt_token_ids) >= seq_len, ( f"Expected at least {seq_len} dummy tokens for profiling, " - f"but got: {len(seq_data.prompt_token_ids)}") + f"but got: {len(decoder_seq_data.prompt_token_ids)}") + + assert decoder_dummy_multi_modal_data is None or \ + encoder_dummy_multi_modal_data is None, ( + "Multi-modal data can't be provided in both encoder and decoder" + ) seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, - seq_data={group_id: seq_data}, + seq_data={group_id: decoder_seq_data}, sampling_params=sampling_params, block_tables=None, - encoder_seq_data=seq_data, + encoder_seq_data=encoder_seq_data, cross_block_table=None, + multi_modal_data=decoder_dummy_multi_modal_data + or encoder_dummy_multi_modal_data, ) seqs.append(seq) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 975b88c0e79a2..86883cf152449 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -137,7 +137,15 @@ def _wrapper(*args, **kwargs): for t in kv_caches if is_tensor(t)] - pickle.dump(dumped_inputs, filep) + try: + pickle.dump(dumped_inputs, filep) + except Exception as pickle_err: + logger.warning( + "Failed to pickle inputs of failed execution: %s", + str(pickle_err)) + raise type(err)(f"Error in model execution: " + f"{str(err)}") from err + logger.info( "Completed writing input of failed execution to %s.", filename) diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py index a58b80e4f2adb..a07395dfc61d8 100644 --- a/vllm/worker/utils.py +++ b/vllm/worker/utils.py @@ -39,10 +39,6 @@ def assert_enc_dec_mr_supported_scenario( raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP']) - if enc_dec_mr.model_config.is_multimodal_model: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM']) - if enc_dec_mr.scheduler_config.num_lookahead_slots > 0: raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC'])