diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py new file mode 100644 index 0000000000000..90a5e54736cf3 --- /dev/null +++ b/.buildkite/check-wheel-size.py @@ -0,0 +1,36 @@ +import os +import zipfile + +MAX_SIZE_MB = 100 + + +def print_top_10_largest_files(zip_file): + with zipfile.ZipFile(zip_file, 'r') as z: + file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()] + file_sizes.sort(key=lambda x: x[1], reverse=True) + for f, size in file_sizes[:10]: + print(f"{f}: {size/(1024*1024)} MBs uncompressed.") + + +def check_wheel_size(directory): + for root, _, files in os.walk(directory): + for f in files: + if f.endswith(".whl"): + wheel_path = os.path.join(root, f) + wheel_size = os.path.getsize(wheel_path) + wheel_size_mb = wheel_size / (1024 * 1024) + if wheel_size_mb > MAX_SIZE_MB: + print( + f"Wheel {wheel_path} is too large ({wheel_size_mb} MB) " + f"compare to the allowed size ({MAX_SIZE_MB} MB).") + print_top_10_largest_files(wheel_path) + return 1 + else: + print(f"Wheel {wheel_path} is within the allowed size " + f"({wheel_size_mb} MB).") + return 0 + + +if __name__ == "__main__": + import sys + sys.exit(check_wheel_size(sys.argv[1])) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 38aff57a410dc..c04e05a994894 100644 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -1,10 +1,11 @@ -# This script build the ROCm docker image and run the API server inside the container. -# It serves a sanity check for compilation and basic model usage. +# This script build the ROCm docker image and runs test inside it. set -ex # Print ROCm version +echo "--- ROCm info" rocminfo +echo "--- Resetting GPUs" echo "reset" > /opt/amdgpu/etc/gpu_state @@ -16,37 +17,28 @@ while true; do fi done +echo "--- Building container" +sha=$(git rev-parse --short HEAD) +container_name=rocm_${sha} +docker build \ + -t ${container_name} \ + -f Dockerfile.rocm \ + --progress plain \ + . + +remove_docker_container() { + docker rm -f ${container_name} || docker image rm -f ${container_name} || true +} +trap remove_docker_container EXIT +echo "--- Running container" -# Try building the docker image -docker build -t rocm -f Dockerfile.rocm . - -# Setup cleanup -remove_docker_container() { docker rm -f rocm || true; } -trap remove_docker_container EXIT -remove_docker_container - -# Run the image -export HIP_VISIBLE_DEVICES=1 -docker run --device /dev/kfd --device /dev/dri --network host -e HIP_VISIBLE_DEVICES --name rocm rocm python3 -m vllm.entrypoints.api_server & - -# Wait for the server to start -wait_for_server_to_start() { - timeout=300 - counter=0 - - while [ "$(curl -s -o /dev/null -w ''%{http_code}'' localhost:8000/health)" != "200" ]; do - sleep 1 - counter=$((counter + 1)) - if [ $counter -ge $timeout ]; then - echo "Timeout after $timeout seconds" - break - fi - done -} -wait_for_server_to_start +docker run \ + --device /dev/kfd --device /dev/dri \ + --network host \ + --rm \ + -e HF_TOKEN \ + --name ${container_name} \ + ${container_name} \ + /bin/bash -c $(echo $1 | sed "s/^'//" | sed "s/'$//") -# Test a simple prompt -curl -X POST -H "Content-Type: application/json" \ - localhost:8000/generate \ - -d '{"prompt": "San Francisco is a"}' diff --git a/.buildkite/run-benchmarks.sh b/.buildkite/run-benchmarks.sh index f6a542afe1a3d..7fbad1c4bd950 100644 --- a/.buildkite/run-benchmarks.sh +++ b/.buildkite/run-benchmarks.sh @@ -53,6 +53,11 @@ echo '```' >> benchmark_results.md tail -n 20 benchmark_serving.txt >> benchmark_results.md # last 20 lines echo '```' >> benchmark_results.md +# if the agent binary is not found, skip uploading the results, exit 0 +if [ ! -f /workspace/buildkite-agent ]; then + exit 0 +fi + # upload the results to buildkite /workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 11cda053260ec..e49a5650c44ea 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -17,27 +17,38 @@ steps: - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py + - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - label: Core Test + mirror_hardwares: [amd] command: pytest -v -s core - label: Distributed Comm Ops Test command: pytest -v -s test_comm_ops.py working_dir: "/vllm-workspace/tests/distributed" - num_gpus: 2 # only support 1 or 2 for now. + num_gpus: 2 - label: Distributed Tests working_dir: "/vllm-workspace/tests/distributed" + num_gpus: 2 # only support 1 or 2 for now. + mirror_hardwares: [amd] + commands: - - pytest -v -s test_pynccl.py - pytest -v -s test_pynccl_library.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py +- label: Distributed Tests (Multiple Groups) + working_dir: "/vllm-workspace/tests/distributed" + num_gpus: 4 + commands: + - pytest -v -s test_pynccl.py + - label: Engine Test + mirror_hardwares: [amd] command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py - label: Entrypoints Test @@ -48,6 +59,7 @@ steps: - label: Examples Test working_dir: "/vllm-workspace/examples" + mirror_hardwares: [amd] commands: # install aws cli for llava_example.py - pip install awscli @@ -61,16 +73,19 @@ steps: parallelism: 4 - label: Models Test + mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py - label: Llava Test + mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - pytest -v -s models/test_llava.py - label: Prefix Caching Test + mirror_hardwares: [amd] commands: - pytest -v -s prefix_caching @@ -78,12 +93,15 @@ steps: command: pytest -v -s samplers - label: LogitsProcessor Test + mirror_hardwares: [amd] command: pytest -v -s test_logits_processor.py - label: Worker Test + mirror_hardwares: [amd] command: pytest -v -s worker - label: Speculative decoding tests + mirror_hardwares: [amd] command: pytest -v -s spec_decode - label: LoRA Test %N @@ -101,6 +119,7 @@ steps: - label: Benchmarks working_dir: "/vllm-workspace/.buildkite" + mirror_hardwares: [amd] commands: - pip install aiohttp - bash run-benchmarks.sh diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index 5c9515840bb03..919a09e1cc064 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -14,20 +14,33 @@ steps: automatic: - exit_status: -1 # Agent was lost limit: 5 + - exit_status: -10 # Agent was lost + limit: 5 - wait - - label: "AMD Test" - agents: - queue: amd - command: bash .buildkite/run-amd-test.sh + - group: "AMD Tests" + depends_on: ~ + steps: + {% for step in steps %} + {% if step.mirror_hardwares and "amd" in step.mirror_hardwares %} + - label: "AMD: {{ step.label }}" + agents: + queue: amd + command: bash .buildkite/run-amd-test.sh "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'" + env: + DOCKER_BUILDKIT: "1" + {% endif %} + {% endfor %} - label: "Neuron Test" + depends_on: ~ agents: queue: neuron command: bash .buildkite/run-neuron-test.sh soft_fail: true - - label: "CPU Test" + - label: "Intel Test" + depends_on: ~ command: bash .buildkite/run-cpu-test.sh {% for step in steps %} @@ -42,9 +55,14 @@ steps: automatic: - exit_status: -1 # Agent was lost limit: 5 + - exit_status: -10 # Agent was lost + limit: 5 plugins: - kubernetes: podSpec: + {% if step.num_gpus %} + priorityClassName: gpu-priority-cls-{{ step.num_gpus }} + {% endif %} volumes: - name: dshm emptyDir: diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index a19be8525f902..a20753d8a7702 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -33,6 +33,7 @@ jobs: - name: Mypy run: | mypy vllm/attention --config-file pyproject.toml + mypy vllm/core --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml @@ -42,9 +43,8 @@ jobs: mypy vllm/engine --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml - mypy vllm/lora --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml - - # TODO(sang): Fix nested dir - mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/lora --config-file pyproject.toml + mypy vllm/logging --config-file pyproject.toml + mypy vllm/model_executor --config-file pyproject.toml diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index d79681f03b003..9c35ede5f6781 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -58,6 +58,9 @@ jobs: - name: Setup ccache uses: hendrikmuhs/ccache-action@v1.2 + with: + create-symlink: true + key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }} - name: Set up Linux Env if: ${{ runner.os == 'Linux' }} @@ -79,6 +82,8 @@ jobs: - name: Build wheel shell: bash + env: + CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size run: | bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} wheel_name=$(ls dist/*whl | xargs -n 1 basename) diff --git a/.github/workflows/scripts/create_release.js b/.github/workflows/scripts/create_release.js index 0f25624b4c21c..475742118afeb 100644 --- a/.github/workflows/scripts/create_release.js +++ b/.github/workflows/scripts/create_release.js @@ -8,7 +8,7 @@ module.exports = async (github, context, core) => { generate_release_notes: true, name: process.env.RELEASE_TAG, owner: context.repo.owner, - prerelease: false, + prerelease: true, repo: context.repo.repo, tag_name: process.env.RELEASE_TAG, }); diff --git a/Dockerfile b/Dockerfile index e471a6e93b963..90be3a30f89b1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,9 +1,13 @@ # The vLLM Dockerfile is used to construct vLLM image that can be directly used # to run the OpenAI compatible server. +# Please update any changes made here to +# docs/source/dev/dockerfile/dockerfile.rst and +# docs/source/assets/dev/dockerfile-stages-dependency.png + #################### BASE BUILD IMAGE #################### # prepare basic build environment -FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev +FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS dev RUN apt-get update -y \ && apt-get install -y python3-pip git @@ -12,7 +16,7 @@ RUN apt-get update -y \ # https://github.com/pytorch/pytorch/issues/107960 -- hopefully # this won't be needed for future versions of this docker image # or future versions of triton. -RUN ldconfig /usr/local/cuda-12.1/compat/ +RUN ldconfig /usr/local/cuda-12.4/compat/ WORKDIR /workspace @@ -71,6 +75,10 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ --mount=type=cache,target=/root/.cache/pip \ python3 setup.py bdist_wheel --dist-dir=dist +# check the size of the wheel, we cannot upload wheels larger than 100MB +COPY .buildkite/check-wheel-size.py check-wheel-size.py +RUN python3 check-wheel-size.py dist + # the `vllm_nccl` package must be installed from source distribution # pip is too smart to store a wheel in the cache, and other CI jobs # will directly use the wheel from the cache, which is not what we want. @@ -98,7 +106,7 @@ RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ #################### vLLM installation IMAGE #################### # image with vLLM installed -FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base +FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base WORKDIR /vllm-workspace RUN apt-get update -y \ @@ -108,7 +116,7 @@ RUN apt-get update -y \ # https://github.com/pytorch/pytorch/issues/107960 -- hopefully # this won't be needed for future versions of this docker image # or future versions of triton. -RUN ldconfig /usr/local/cuda-12.1/compat/ +RUN ldconfig /usr/local/cuda-12.4/compat/ # install vllm wheel first, so that torch etc will be installed RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 3f84b949481d1..d04bb9915e2ab 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -46,7 +46,7 @@ RUN apt-get update && apt-get install -y \ ### Mount Point ### # When launching the container, mount the code directory to /app -ARG APP_MOUNT=/app +ARG APP_MOUNT=/vllm-workspace VOLUME [ ${APP_MOUNT} ] WORKDIR ${APP_MOUNT} @@ -89,15 +89,16 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \ && cd ../..; \ fi -COPY ./ /app/vllm +WORKDIR /vllm-workspace +COPY . . RUN python3 -m pip install --upgrade pip numba -RUN cd /app \ - && cd vllm \ - && pip install -U -r requirements-rocm.txt \ - && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \ +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -U -r requirements-rocm.txt \ + && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \ && python3 setup.py install \ + && cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \ && cd .. RUN python3 -m pip install --upgrade pip diff --git a/MANIFEST.in b/MANIFEST.in index d385f194c6c0f..82be639ef4d73 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,9 @@ include LICENSE include requirements-common.txt include requirements-cuda.txt +include requirements-rocm.txt +include requirements-neuron.txt +include requirements-cpu.txt include CMakeLists.txt recursive-include cmake * diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 1f3274a28cad5..089966986984f 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -16,20 +16,22 @@ def test_prefix(llm=None, sampling_params=None, prompts=None): def main(args): - llm = LLM(model="baichuan-inc/Baichuan2-13B-Chat", + llm = LLM(model=args.model, tokenizer_mode='auto', trust_remote_code=True, enforce_eager=True, + use_v2_block_manager=args.use_v2_block_manager, + tensor_parallel_size=args.tensor_parallel_size, enable_prefix_caching=args.enable_prefix_caching) num_prompts = 100 prompts = [PROMPT] * num_prompts - sampling_params = SamplingParams(temperature=0, max_tokens=100) + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) print("------warm up------") test_prefix( llm=llm, - prompts=prompts[:1], + prompts=prompts, sampling_params=sampling_params, ) @@ -45,8 +47,16 @@ def main(args): parser = argparse.ArgumentParser( description='Benchmark the performance with or without automatic ' 'prefix caching.') + parser.add_argument('--model', + type=str, + default='baichuan-inc/Baichuan2-13B-Chat') + parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) + parser.add_argument('--output-len', type=int, default=10) parser.add_argument('--enable-prefix-caching', action='store_true', help='enable prefix caching') + parser.add_argument('--use-v2-block-manager', + action='store_true', + help='Use BlockSpaceMangerV2') args = parser.parse_args() main(args) diff --git a/benchmarks/kernels/benchmark_mixtral_moe.py b/benchmarks/kernels/benchmark_mixtral_moe.py index 8e976fbcb3028..5280b214144c9 100644 --- a/benchmarks/kernels/benchmark_mixtral_moe.py +++ b/benchmarks/kernels/benchmark_mixtral_moe.py @@ -1,3 +1,4 @@ +import argparse import json import os import sys @@ -5,6 +6,7 @@ import torch import torch.nn.functional as F import triton +from tqdm import tqdm from vllm.model_executor.layers.fused_moe import (fused_moe, get_config_file_name) @@ -12,16 +14,16 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '0' -def main(): +def main(dtype: str): method = fused_moe for bs in [ 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, 3072, 4096 ]: - run_grid(bs, method=method) + run_grid(bs, method=method, dtype=dtype) -def run_grid(bs, method): +def run_grid(bs, method, dtype: str): d_model = 4096 num_total_experts = 8 top_k = 2 @@ -34,39 +36,29 @@ def run_grid(bs, method): num_trials = 1 configs = [] - if bs <= 16: - BLOCK_SIZES_M = [16] - elif bs <= 32: - BLOCK_SIZES_M = [16, 32] - elif bs <= 64: - BLOCK_SIZES_M = [16, 32, 64] - elif bs <= 128: - BLOCK_SIZES_M = [16, 32, 64, 128] - else: - BLOCK_SIZES_M = [16, 32, 64, 128, 256] for block_size_n in [32, 64, 128, 256]: - for block_size_m in BLOCK_SIZES_M: + for block_size_m in [16, 32, 64, 128, 256]: for block_size_k in [64, 128, 256]: for group_size_m in [1, 16, 32, 64]: for num_warps in [4, 8]: - configs.append({ - "BLOCK_SIZE_M": block_size_m, - "BLOCK_SIZE_N": block_size_n, - "BLOCK_SIZE_K": block_size_k, - "GROUP_SIZE_M": group_size_m, - "num_warps": num_warps, - "num_stages": 4, - }) + for num_stages in [2, 3, 4, 5]: + configs.append({ + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, + }) best_config = None best_time_us = 1e20 - for config in configs: - print(f'{tp_size=} {bs=}') - print(f'{config}') + print(f'{tp_size=} {bs=}') + + for config in tqdm(configs): # warmup - print('warming up') try: for _ in range(num_warmup_trials): run_timing( @@ -79,12 +71,12 @@ def run_grid(bs, method): model_intermediate_size=model_intermediate_size, method=method, config=config, + dtype=dtype, ) except triton.runtime.autotuner.OutOfResources: continue # trial - print('benchmarking') for _ in range(num_trials): kernel_dur_ms = run_timing( num_calls=num_calls, @@ -96,6 +88,7 @@ def run_grid(bs, method): model_intermediate_size=model_intermediate_size, method=method, config=config, + dtype=dtype, ) kernel_dur_us = 1000 * kernel_dur_ms @@ -105,16 +98,18 @@ def run_grid(bs, method): best_config = config best_time_us = kernel_dur_us - print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' - f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' - f'{d_model=} {model_intermediate_size=} {num_layers=}') + tqdm.write( + f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' + f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' + f'{d_model=} {model_intermediate_size=} {num_layers=}') print("best_time_us", best_time_us) print("best_config", best_config) # holds Dict[str, Dict[str, int]] filename = get_config_file_name(num_total_experts, - model_intermediate_size // tp_size) + model_intermediate_size // tp_size, + "float8" if dtype == "float8" else None) print(f"writing config to file {filename}") existing_content = {} if os.path.exists(filename): @@ -128,27 +123,48 @@ def run_grid(bs, method): def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, top_k: int, tp_size: int, model_intermediate_size: int, method, - config) -> float: + config, dtype: str) -> float: shard_intermediate_size = model_intermediate_size // tp_size hidden_states = torch.rand( (bs, d_model), device="cuda:0", - dtype=torch.bfloat16, + dtype=torch.float16, ) - ws = torch.rand( + w1 = torch.rand( (num_total_experts, 2 * shard_intermediate_size, d_model), device=hidden_states.device, dtype=hidden_states.dtype, ) - w2s = torch.rand( + w2 = torch.rand( (num_total_experts, d_model, shard_intermediate_size), device=hidden_states.device, dtype=hidden_states.dtype, ) + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + + if dtype == "float8": + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + w1_scale = torch.ones(num_total_experts, + device=hidden_states.device, + dtype=torch.float32) + w2_scale = torch.ones(num_total_experts, + device=hidden_states.device, + dtype=torch.float32) + a1_scale = torch.ones(1, + device=hidden_states.device, + dtype=torch.float32) + a2_scale = torch.ones(1, + device=hidden_states.device, + dtype=torch.float32) + gating_output = F.softmax(torch.rand( (num_calls, bs, num_total_experts), device=hidden_states.device, @@ -163,13 +179,18 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, for i in range(num_calls): hidden_states = method( hidden_states=hidden_states, - w1=ws, - w2=w2s, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, gating_output=gating_output[i], topk=2, renormalize=True, inplace=True, override_config=config, + use_fp8=dtype == "float8", ) end_event.record() end_event.synchronize() @@ -179,4 +200,16 @@ def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int, if __name__ == "__main__": - sys.exit(main()) + parser = argparse.ArgumentParser( + prog='benchmark_mixtral_moe', + description='Benchmark and tune the fused_moe kernel', + ) + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['float8', 'float16'], + help='Data type used for fused_moe kernel computations', + ) + args = parser.parse_args() + sys.exit(main(args.dtype)) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 5c3650fa72d17..ca7967c1ab0d2 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -16,7 +16,7 @@ def main( version: str, num_seqs: int, - context_len: int, + seq_len: int, num_query_heads: int, num_kv_heads: int, head_size: int, @@ -48,12 +48,12 @@ def main( dtype=torch.float, device=device) - context_lens = [context_len for _ in range(num_seqs)] - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int, device=device) + seq_lens = [seq_len for _ in range(num_seqs)] + max_seq_len = max(seq_lens) + seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device) # Create the block tables. - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = [] for _ in range(num_seqs): block_table = [ @@ -77,8 +77,7 @@ def main( # Prepare for the paged attention kernel. output = torch.empty_like(query) if version == "v2": - num_partitions = ((max_context_len + PARTITION_SIZE - 1) // - PARTITION_SIZE) + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), dtype=output.dtype, @@ -110,9 +109,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -129,9 +128,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -166,7 +165,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: choices=["v1", "v2"], default="v2") parser.add_argument("--batch-size", type=int, default=8) - parser.add_argument("--context-len", type=int, default=4096) + parser.add_argument("--seq_len", type=int, default=4096) parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--head-size", @@ -199,7 +198,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: main( version=args.version, num_seqs=args.batch_size, - context_len=args.context_len, + seq_len=args.seq_len, num_query_heads=args.num_query_heads, num_kv_heads=args.num_kv_heads, head_size=args.head_size, diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index f3a5bbfd3098d..8b1b5e098015f 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -104,7 +104,7 @@ __device__ void paged_attention_kernel( const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -115,23 +115,23 @@ __device__ void paged_attention_kernel( const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; - const int context_len = context_lens[seq_idx]; - if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + const int seq_len = seq_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { // No work to do. Terminate the thread block. return; } - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; // [start_block_idx, end_block_idx) is the range of blocks to process. const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); const int num_blocks = end_block_idx - start_block_idx; // [start_token_idx, end_token_idx) is the range of tokens to process. const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); const int num_tokens = end_token_idx - start_token_idx; constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); @@ -245,12 +245,12 @@ __device__ void paged_attention_kernel( // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); // Add the ALiBi bias if slopes are given. - qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. // NOTE(woosuk): It is required to zero out the masked logits. - const bool mask = token_idx >= context_len; + const bool mask = token_idx >= seq_len; logits[token_idx - start_token_idx] = mask ? 0.f : qk; // Update the max value. qk_max = mask ? qk_max : fmaxf(qk_max, qk); @@ -364,14 +364,14 @@ __device__ void paged_attention_kernel( } else { v_vec = *reinterpret_cast(v_ptr + offset); } - if (block_idx == num_context_blocks - 1) { + if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the context, // we should explicitly zero out the values since they may contain NaNs. // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); #pragma unroll for (int j = 0; j < V_VEC_SIZE; j++) { - v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value; + v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; } } accs[i] += dot(logits_vec, v_vec); @@ -457,7 +457,7 @@ __global__ void paged_attention_v1_kernel( const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -466,7 +466,7 @@ __global__ void paged_attention_v1_kernel( const float kv_scale) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, - out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, + out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); } @@ -489,7 +489,7 @@ __global__ void paged_attention_v2_kernel( const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, @@ -498,7 +498,7 @@ __global__ void paged_attention_v2_kernel( const float kv_scale) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale); } @@ -513,13 +513,13 @@ __global__ void paged_attention_v2_reduce_kernel( const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] + const int* __restrict__ seq_lens, // [num_seqs] const int max_num_partitions) { const int num_heads = gridDim.x; const int head_idx = blockIdx.x; const int seq_idx = blockIdx.y; - const int context_len = context_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); if (num_partitions == 1) { // No need to reduce. Only copy tmp_out to out. scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; @@ -616,7 +616,7 @@ __global__ void paged_attention_v2_reduce_kernel( num_kv_heads, \ scale, \ block_tables_ptr, \ - context_lens_ptr, \ + seq_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ q_stride, \ @@ -639,8 +639,8 @@ void paged_attention_v1_launcher( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, + torch::Tensor& seq_lens, + int max_seq_len, const c10::optional& alibi_slopes, float kv_scale) { int num_seqs = query.size(0); @@ -664,11 +664,11 @@ void paged_attention_v1_launcher( CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_context_len * sizeof(float); + int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_seq_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Keep that in sync with the logic here! @@ -715,8 +715,8 @@ void paged_attention_v1_launcher( num_kv_heads, \ scale, \ block_tables, \ - context_lens, \ - max_context_len, \ + seq_lens, \ + max_seq_len, \ alibi_slopes, \ kv_scale); @@ -746,9 +746,9 @@ void paged_attention_v1( int num_kv_heads, // [num_heads] float scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] + torch::Tensor& seq_lens, // [num_seqs] int block_size, - int max_context_len, + int max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) { @@ -790,7 +790,7 @@ void paged_attention_v1( num_kv_heads, \ scale, \ block_tables_ptr, \ - context_lens_ptr, \ + seq_lens_ptr, \ max_num_blocks_per_seq, \ alibi_slopes_ptr, \ q_stride, \ @@ -803,7 +803,7 @@ void paged_attention_v1( exp_sums_ptr, \ max_logits_ptr, \ tmp_out_ptr, \ - context_lens_ptr, \ + seq_lens_ptr, \ max_num_partitions); template< @@ -824,8 +824,8 @@ void paged_attention_v2_launcher( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, + torch::Tensor& seq_lens, + int max_seq_len, const c10::optional& alibi_slopes, float kv_scale) { int num_seqs = query.size(0); @@ -852,10 +852,10 @@ void paged_attention_v2_launcher( CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); int logits_size = PARTITION_SIZE * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); @@ -909,8 +909,8 @@ void paged_attention_v2_launcher( num_kv_heads, \ scale, \ block_tables, \ - context_lens, \ - max_context_len, \ + seq_lens, \ + max_seq_len, \ alibi_slopes, \ kv_scale); @@ -943,9 +943,9 @@ void paged_attention_v2( int num_kv_heads, // [num_heads] float scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] + torch::Tensor& seq_lens, // [num_seqs] int block_size, - int max_context_len, + int max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale) { diff --git a/csrc/cache.h b/csrc/cache.h index 718a5f6cfd7f7..10871b3670bac 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -13,7 +13,7 @@ void swap_blocks( void copy_blocks( std::vector& key_caches, std::vector& value_caches, - const std::map>& block_mapping); + torch::Tensor& block_mapping); void reshape_and_cache( torch::Tensor& key, @@ -24,6 +24,14 @@ void reshape_and_cache( const std::string& kv_cache_dtype, const float kv_scale); +void reshape_and_cache_flash( + torch::Tensor& key, + torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype); + // Just for unittest void convert_fp8( torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 24aaa2ff3e263..1e02f7fcbae4c 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -97,7 +97,7 @@ __global__ void copy_blocks_kernel( void copy_blocks( std::vector& key_caches, std::vector& value_caches, - const std::map>& block_mapping) { + torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { @@ -114,17 +114,9 @@ void copy_blocks( key_cache_ptrs[layer_idx] = reinterpret_cast(key_caches[layer_idx].data_ptr()); value_cache_ptrs[layer_idx] = reinterpret_cast(value_caches[layer_idx].data_ptr()); } - // Create block mapping array. - std::vector block_mapping_vec; - for (const auto& pair : block_mapping) { - int64_t src_block_number = pair.first; - for (int64_t dst_block_number : pair.second) { - block_mapping_vec.push_back(src_block_number); - block_mapping_vec.push_back(dst_block_number); - } - } - int64_t* block_mapping_array = block_mapping_vec.data(); - int num_pairs = block_mapping_vec.size() / 2; + + // block_mapping is a 2D tensor with shape (num_pairs, 2). + int num_pairs = block_mapping.size(0); // Move the data structures to the GPU. // NOTE: This synchronizes the CPU and GPU. @@ -132,8 +124,6 @@ void copy_blocks( key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); torch::Tensor value_cache_ptrs_tensor = torch::from_blob( value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device); - torch::Tensor block_mapping_tensor = torch::from_blob( - block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device); // Launch the kernel. const int numel_per_block = key_caches[0][0].numel(); @@ -146,7 +136,7 @@ void copy_blocks( vllm::copy_blocks_kernel<<>>( key_cache_ptrs_tensor.data_ptr(), value_cache_ptrs_tensor.data_ptr(), - block_mapping_tensor.data_ptr(), + block_mapping.data_ptr(), numel_per_block); })); } @@ -215,6 +205,41 @@ __global__ void reshape_and_cache_kernel( } } +template +__global__ void reshape_and_cache_flash_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size] + scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, + const int key_stride, + const int value_stride, + const int num_heads, + const int head_size, + const int block_size) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int64_t tgt_value_idx = block_idx * block_stride + + block_offset * num_heads * head_size + + head_idx * head_size + + head_offset; + k_cache[tgt_value_idx] = key[src_key_idx]; + v_cache[tgt_value_idx] = value[src_value_idx]; + } +} } // namespace vllm #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \ @@ -275,6 +300,51 @@ void reshape_and_cache( } } +void reshape_and_cache_flash( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype) +{ + // FIXME: only support auto datatype, does not support fp8 + if (kv_cache_dtype != "auto") { + TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); + } + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = k_cache.size(1); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + int block_stride = k_cache.stride(0); + TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0)); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + key.scalar_type(), + "reshape_and_cache_flash", + [&] { + vllm::reshape_and_cache_flash_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + k_cache.data_ptr(), + v_cache.data_ptr(), + slot_mapping.data_ptr(), + block_stride, + key_stride, + value_stride, + num_heads, + head_size, + block_size); + }); +} + namespace vllm { template diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 365bbd5e23728..c1d765be05598 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -70,11 +70,11 @@ template FORCE_INLINE std::pair reduceSoftmaxAlibi(T *data, const int size, const int capacity, const float alibi_slope, const int start_index, - const int context_len) { - data[0] += alibi_slope * (start_index - context_len + 1); + const int seq_len) { + data[0] += alibi_slope * (start_index - seq_len + 1); T max = data[0]; for (int i = 1; i < size; ++i) { - T qk = data[i] + alibi_slope * (start_index + i - context_len + 1); + T qk = data[i] + alibi_slope * (start_index + i - seq_len + 1); data[i] = qk; max = max >= qk ? max : qk; } @@ -225,7 +225,7 @@ struct paged_attention_v1_impl { const int num_kv_heads, const float scale, const int *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ context_lens, // [num_seqs] + const int *__restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float *__restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -235,32 +235,32 @@ struct paged_attention_v1_impl { static_assert(BLOCK_SIZE == 16); - int max_context_len = max_num_blocks_per_seq * BLOCK_SIZE; - int max_context_len_padded = (max_context_len + 15) & 0xFFFFFFF0; - TORCH_CHECK((max_context_len_padded * sizeof(float)) % 64 == 0); + int max_seq_len = max_num_blocks_per_seq * BLOCK_SIZE; + int max_seq_len_padded = (max_seq_len + 15) & 0xFFFFFFF0; + TORCH_CHECK((max_seq_len_padded * sizeof(float)) % 64 == 0); const int parallel_work_item_num = omp_get_max_threads(); size_t logits_bytes = - parallel_work_item_num * max_context_len_padded * sizeof(float); + parallel_work_item_num * max_seq_len_padded * sizeof(float); float *logits = (float *)std::aligned_alloc( 64, logits_bytes); // Cacheline alignment for each context token. - // [parallel_work_item_num, max_context_len_padded] + // [parallel_work_item_num, max_seq_len_padded] #pragma omp parallel for collapse(2) schedule(dynamic, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - int context_len = context_lens[seq_idx]; + int seq_len = seq_lens[seq_idx]; const int *seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx; - const int block_num = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + const int block_num = (seq_len + BLOCK_SIZE - 1) / BLOCK_SIZE; const int64_t kv_head_idx = head_idx / num_queries_per_kv; const scalar_t *__restrict__ q_vec_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; const int last_block_token_num = - context_len - (block_num - 1) * BLOCK_SIZE; + seq_len - (block_num - 1) * BLOCK_SIZE; float *__restrict__ thread_block_logits = - logits + omp_get_thread_num() * max_context_len_padded; + logits + omp_get_thread_num() * max_seq_len_padded; // Compute logits for (int block_idx = 0; block_idx < block_num; ++block_idx) { @@ -278,11 +278,11 @@ struct paged_attention_v1_impl { // Compute softmax if (alibi_slopes) { - reduceSoftmaxAlibi(thread_block_logits, context_len, + reduceSoftmaxAlibi(thread_block_logits, seq_len, block_num * BLOCK_SIZE, alibi_slopes[head_idx], 0, - context_len); + seq_len); } else { - reduceSoftmax(thread_block_logits, context_len, + reduceSoftmax(thread_block_logits, seq_len, block_num * BLOCK_SIZE); } @@ -340,7 +340,7 @@ struct paged_attention_v1_impl { #define LAUNCH_V1_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE) \ paged_attention_v1_impl::call( \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ - block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, num_seqs, \ num_heads); @@ -348,8 +348,8 @@ template void paged_attention_v1_impl_launcher( torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &context_lens, - int max_context_len, const c10::optional &alibi_slopes) { + torch::Tensor &block_tables, torch::Tensor &seq_lens, + int max_seq_len, const c10::optional &alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -369,7 +369,7 @@ void paged_attention_v1_impl_launcher( T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int *block_tables_ptr = block_tables.data_ptr(); - int *context_lens_ptr = context_lens.data_ptr(); + int *seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { case 64: @@ -399,7 +399,7 @@ void paged_attention_v1_impl_launcher( #define CALL_V1_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ paged_attention_v1_impl_launcher( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - context_lens, max_context_len, alibi_slopes); + seq_lens, max_seq_len, alibi_slopes); #define CALL_V1_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ @@ -416,8 +416,8 @@ void paged_attention_v1(torch::Tensor &out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor &block_tables, - torch::Tensor &context_lens, int block_size, - int max_context_len, + torch::Tensor &seq_lens, int block_size, + int max_seq_len, const c10::optional &alibi_slopes, const std::string &kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); @@ -448,7 +448,7 @@ struct paged_attention_v2_impl { const int num_kv_heads, const float scale, const int *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int *__restrict__ context_lens, // [num_seqs] + const int *__restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float *__restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, @@ -465,22 +465,22 @@ struct paged_attention_v2_impl { for (int partition_idx = 0; partition_idx < max_num_partitions; ++partition_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int context_len = context_lens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int start_token_idx = partition_idx * PARTITION_SIZE; - if (start_token_idx >= context_len) + if (start_token_idx >= seq_len) continue; const int partition_num = - (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; const bool no_reduce = (partition_num == 1); - const int context_token_num = - (std::min(context_len, start_token_idx + PARTITION_SIZE) - + const int token_num = + (std::min(seq_len, start_token_idx + PARTITION_SIZE) - start_token_idx); const int block_num = - (context_token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; + (token_num + BLOCK_SIZE - 1) / BLOCK_SIZE; const int last_block_token_num = - context_token_num - (block_num - 1) * BLOCK_SIZE; + token_num - (block_num - 1) * BLOCK_SIZE; const int *seq_block_table = block_tables + max_num_blocks_per_seq * seq_idx + start_token_idx / BLOCK_SIZE; @@ -507,10 +507,10 @@ struct paged_attention_v2_impl { std::pair max_and_sum; if (alibi_slopes) { max_and_sum = reduceSoftmaxAlibi( - logits, context_token_num, block_num * BLOCK_SIZE, - alibi_slopes[head_idx], start_token_idx, context_len); + logits, token_num, block_num * BLOCK_SIZE, + alibi_slopes[head_idx], start_token_idx, seq_len); } else { - max_and_sum = reduceSoftmax(logits, context_token_num, + max_and_sum = reduceSoftmax(logits, token_num, block_num * BLOCK_SIZE); } @@ -583,9 +583,9 @@ struct paged_attention_v2_impl { #pragma omp parallel for collapse(2) schedule(static, 1) for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { - const int context_len = context_lens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int partition_num = - (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; if (partition_num == 1) continue; @@ -612,9 +612,9 @@ struct paged_attention_v2_impl { for (int seq_idx = 0; seq_idx < num_seqs; ++seq_idx) { for (int head_idx = 0; head_idx < num_heads; ++head_idx) { for (int group_idx = 0; group_idx < head_group_num; ++group_idx) { - const int context_len = context_lens[seq_idx]; + const int seq_len = seq_lens[seq_idx]; const int partition_num = - (context_len + PARTITION_SIZE - 1) / PARTITION_SIZE; + (seq_len + PARTITION_SIZE - 1) / PARTITION_SIZE; if (partition_num == 1) continue; @@ -649,7 +649,7 @@ struct paged_attention_v2_impl { paged_attention_v2_impl::call( \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, \ key_cache_ptr, value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ kv_block_stride, kv_head_stride, num_seqs, num_heads, \ max_num_partitions); @@ -658,8 +658,8 @@ void paged_attention_v2_impl_launcher( torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &max_logits, torch::Tensor &tmp_out, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, - torch::Tensor &block_tables, torch::Tensor &context_lens, int block_size, - int max_context_len, const c10::optional &alibi_slopes) { + torch::Tensor &block_tables, torch::Tensor &seq_lens, int block_size, + int max_seq_len, const c10::optional &alibi_slopes) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -683,7 +683,7 @@ void paged_attention_v2_impl_launcher( T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int *block_tables_ptr = block_tables.data_ptr(); - int *context_lens_ptr = context_lens.data_ptr(); + int *seq_lens_ptr = seq_lens.data_ptr(); switch (head_size) { case 64: @@ -713,8 +713,8 @@ void paged_attention_v2_impl_launcher( #define CALL_V2_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ paged_attention_v2_impl_launcher( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, block_size, \ - max_context_len, alibi_slopes); + num_kv_heads, scale, block_tables, seq_lens, block_size, \ + max_seq_len, alibi_slopes); #define CALL_V2_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ @@ -732,8 +732,8 @@ void paged_attention_v2(torch::Tensor &out, torch::Tensor &exp_sums, torch::Tensor &query, torch::Tensor &key_cache, torch::Tensor &value_cache, int num_kv_heads, float scale, torch::Tensor &block_tables, - torch::Tensor &context_lens, int block_size, - int max_context_len, + torch::Tensor &seq_lens, int block_size, + int max_seq_len, const c10::optional &alibi_slopes, const std::string &kv_cache_dtype, float kv_scale) { TORCH_CHECK(kv_scale == 1.0f); diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 7849a5df991b1..95e3f11900fde 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -8,16 +8,16 @@ template void copy_blocks_cpu_impl( std::vector &key_caches, std::vector &value_caches, - const std::vector> mapping_pairs, + const torch::Tensor& mapping_pairs, const int element_num_per_block, const int layer_num) { - const size_t pair_num = mapping_pairs.size(); + const size_t pair_num = mapping_pairs.size(0); const size_t block_bytes = sizeof(scalar_t) * element_num_per_block; #pragma omp parallel for collapse(2) for (int layer = 0; layer < layer_num; ++layer) { for (size_t pair = 0; pair < pair_num; ++pair) { - int64_t source_offset = element_num_per_block * mapping_pairs[pair].first; + int64_t source_offset = element_num_per_block * mapping_pairs[pair][0].item(); int64_t target_offset = - element_num_per_block * mapping_pairs[pair].second; + element_num_per_block * mapping_pairs[pair][1].item(); scalar_t *key_cache_ptr = key_caches[layer].data_ptr(); scalar_t *source_ptr = key_cache_ptr + source_offset; scalar_t *target_ptr = key_cache_ptr + target_offset; @@ -83,26 +83,18 @@ void reshape_and_cache_cpu_impl( void copy_blocks(std::vector &key_caches, std::vector &value_caches, - const std::map> &block_mapping) { + torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { return; } - std::vector> mapping_pairs; - mapping_pairs.reserve(block_mapping.size()); - for (const auto &pair : block_mapping) { - for (const auto &dst : pair.second) { - mapping_pairs.emplace_back(pair.first, dst); - } - } - const int element_num_per_block = key_caches[0][0].numel(); VLLM_DISPATCH_FLOATING_TYPES( key_caches[0].scalar_type(), "copy_blocks_cpu_impl", [&] { CPU_KERNEL_GUARD_IN(copy_blocks_cpu_impl) - copy_blocks_cpu_impl(key_caches, value_caches, mapping_pairs, + copy_blocks_cpu_impl(key_caches, value_caches, block_mapping, element_num_per_block, num_layers); CPU_KERNEL_GUARD_OUT(copy_blocks_cpu_impl) }); diff --git a/csrc/ops.h b/csrc/ops.h index 04b97d1784cd2..9541adcb3de88 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -10,9 +10,9 @@ void paged_attention_v1( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, + torch::Tensor& seq_lens, int block_size, - int max_context_len, + int max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale); @@ -28,9 +28,9 @@ void paged_attention_v2( int num_kv_heads, float scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, + torch::Tensor& seq_lens, int block_size, - int max_context_len, + int max_seq_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, float kv_scale); @@ -132,6 +132,7 @@ torch::Tensor gptq_marlin_gemm( torch::Tensor &g_idx, torch::Tensor &perm, torch::Tensor &workspace, + int64_t num_bits, int64_t size_m, int64_t size_n, int64_t size_k, @@ -141,7 +142,8 @@ torch::Tensor gptq_marlin_repack( torch::Tensor &b_q_weight, torch::Tensor &perm, int64_t size_k, - int64_t size_n); + int64_t size_n, + int64_t num_bits); #endif void squeezellm_gemm( diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 9839bfc0331c4..173e0b1732e13 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -96,6 +96,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "reshape_and_cache", &reshape_and_cache, "Reshape the key and value tensors and cache them"); + cache_ops.def( + "reshape_and_cache_flash", + &reshape_and_cache_flash, + "Reshape the key and value tensors and cache them"); cache_ops.def( "convert_fp8", &convert_fp8, diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index 2477051eb60d7..b9c5d39277ca5 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -17,6 +17,15 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { return old; } +#define FP8_E4M3_MAX std::numeric_limits::max() + +template +__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(const scalar_t val, const float scale) { + float x = static_cast(val) / scale; + float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); + return static_cast(r); +} + // Compute the absolute maximum m of the input tensor and store // m / float8_e4m3::max() in *scale. Each thread block performs a // reduction tree and the memory in scale is atomically updated. @@ -67,7 +76,7 @@ __global__ void scaled_fp8_quant_kernel( int64_t num_elems) { int i = blockDim.x * blockIdx.x + threadIdx.x; while (i < num_elems) { - out[i] = static_cast(input[i] / *scale); + out[i] = scaled_fp8_conversion(input[i], *scale); i += blockDim.x * gridDim.x; } } diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 9902f55167d89..fd0837f0cb39c 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -32,7 +32,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, int4 *__restrict__ out_int4_ptr, int size_m, int size_k, int block_rows) {} -template = 8.0"); return torch::empty({1, 1}); @@ -114,11 +115,21 @@ template __device__ inline int lop3(int a, int b, int c) { return res; } +// Constructs destination register by taking bytes from 2 sources (based on mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + // Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 // values. We mostly follow the strategy in the link below, with some small // changes: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h -__device__ inline FragB dequant(int q) { +__device__ inline FragB dequant_4bit(int q) { const int LO = 0x000f000f; const int HI = 0x00f000f0; const int EX = 0x64006400; @@ -139,6 +150,24 @@ __device__ inline FragB dequant(int q) { return frag_b; } +__device__ inline FragB dequant_8bit(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. __device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { @@ -162,6 +191,13 @@ __device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2, frag_b[1] = __hmul2(frag_b[1], s_val_3_4); } +// Given 2 floats multiply by 2 scales (halves) +__device__ inline void scale_float(float *c, FragS &s) { + __half *s_ptr = reinterpret_cast<__half *>(&s); + c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +} + // Wait until barrier reaches `count`, then lock for current threadblock. __device__ inline void barrier_acquire(int *lock, int count) { if (threadIdx.x == 0) { @@ -250,7 +286,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, } } -template ( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + +#pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } }; bool is_same_group[stages]; int same_group_id[stages]; auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); @@ -767,10 +828,23 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // dequantization and matmul operations. #pragma unroll for (int j = 0; j < 4; j++) { - int b_quant = frag_b_quant[k % 2][j]; - int b_quant_shift = b_quant >> 8; + FragB frag_b0; + FragB frag_b1; + if constexpr (num_bits == 4) { + int b_quant = frag_b_quant[k % 2][0][j]; + int b_quant_shift = b_quant >> 8; + + frag_b0 = dequant_4bit(b_quant); + frag_b1 = dequant_4bit(b_quant_shift); - FragB frag_b0 = dequant(b_quant); + } else { + int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + + frag_b0 = dequant_8bit(b_quant_0); + frag_b1 = dequant_8bit(b_quant_1); + } // Apply scale to frag_b0 if constexpr (has_act_order) { @@ -782,8 +856,6 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk } } - FragB frag_b1 = dequant(b_quant_shift); - // Apply scale to frag_b1 if constexpr (has_act_order) { scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], @@ -808,13 +880,13 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // multiple warps that accumulate their partial sums of the same output // location; which we have to reduce over in the end. We do in shared memory. auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride / 2; + constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride; - constexpr int red_sh_stride = b_sh_stride * 4 * 2; - constexpr int red_sh_delta = b_sh_stride; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + - (threadIdx.x % b_sh_stride); + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); // Parallel logarithmic shared memory reduction. We make sure to avoid any // unnecessary read or write iterations, e.g., for two warps we write only @@ -861,7 +933,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk }; // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped portioning + // finally have to globally reduce over the results. As the striped partitioning // minimizes the number of such reductions and our outputs are usually rather // small, we perform this reduction serially in L2 cache. auto global_reduce = [&](bool first = false, bool last = false) { @@ -951,13 +1023,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk auto write = [&](int idx, float c0, float c1, FragS &s) { half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - // For per-column quantization we finally apply the scale here - if constexpr (!has_act_order && group_blocks == -1) { + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) { res = __hmul2(res, s[0]); } ((half2 *)sh)[idx] = res; }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { @@ -1023,6 +1097,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // ensure all shared memory accesses are static. Note that both pipelines // have even length meaning that the next iteration will always start at // index 0. + #pragma unroll for (int pipe = 0; pipe < stages;) { #pragma unroll @@ -1070,23 +1145,63 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // For per-column scales, we only fetch them here in the final step before // write-out if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (num_bits == 8) { if (s_sh_wr_pred) { - cp_async4_stream(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } cp_async_fence(); + } else { + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } } } thread_block_reduce(); if constexpr (!has_act_order && group_blocks == -1) { - if (last) { + if constexpr (num_bits == 8) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } } } @@ -1125,28 +1240,25 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } - // if (blockIdx.x == 0 && threadIdx.x == 0) { - // printf("Move\n"); - // } start_pipes(); } } } } -#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ +#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ - else if (thread_m_blocks == THREAD_M_BLOCKS && \ + else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ num_threads == NUM_THREADS) { \ cudaFuncSetAttribute( \ - Marlin, \ + Marlin, \ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ + Marlin \ <<>>( \ A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ prob_k, locks); \ @@ -1158,28 +1270,92 @@ typedef struct { int num_threads; } thread_config_t; -thread_config_t small_batch_thread_configs[] = { +typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; +} exec_config_t; + +thread_config_t thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N + {64, 256, 256}, // Default (max cache usage) + {64, 128, 128}, // Reduce N, reduce warps + {128, 64, 128}, // Reduce N more, but increase K + }; -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority +int get_scales_cache_size(thread_config_t const &th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 128, 128}, // Reduce N 2X, same K - // {128, 64, 128}, // Reduce N 4X, increase K 2X -}; + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = + tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } +} + +bool is_valid_cache_size(thread_config_t const &th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = div_ceil(prob_m, 16); + int tb_max_m = 16; -bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, - int prob_k) { + while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + max_m_blocks--; + if (max_m_blocks == 0) { + TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + } + } + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = (a_size + b_size) * pipe_stages; + + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + + return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); +} + +bool is_valid_config(thread_config_t const &th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -1201,62 +1377,79 @@ bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, return false; } + // Determine cache for scales + int scales_cache_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + return true; } -thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { - - // TODO: Enable if needed after some more testing - if (prob_m <= 0) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; +exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + bool has_act_order, bool is_k_full, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + for (auto th_config : thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; } } - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { - return th_config; - } - } + printf("WARNING: Marlin kernel is reducing max_m_blocks due to small SM " + "GPU cache. This may " + "hurt performance. Consider upgrading your GPU.\n"); + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage } - return thread_config_t{-1, -1, -1}; + return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ +#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) - -void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx, - void *perm, void *a_tmp, int prob_m, int prob_n, int prob_k, - void *workspace, bool has_act_order, bool is_k_full, - int num_groups, int group_size, int dev = 0, - cudaStream_t stream = 0, int thread_k = -1, int thread_n = -1, - int sms = -1, int max_par = 16) { + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + +void marlin_mm_f16i4(const void *A, const void *B, void *C, void *s, + void *g_idx, void *perm, void *a_tmp, int prob_m, + int prob_n, int prob_k, void *workspace, int num_bits, + bool has_act_order, bool is_k_full, int num_groups, + int group_size, int dev, cudaStream_t stream, int thread_k, + int thread_n, int sms, int max_par) { + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1274,25 +1467,34 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx, TORCH_CHECK(max_shared_mem > 0); // Set thread config - thread_config_t th_config; + exec_config_t exec_cfg; if (thread_k != -1 && thread_n != -1) { // User-defined config - th_config = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = + exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; } else { // Auto config - th_config = determine_thread_config(prob_m, prob_n, prob_k); + exec_cfg = + determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem); } - TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k), - "Invalid thread config: thread_k = " + str(th_config.thread_k) + - ", thread_n = " + str(th_config.thread_n) + - ", num_threads = " + str(th_config.num_threads) + - " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + - str(prob_n) + "]"); - - int num_threads = th_config.num_threads; - thread_k = th_config.thread_k; - thread_n = th_config.thread_n; + TORCH_CHECK(exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, + prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; @@ -1352,28 +1554,32 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx, } // Main loop - for (int i = 0; i < tot_m_blocks; i += 4) { + for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { int thread_m_blocks = tot_m_blocks - i; prob_m = tot_m - 16 * i; int par = 1; - if (thread_m_blocks > 4) { + if (thread_m_blocks > exec_cfg.max_m_blocks) { // Note that parallel > 1 currently only works for inputs without any // padding - par = (16 * thread_m_blocks - pad) / 64; + par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); if (par > max_par) par = max_par; - prob_m = 64 * par; - i += 4 * (par - 1); - thread_m_blocks = 4; + prob_m = (16 * exec_cfg.max_m_blocks) * par; + i += exec_cfg.max_m_blocks * (par - 1); + thread_m_blocks = exec_cfg.max_m_blocks; } // Define kernel configurations if (false) { } - CALL_IF(16, 4, 256) - CALL_IF(8, 8, 256) - CALL_IF(8, 4, 128) - CALL_IF(4, 8, 128) + CALL_IF(4, 32, 2, 256) + CALL_IF(4, 16, 4, 256) + CALL_IF(4, 8, 4, 128) + CALL_IF(4, 4, 8, 128) + CALL_IF(8, 32, 2, 256) + CALL_IF(8, 16, 4, 256) + CALL_IF(8, 8, 4, 128) + CALL_IF(8, 4, 8, 128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + @@ -1395,33 +1601,32 @@ void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx, torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, torch::Tensor &b_scales, torch::Tensor &g_idx, torch::Tensor &perm, torch::Tensor &workspace, - int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full) { + int64_t num_bits, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full) { + // Verify num_bits + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int pack_factor = 32 / num_bits; + // Verify A - TORCH_CHECK(a.size(0) == size_m, - "Shape mismatch: a.size(0) = " + str(a.size(0)) + - ", size_m = " + str(size_m)); - TORCH_CHECK(a.size(1) == size_k, - "Shape mismatch: a.size(1) = " + str(a.size(1)) + - ", size_k = " + str(size_k)); + TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), + ", size_m = ", size_m); + TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), + ", size_k = ", size_k); // Verify B - TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, - "size_k = " + str(size_k) + " is not divisible by tile_size = " + - str(gptq_marlin::tile_size)); + TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", gptq_marlin::tile_size); TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = " + - str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + - ", tile_size = " + str(gptq_marlin::tile_size)); - TORCH_CHECK( - b_q_weight.size(1) % gptq_marlin::tile_size == 0, - "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + - " is not divisible by tile_size = " + str(gptq_marlin::tile_size)); - int actual_size_n = (b_q_weight.size(1) / gptq_marlin::tile_size) * - gptq_marlin::pack_factor_4bit; - TORCH_CHECK(size_n == actual_size_n, - "size_n = " + str(size_n) + - ", actual_size_n = " + str(actual_size_n)); + "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), + ", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size); + TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0, + "b_q_weight.size(1) = ", b_q_weight.size(1), + " is not divisible by tile_size = ", gptq_marlin::tile_size); + int actual_size_n = + (b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor; + TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, + ", actual_size_n = ", actual_size_n); // Verify device and strides TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); @@ -1457,9 +1662,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, // Verify g_idx and perm TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) || (g_idx.size(0) == size_k && perm.size(0) == size_k), - "Unexpected g_idx.size(0) = " + str(g_idx.size(0)) + - " and perm.size(0) = " + str(perm.size(0)) + - ", where size_k = " + str(size_k)); + "Unexpected g_idx.size(0) = ", g_idx.size(0), + " and perm.size(0) = ", perm.size(0), + ", where size_k = ", size_k); // Detect groupsize and act_order int num_groups = -1; @@ -1475,9 +1680,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, if (has_act_order) { if (is_k_full) { TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); - TORCH_CHECK(size_k % num_groups == 0, - "size_k = " + str(size_k) + - ", is not divisible by num_groups = " + str(num_groups)); + TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by num_groups = ", num_groups); group_size = size_k / num_groups; } else { group_size = 0; @@ -1485,10 +1689,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, } else { if (num_groups > 1) { - TORCH_CHECK(size_k % num_groups == 0, - "size_k = " + str(size_k) + - ", is not divisible by b_scales.size(0) = " + - str(b_scales.size(0))); + TORCH_CHECK( + size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); group_size = size_k / num_groups; } else { group_size = -1; @@ -1496,23 +1699,22 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, } // Verify workspace size - TORCH_CHECK(size_n % gptq_marlin::min_thread_n == 0, - "size_n = " + str(size_n) + - ", is not divisible by min_thread_n = " + - str(gptq_marlin::min_thread_n)); + TORCH_CHECK( + size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n, + ", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n); int min_workspace_size = (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = " + str(workspace.numel()) + - " is below min_workspace_size = " + str(min_workspace_size)); + "workspace.numel = ", workspace.numel(), + " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); - gptq_marlin::marlin_cuda( + gptq_marlin::marlin_mm_f16i4( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, - size_k, workspace.data_ptr(), has_act_order, is_k_full, num_groups, - group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, - sms, gptq_marlin::max_par); + size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, + num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_n, sms, gptq_marlin::max_par); return c; } diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cuh b/csrc/quantization/gptq_marlin/gptq_marlin.cuh index 8cfce6b2575d5..35ea48aaba310 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cuh +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cuh @@ -24,8 +24,6 @@ static constexpr int min_thread_k = 64; static constexpr int tile_size = 16; static constexpr int max_par = 16; -static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 bit - template struct Vec { T elems[n]; @@ -51,13 +49,11 @@ __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool "r"(smem), "l"(glob_ptr), "n"(BYTES)); } -__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); asm volatile("{\n" - " .reg .b64 p;\n" - " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" - " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" "}\n" ::"r"(smem), "l"(glob_ptr), "n"(BYTES)); } diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu index fa45ce68a0c77..0d3da6240dbca 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -11,7 +11,7 @@ static constexpr int tile_n_size = tile_k_size * 4; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -template +template __global__ void marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, uint32_t const *__restrict__ perm_ptr, @@ -20,7 +20,8 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, } // namespace gptq_marlin torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, - int64_t size_k, int64_t size_n) { + int64_t size_k, int64_t size_n, + int64_t num_bits) { TORCH_CHECK_NOT_IMPLEMENTED( false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); @@ -28,11 +29,13 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, #else -template +template __global__ void marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, uint32_t const *__restrict__ perm_ptr, uint32_t *__restrict__ out_ptr, int size_k, int size_n) { + constexpr int pack_factor = 32 / num_bits; + int k_tiles = size_k / tile_k_size; int n_tiles = size_n / tile_n_size; int block_k_tiles = div_ceil(k_tiles, gridDim.x); @@ -64,9 +67,10 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, sh_pipe_ptr += perm_size; } + constexpr int tile_ints = tile_k_size / pack_factor; + constexpr int stage_n_threads = tile_n_size / 4; - constexpr int stage_k_threads = - has_perm ? tile_k_size : tile_k_size / pack_factor_4bit; + constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; constexpr int stage_size = stage_k_threads * stage_n_threads; auto load_perm_to_shared = [&](int k_tile_id) { @@ -99,9 +103,9 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, reinterpret_cast(sh_perm_ptr); int src_k = sh_perm_int_ptr[k_id]; - int src_k_packed = src_k / pack_factor_4bit; + int src_k_packed = src_k / pack_factor; - cp_async4_stream( + cp_async4( &sh_ptr[k_id * stage_n_threads + n_id], reinterpret_cast(&( b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); @@ -113,12 +117,12 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, int n_id = threadIdx.x % stage_n_threads; int first_k = k_tile_id * tile_k_size; - int first_k_packed = first_k / pack_factor_4bit; + int first_k_packed = first_k / pack_factor; - cp_async4_stream(&sh_ptr[k_id * stage_n_threads + n_id], - reinterpret_cast( - &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + - first_n + (n_id * 4)]))); + cp_async4(&sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + + first_n + (n_id * 4)]))); } } @@ -145,26 +149,27 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, int cur_n = warp_id * 16 + tc_col; constexpr int sh_stride = 64; + constexpr uint32_t mask = (1 << num_bits) - 1; int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); uint32_t *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); - uint32_t vals[pack_factor_4bit]; + uint32_t vals[8]; if constexpr (has_perm) { for (int i = 0; i < 4; i++) { int k_idx = tc_row + tc_offsets[i]; uint32_t src_k = sh_perm_int_ptr[k_idx]; - uint32_t src_k_pos = src_k % pack_factor_4bit; + uint32_t src_k_pos = src_k % pack_factor; uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; - uint32_t b1_cur_val = (b1_val >> (src_k_pos * 4)) & 0xf; + uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; - uint32_t b2_cur_val = (b2_val >> (src_k_pos * 4)) & 0xf; + uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; vals[i] = b1_cur_val; vals[4 + i] = b2_cur_val; @@ -172,41 +177,56 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, } else { - uint32_t b1_val_1 = sh_stage_int_ptr[cur_n]; - uint32_t b1_val_2 = sh_stage_int_ptr[sh_stride + cur_n]; - - uint32_t b2_val_1 = sh_stage_int_ptr[cur_n + 8]; - uint32_t b2_val_2 = sh_stage_int_ptr[sh_stride + cur_n + 8]; + uint32_t b1_vals[tile_ints]; + uint32_t b2_vals[tile_ints]; #pragma unroll - for (int i = 0; i < 2; i++) { - int cur_elem = tc_row + tc_offsets[i]; - vals[i] = (b1_val_1 >> (cur_elem * 4)) & 0xf; - vals[4 + i] = (b2_val_1 >> (cur_elem * 4)) & 0xf; + for (int i = 0; i < tile_ints; i++) { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; + b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; } #pragma unroll - for (int i = 2; i < 4; i++) { - int cur_elem = tc_row + tc_offsets[i] - 8; - vals[i] = (b1_val_2 >> (cur_elem * 4)) & 0xf; - vals[4 + i] = (b2_val_2 >> (cur_elem * 4)) & 0xf; + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; + int cur_int = cur_elem / pack_factor; + int cur_pos = cur_elem % pack_factor; + + vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; + vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; } } + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + // Result of: // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h - constexpr int pack_idx[pack_factor_4bit] = {0, 2, 4, 6, 1, 3, 5, 7}; + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; - uint32_t res = 0; + uint32_t res = 0; #pragma unroll - for (int i = 0; i < pack_factor_4bit; i++) { - res |= vals[pack_idx[i]] << (i * 4); - } + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } - constexpr int tile_size = tile_k_size * tile_n_size / pack_factor_4bit; - int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + out_ptr[out_offset + th_id * 4 + warp_id] = res; - out_ptr[out_offset + th_id * 4 + warp_id] = res; + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; +#pragma unroll + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } }; auto start_pipes = [&](int k_tile_id, int n_tile_id) { @@ -242,19 +262,35 @@ marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, } // namespace gptq_marlin +#define CALL_IF(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + cudaFuncSetAttribute( \ + gptq_marlin::marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + gptq_marlin::marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ + } + torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, - int64_t size_k, int64_t size_n) { + int64_t size_k, int64_t size_n, + int64_t num_bits) { // Verify compatibility with marlin tile of 16x64 TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); + TORCH_CHECK(num_bits == 4 || num_bits == 8, + "num_bits must be 4 or 8. Got = ", num_bits); + int const pack_factor = 32 / num_bits; + // Verify B - TORCH_CHECK((size_k / gptq_marlin::pack_factor_4bit) == b_q_weight.size(0), + TORCH_CHECK((size_k / pack_factor) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, - ", pack_factor_4bit = ", gptq_marlin::pack_factor_4bit); + ", size_k = ", size_k, ", pack_factor = ", pack_factor); TORCH_CHECK(b_q_weight.size(1) == size_n, "b_q_weight.size(1) = ", b_q_weight.size(1), " is not size_n = ", size_n); @@ -273,10 +309,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, auto options = torch::TensorOptions() .dtype(b_q_weight.dtype()) .device(b_q_weight.device()); - torch::Tensor out = torch::empty( - {size_k / gptq_marlin::tile_size, - size_n * gptq_marlin::tile_size / gptq_marlin::pack_factor_4bit}, - options); + torch::Tensor out = + torch::empty({size_k / gptq_marlin::tile_size, + size_n * gptq_marlin::tile_size / pack_factor}, + options); // Detect if there is act_order bool has_perm = perm.size(0) != 0; @@ -299,23 +335,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); TORCH_CHECK(max_shared_mem > 0); - if (has_perm) { - cudaFuncSetAttribute( - gptq_marlin::marlin_repack_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - max_shared_mem); - gptq_marlin::marlin_repack_kernel - <<>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); - - } else { - cudaFuncSetAttribute( - gptq_marlin::marlin_repack_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - max_shared_mem); - gptq_marlin::marlin_repack_kernel - <<>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); + if (false) { + } + CALL_IF(4, false) + CALL_IF(4, true) + CALL_IF(8, false) + CALL_IF(8, true) + else { + TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits, + ", has_perm = ", has_perm); } return out; diff --git a/docs/source/assets/dev/dockerfile-stages-dependency.png b/docs/source/assets/dev/dockerfile-stages-dependency.png new file mode 100644 index 0000000000000..b016531f1e0a0 Binary files /dev/null and b/docs/source/assets/dev/dockerfile-stages-dependency.png differ diff --git a/docs/source/dev/dockerfile/dockerfile.rst b/docs/source/dev/dockerfile/dockerfile.rst new file mode 100644 index 0000000000000..a07463392dbe8 --- /dev/null +++ b/docs/source/dev/dockerfile/dockerfile.rst @@ -0,0 +1,50 @@ +Dockerfile +==================== + +See `here `_ for the main Dockerfile to construct +the image for running an OpenAI compatible server with vLLM. + +- Below is a visual representation of the multi-stage Dockerfile. The build graph contains the following nodes: + + - All build stages + - The default build target (highlighted in grey) + - External images (with dashed borders) + + The edges of the build graph represent: + + - FROM ... dependencies (with a solid line and a full arrow head) + - COPY --from=... dependencies (with a dashed line and an empty arrow head) + - RUN --mount=(.*)from=... dependencies (with a dotted line and an empty diamond arrow head) + + .. figure:: ../../assets/dev/dockerfile-stages-dependency.png + :alt: query + :width: 100% + :align: center + + Made using: https://github.com/patrickhoefler/dockerfilegraph + + Commands to regenerate the build graph (make sure to run it **from the `root` directory of the vLLM repository** where the dockerfile is present): + + .. code:: bash + + dockerfilegraph -o png --legend --dpi 200 --max-label-length 50 --filename Dockerfile + + or in case you want to run it directly with the docker image: + + .. code:: bash + + docker run \ + --rm \ + --user "$(id -u):$(id -g)" \ + --workdir /workspace \ + --volume "$(pwd)":/workspace \ + ghcr.io/patrickhoefler/dockerfilegraph:alpine \ + --output png \ + --dpi 200 \ + --max-label-length 50 \ + --filename Dockerfile \ + --legend + + (To run it for a different file, you can pass in a different argument to the flag `--filename`.) + + \ No newline at end of file diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index e7826114ffa9d..0c81f7ec6d2a9 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -53,6 +53,7 @@ You can also build and install vLLM from source: $ git clone https://github.com/vllm-project/vllm.git $ cd vllm + $ # export VLLM_INSTALL_PUNICA_KERNELS=1 # optionally build for multi-LoRA capability $ pip install -e . # This may take 5-10 minutes. .. tip:: diff --git a/docs/source/index.rst b/docs/source/index.rst index e8daa5f052754..4022c590843e6 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -75,6 +75,7 @@ Documentation serving/deploying_with_docker serving/distributed_serving serving/metrics + serving/env_vars serving/usage_stats serving/integrations @@ -86,6 +87,7 @@ Documentation models/adding_model models/engine_args models/lora + models/performance .. toctree:: :maxdepth: 1 @@ -102,6 +104,7 @@ Documentation dev/sampling_params dev/engine/engine_index dev/kernel/paged_attention + dev/dockerfile/dockerfile Indices and tables ================== diff --git a/docs/source/models/performance.rst b/docs/source/models/performance.rst new file mode 100644 index 0000000000000..067757699f32a --- /dev/null +++ b/docs/source/models/performance.rst @@ -0,0 +1,38 @@ +.. _performance: + +Performance and Tuning +====================== + +Chunked Prefill +--------------- +vLLM supports an experimental feature chunked prefill. Chunked prefill allows to chunk large prefills into smaller chunks and batch them together with decode requests. + +You can enable the feature by specifying + +.. code-block:: python + + llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True) + # Set max_num_batched_tokens to tune performance. + # NOTE: 512 is the default max_num_batched_tokens for chunked prefill. + # llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True, max_num_batched_tokens=512) + +By default, vLLM scheduler prioritizes prefills and doesn't batch prefill and decode to the same batch. This policy optimizes the TTFT (time to thefirst token), but incurs slower ITL (inter token latency) and inefficient GPU utilization. + +Once chunked prefill is enabled, the policy is changed to + +- prioritize decode requests. It batches all pending decode requests to the batch before scheduling any prefill. +- When there are available token_budget (`max_num_batched_tokens`), it schedules pending prefills. If a last pending prefill request cannot fit into `max_num_batched_tokens`, it chunks it. + +This policy has two benefits. + +- It improves ITL (inter token latency) and generation decode because decode requests are prioritized. +- It helps achieve better GPU utilization by locating compute-bound (prefill) and memory-bound (decode) requests to the same batch. + +You can tune the performance by changing `max_num_batched_tokens`. +By default, it is set to 512, which has the best ITL on A100 in the initial benchmark. +Smaller batch size achieves better ITL because there are fewer prefills interrupting decodes. +Higher batch size achieves better TTFT as you can put more prefill to the batch. +If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the default scheduling policy (except that it still prioritizes decodes). +Note that the default batch size (512) is optimized for ITL, and it may have lower throughput than the default scheduler. We recommend you set `max_num_batched_tokens > 2048` for throughput. + +See related papers for more details (https://arxiv.org/pdf/2401.08671 or https://arxiv.org/pdf/2308.16369). diff --git a/docs/source/serving/env_vars.rst b/docs/source/serving/env_vars.rst new file mode 100644 index 0000000000000..0ce1374a3967b --- /dev/null +++ b/docs/source/serving/env_vars.rst @@ -0,0 +1,9 @@ +Environment Variables +======================== + +vLLM uses the following environment variables to configure the system: + +.. literalinclude:: ../../../vllm/envs.py + :language: python + :start-after: begin-env-vars-definition + :end-before: end-env-vars-definition diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 388b5daa79a92..c157d8ba998da 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -4,7 +4,7 @@ vLLM provides an HTTP server that implements OpenAI's [Completions](https://plat You can start the server using Python, or using [Docker](deploying_with_docker.rst): ```bash -python -m vllm.entrypoints.openai.api_server --model mistralai/Mistral-7B-Instruct-v0.2 --dtype auto --api-key token-abc123 +python -m vllm.entrypoints.openai.api_server --model NousResearch/Meta-Llama-3-8B-Instruct --dtype auto --api-key token-abc123 ``` To call the server, you can use the official OpenAI Python client library, or any other HTTP client. @@ -16,7 +16,7 @@ client = OpenAI( ) completion = client.chat.completions.create( - model="mistralai/Mistral-7B-Instruct-v0.2", + model="NousResearch/Meta-Llama-3-8B-Instruct", messages=[ {"role": "user", "content": "Hello!"} ] @@ -37,7 +37,7 @@ Or directly merge them into the JSON payload if you are using HTTP call directly ```python completion = client.chat.completions.create( - model="mistralai/Mistral-7B-Instruct-v0.2", + model="NousResearch/Meta-Llama-3-8B-Instruct", messages=[ {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} ], @@ -87,7 +87,7 @@ In order for the language model to support chat protocol, vLLM requires the mode a chat template in its tokenizer configuration. The chat template is a Jinja2 template that specifies how are roles, messages, and other chat-specific tokens are encoded in the input. -An example chat template for `mistralai/Mistral-7B-Instruct-v0.2` can be found [here](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2#instruction-format) +An example chat template for `NousResearch/Meta-Llama-3-8B-Instruct` can be found [here](https://github.com/meta-llama/llama3?tab=readme-ov-file#instruction-tuned-models) Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model, you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat diff --git a/examples/logging_configuration.md b/examples/logging_configuration.md new file mode 100644 index 0000000000000..75b4b31a80462 --- /dev/null +++ b/examples/logging_configuration.md @@ -0,0 +1,178 @@ +# Logging Configuration + +vLLM leverages Python's `logging.config.dictConfig` functionality to enable +robust and flexible configuration of the various loggers used by vLLM. + +vLLM offers two environment variables that can be used to accommodate a range +of logging configurations that range from simple-and-inflexible to +more-complex-and-more-flexible. + +- No vLLM logging (simple and inflexible) + - Set `VLLM_CONFIGURE_LOGGING=0` (leaving `VLLM_LOGGING_CONFIG_PATH` unset) +- vLLM's default logging configuration (simple and inflexible) + - Leave `VLLM_CONFIGURE_LOGGING` unset or set `VLLM_CONFIGURE_LOGGING=1` +- Fine-grained custom logging configuration (more complex, more flexible) + - Leave `VLLM_CONFIGURE_LOGGING` unset or set `VLLM_CONFIGURE_LOGGING=1` and + set `VLLM_LOGGING_CONFIG_PATH=` + + +## Logging Configuration Environment Variables + +### `VLLM_CONFIGURE_LOGGING` + +`VLLM_CONFIGURE_LOGGING` controls whether or not vLLM takes any action to +configure the loggers used by vLLM. This functionality is enabled by default, +but can be disabled by setting `VLLM_CONFIGURE_LOGGING=0` when running vLLM. + +If `VLLM_CONFIGURE_LOGGING` is enabled and no value is given for +`VLLM_LOGGING_CONFIG_PATH`, vLLM will use built-in default configuration to +configure the root vLLM logger. By default, no other vLLM loggers are +configured and, as such, all vLLM loggers defer to the root vLLM logger to make +all logging decisions. + +If `VLLM_CONFIGURE_LOGGING` is disabled and a value is given for +`VLLM_LOGGING_CONFIG_PATH`, an error will occur while starting vLLM. + +### `VLLM_LOGGING_CONFIG_PATH` + +`VLLM_LOGGING_CONFIG_PATH` allows users to specify a path to a JSON file of +alternative, custom logging configuration that will be used instead of vLLM's +built-in default logging configuration. The logging configuration should be +provided in JSON format following the schema specified by Python's [logging +configuration dictionary +schema](https://docs.python.org/3/library/logging.config.html#dictionary-schema-details). + +If `VLLM_LOGGING_CONFIG_PATH` is specified, but `VLLM_CONFIGURE_LOGGING` is +disabled, an error will occur while starting vLLM. + + +## Examples + +### Example 1: Customize vLLM root logger + +For this example, we will customize the vLLM root logger to use +[`python-json-logger`](https://github.com/madzak/python-json-logger) to log to +STDOUT of the console in JSON format with a log level of `INFO`. + +To begin, first, create an appropriate JSON logging configuration file: + +**/path/to/logging_config.json:** + +```json +{ + "formatters": { + "json": { + "class": "pythonjsonlogger.jsonlogger.JsonFormatter" + } + }, + "handlers": { + "console": { + "class" : "logging.StreamHandler", + "formatter": "json", + "level": "INFO", + "stream": "ext://sys.stdout" + } + }, + "loggers": { + "vllm": { + "handlers": ["console"], + "level": "INFO", + "propagate": false + } + }, + "version": 1 +} +``` + +Next, install the `python-json-logger` package if it's not already installed: + +```bash +pip install python-json-logger +``` + +Finally, run vLLM with the `VLLM_LOGGING_CONFIG_PATH` environment variable set +to the path of the custom logging configuration JSON file: + +```bash +VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \ + python3 -m vllm.entrypoints.openai.api_server \ + --max-model-len 2048 \ + --model mistralai/Mistral-7B-v0.1 +``` + + +### Example 2: Silence a particular vLLM logger + +To silence a particular vLLM logger, it is necessary to provide custom logging +configuration for the target logger that configures the logger so that it won't +propagate its log messages to the root vLLM logger. + +When custom configuration is provided for any logger, it is also necessary to +provide configuration for the root vLLM logger since any custom logger +configuration overrides the built-in default logging configuration used by vLLM. + +First, create an appropriate JSON logging configuration file that includes +configuration for the root vLLM logger and for the logger you wish to silence: + +**/path/to/logging_config.json:** + +```json +{ + "formatters": { + "vllm": { + "class": "vllm.logging.NewLineFormatter", + "datefmt": "%m-%d %H:%M:%S", + "format": "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" + } + }, + "handlers": { + "vllm": { + "class" : "logging.StreamHandler", + "formatter": "vllm", + "level": "INFO", + "stream": "ext://sys.stdout" + } + }, + "loggers": { + "vllm": { + "handlers": ["vllm"], + "level": "DEBUG", + "propagage": false + }, + "vllm.example_noisy_logger": { + "propagate": false + } + }, + "version": 1 +} +``` + +Finally, run vLLM with the `VLLM_LOGGING_CONFIG_PATH` environment variable set +to the path of the custom logging configuration JSON file: + +```bash +VLLM_LOGGING_CONFIG_PATH=/path/to/logging_config.json \ + python3 -m vllm.entrypoints.openai.api_server \ + --max-model-len 2048 \ + --model mistralai/Mistral-7B-v0.1 +``` + + +### Example 3: Disable vLLM default logging configuration + +To disable vLLM's default logging configuration and silence all vLLM loggers, +simple set `VLLM_CONFIGURE_LOGGING=0` when running vLLM. This will prevent vLLM +for configuring the root vLLM logger, which in turn, silences all other vLLM +loggers. + +```bash +VLLM_CONFIGURE_LOGGING=0 \ + python3 -m vllm.entrypoints.openai.api_server \ + --max-model-len 2048 \ + --model mistralai/Mistral-7B-v0.1 +``` + + +## Additional resources + +- [`logging.config` Dictionary Schema Details](https://docs.python.org/3/library/logging.config.html#dictionary-schema-details) diff --git a/format.sh b/format.sh index bd12e61d77806..233e6af0c9479 100755 --- a/format.sh +++ b/format.sh @@ -95,7 +95,7 @@ echo 'vLLM yapf: Done' # Run mypy echo 'vLLM mypy:' mypy vllm/attention --config-file pyproject.toml -mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/core --config-file pyproject.toml mypy vllm/distributed --config-file pyproject.toml mypy vllm/entrypoints --config-file pyproject.toml mypy vllm/executor --config-file pyproject.toml @@ -107,6 +107,8 @@ mypy vllm/worker --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml +mypy vllm/logging --config-file pyproject.toml +mypy vllm/model_executor --config-file pyproject.toml CODESPELL_EXCLUDES=( diff --git a/requirements-common.txt b/requirements-common.txt index 3abb828116680..bd779d5acb68e 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -14,7 +14,7 @@ pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 prometheus-fastapi-instrumentator >= 7.0.0 tiktoken == 0.6.0 # Required for DBRX tokenizer -lm-format-enforcer == 0.9.8 +lm-format-enforcer == 0.10.1 outlines == 0.0.34 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 diff --git a/requirements-dev.txt b/requirements-dev.txt index 324039186142b..e6d375cbafa39 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,7 +14,7 @@ types-setuptools # testing pytest -tensorizer==2.9.0a0 +tensorizer==2.9.0 pytest-forked pytest-asyncio pytest-rerunfailures diff --git a/setup.py b/setup.py index 6ba36b85ea318..3768daf9d6fab 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +import importlib.util import io import logging import os @@ -13,10 +14,23 @@ from setuptools.command.build_ext import build_ext from torch.utils.cpp_extension import CUDA_HOME + +def load_module_from_path(module_name, path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + ROOT_DIR = os.path.dirname(__file__) logger = logging.getLogger(__name__) -# Target device of vLLM, supporting [cuda (by default), rocm, neuron, cpu] -VLLM_TARGET_DEVICE = os.getenv("VLLM_TARGET_DEVICE", "cuda") + +# cannot import envs directly because it depends on vllm, +# which is not installed yet +envs = load_module_from_path('envs', os.path.join(ROOT_DIR, 'vllm', 'envs.py')) + +VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE # vLLM only supports Linux platform assert sys.platform.startswith( @@ -60,7 +74,7 @@ class cmake_build_ext(build_ext): def compute_num_jobs(self): # `num_jobs` is either the value of the MAX_JOBS environment variable # (if defined) or the number of CPUs available. - num_jobs = os.environ.get("MAX_JOBS", None) + num_jobs = envs.MAX_JOBS if num_jobs is not None: num_jobs = int(num_jobs) logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs) @@ -78,7 +92,7 @@ def compute_num_jobs(self): # environment variable (if defined) or 1. # when it is set, we reduce `num_jobs` to avoid # overloading the system. - nvcc_threads = os.getenv("NVCC_THREADS", None) + nvcc_threads = envs.NVCC_THREADS if nvcc_threads is not None: nvcc_threads = int(nvcc_threads) logger.info( @@ -104,7 +118,7 @@ def configure(self, ext: CMakeExtension) -> None: # Select the build type. # Note: optimization level + debug info are set by the build type default_cfg = "Debug" if self.debug else "RelWithDebInfo" - cfg = os.getenv("CMAKE_BUILD_TYPE", default_cfg) + cfg = envs.CMAKE_BUILD_TYPE or default_cfg # where .so files will be written, should be the same for all extensions # that use the same CMakeLists.txt. @@ -118,7 +132,7 @@ def configure(self, ext: CMakeExtension) -> None: '-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE), ] - verbose = bool(int(os.getenv('VERBOSE', '0'))) + verbose = envs.VERBOSE if verbose: cmake_args += ['-DCMAKE_VERBOSE_MAKEFILE=ON'] @@ -205,8 +219,7 @@ def _is_neuron() -> bool: subprocess.run(["neuron-ls"], capture_output=True, check=True) except (FileNotFoundError, PermissionError, subprocess.CalledProcessError): torch_neuronx_installed = False - return torch_neuronx_installed or os.environ.get("VLLM_BUILD_WITH_NEURON", - False) + return torch_neuronx_installed or envs.VLLM_BUILD_WITH_NEURON def _is_cpu() -> bool: @@ -214,7 +227,7 @@ def _is_cpu() -> bool: def _install_punica() -> bool: - return bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))) + return envs.VLLM_INSTALL_PUNICA_KERNELS def get_hipcc_rocm_version(): @@ -377,7 +390,8 @@ def _read_requirements(filename: str) -> List[str]: package_data = { "vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"] } -if os.environ.get("VLLM_USE_PRECOMPILED"): +if envs.VLLM_USE_PRECOMPILED: + ext_modules = [] package_data["vllm"].append("*.so") setup( @@ -403,12 +417,12 @@ def _read_requirements(filename: str) -> List[str]: "Topic :: Scientific/Engineering :: Artificial Intelligence", ], packages=find_packages(exclude=("benchmarks", "csrc", "docs", "examples", - "tests")), + "tests*")), python_requires=">=3.8", install_requires=get_requirements(), ext_modules=ext_modules, extras_require={ - "tensorizer": ["tensorizer==2.9.0a1"], + "tensorizer": ["tensorizer==2.9.0"], }, cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {}, package_data=package_data, diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py index 8d6ad6706fb0e..64bcba67c3437 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/async_engine/test_chat_template.py @@ -60,12 +60,13 @@ class MockServingChat: tokenizer: MockTokenizer -def test_load_chat_template(): +@pytest.mark.asyncio +async def test_load_chat_template(): # Testing chatml template tokenizer = MockTokenizer() mock_serving_chat = MockServingChat(tokenizer) - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=chatml_jinja_path) + await OpenAIServingChat._load_chat_template( + mock_serving_chat, chat_template=chatml_jinja_path) template_content = tokenizer.chat_template @@ -76,7 +77,8 @@ def test_load_chat_template(): {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501 -def test_no_load_chat_template_filelike(): +@pytest.mark.asyncio +async def test_no_load_chat_template_filelike(): # Testing chatml template template = "../../examples/does_not_exist" tokenizer = MockTokenizer() @@ -84,18 +86,19 @@ def test_no_load_chat_template_filelike(): mock_serving_chat = MockServingChat(tokenizer) with pytest.raises(ValueError, match="looks like a file path"): - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + await OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) -def test_no_load_chat_template_literallike(): +@pytest.mark.asyncio +async def test_no_load_chat_template_literallike(): # Testing chatml template template = "{{ messages }}" tokenizer = MockTokenizer() mock_serving_chat = MockServingChat(tokenizer) - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + await OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) template_content = tokenizer.chat_template assert template_content == template @@ -110,8 +113,8 @@ async def test_get_gen_prompt(model, template, add_generation_prompt, # Initialize the tokenizer tokenizer = get_tokenizer(tokenizer_name=model) mock_serving_chat = MockServingChat(tokenizer) - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + await OpenAIServingChat._load_chat_template(mock_serving_chat, + chat_template=template) # Create a mock request object using keyword arguments mock_request = ChatCompletionRequest( diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 97cff623c5e1d..d75279dd9cfa9 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -2,12 +2,15 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ +import os + import pytest MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", ] +VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" @pytest.mark.parametrize("model", MODELS) @@ -23,11 +26,18 @@ def test_models( max_tokens: int, enforce_eager: bool, ) -> None: + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) + if backend_by_env_var == "FLASHINFER" and enforce_eager is False: + pytest.skip("Skipping non-eager test for FlashInferBackend.") + hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager) + vllm_model = vllm_runner(model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index d83416eb51b43..47d582c726c66 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -55,7 +55,6 @@ def test_models( ) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model - print(vllm_outputs[0]) for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py new file mode 100644 index 0000000000000..ffb0717b3bfdb --- /dev/null +++ b/tests/basic_correctness/test_preemption.py @@ -0,0 +1,223 @@ +"""Compare the short outputs of HF and vLLM when using greedy sampling. + +VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 has to be set before running this test. + +Run `VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 +pytest tests/basic_correctness/test_preemption.py`. +""" +import pytest + +from vllm import SamplingParams +from vllm.core.scheduler import (ARTIFICIAL_PREEMPTION_MAX_CNT, + ENABLE_ARTIFICIAL_PREEMPT) + +MODELS = [ + "facebook/opt-125m", +] + +assert ENABLE_ARTIFICIAL_PREEMPT is True, ( + "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. " + "`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest " + "tests/basic_correctness/test_preemption.py`") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [96]) +@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +def test_chunked_prefill_recompute( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, +) -> None: + """Ensure that chunked prefill works with preemption.""" + max_num_seqs = min(chunked_prefill_token_size, 256) + enable_chunked_prefill = False + max_num_batched_tokens = None + if chunked_prefill_token_size != -1: + enable_chunked_prefill = True + max_num_batched_tokens = chunked_prefill_token_size + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, + max_num_seqs=max_num_seqs, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_preemption( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + """By default, recompute preemption is enabled""" + + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner( + model, + dtype=dtype, + ) + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +@pytest.mark.parametrize("beam_width", [4]) +def test_swap( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + beam_width: int, +) -> None: + """Use beam search enables swapping.""" + example_prompts = example_prompts[:1] + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width, + max_tokens) + del hf_model + + vllm_model = vllm_runner(model, dtype=dtype, swap_space=10) + vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width, + max_tokens) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + del vllm_model + + for i in range(len(example_prompts)): + hf_output_ids, _ = hf_outputs[i] + vllm_output_ids, _ = vllm_outputs[i] + assert len(hf_output_ids) == len(vllm_output_ids) + for j in range(len(hf_output_ids)): + assert hf_output_ids[j] == vllm_output_ids[j], ( + f"Test{i} output{j}:\nHF: {hf_output_ids}\n" + f"vLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +@pytest.mark.parametrize("beam_width", [4]) +def test_swap_infeasible( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + beam_width: int, +) -> None: + """Verify infeasible swap request will be ignored.""" + BLOCK_SIZE = 16 + prefill_blocks = 2 + decode_blocks = max_tokens // BLOCK_SIZE + example_prompts = example_prompts[:1] + + vllm_model = vllm_runner( + model, + dtype=dtype, + swap_space=10, + block_size=BLOCK_SIZE, + # Since beam search have more than 1 sequence, prefill + decode blocks + # are not enough to finish. + num_gpu_blocks_override=prefill_blocks + decode_blocks, + max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE, + ) + sampling_params = SamplingParams(n=beam_width, + use_beam_search=True, + temperature=0.0, + max_tokens=max_tokens, + ignore_eos=True) + req_outputs = vllm_model.model.generate( + example_prompts, + sampling_params=sampling_params, + ) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + del vllm_model + # Verify the request is ignored and not hang. + assert req_outputs[0].outputs[0].finish_reason == "length" + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) +def test_preemption_infeasible( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + """Verify infeasible preemption request will be ignored.""" + BLOCK_SIZE = 16 + prefill_blocks = 2 + decode_blocks = max_tokens // BLOCK_SIZE + vllm_model = vllm_runner( + model, + dtype=dtype, + block_size=BLOCK_SIZE, + # Not enough gpu blocks to complete a single sequence. + # preemption should happen, and the sequence should be + # ignored instead of hanging forever. + num_gpu_blocks_override=prefill_blocks + decode_blocks // 2, + max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE), + ) + sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True) + req_outputs = vllm_model.model.generate( + example_prompts, + sampling_params=sampling_params, + ) + + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + del vllm_model + # Verify the request is ignored and not hang. + for req_output in req_outputs: + outputs = req_output.outputs + assert len(outputs) == 1 + assert outputs[0].finish_reason == "length" diff --git a/tests/conftest.py b/tests/conftest.py index 5c50fc2d1bab6..671326915b22b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -296,6 +296,7 @@ def __init__( tensor_parallel_size: int = 1, block_size: int = 16, enable_chunked_prefill: bool = False, + swap_space=4, **kwargs, ) -> None: self.model = LLM( @@ -303,7 +304,7 @@ def __init__( tokenizer=tokenizer_name, trust_remote_code=True, dtype=dtype, - swap_space=0, + swap_space=swap_space, disable_log_stats=disable_log_stats, tensor_parallel_size=tensor_parallel_size, max_model_len=max_model_len, diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index 0ee78a9b0a8ea..c3666da7542b5 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -300,6 +300,152 @@ def test_chunked_prefill_block_manager_v2(baseline_llm_generator, assert baseline_token_ids == test_token_ids +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + "model": "facebook/opt-125m", + + # skip cuda graph creation for fast test. + "enforce_eager": True, + + # Allow only 5 sequences of ~1024 tokens in worst case. + "block_size": 16, + "num_gpu_blocks_override": 5 * (64 + 1), + + # Enable prefill cache + "enable_prefix_caching": True, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{ + "use_v2_block_manager": False +}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}]) +@pytest.mark.parametrize("batch_size", [10]) +@pytest.mark.parametrize("seed", [1]) +def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption( + baseline_llm_generator, test_llm_generator, batch_size): + """Verify block manager v2 produces same outputs as block manager v1, even + when there is preemption. + + This constructs two LLM, each with limited number of GPU blocks. The limit + is decided such that as the sequences in the batch grow, sequences must be + preempted and removed from cache. + + If the output token ids are equivalent, then we have confidence that the KV + cache is not corrupted in the v2 block manager. + + NOTE: We want a significant number of generated tokens so that any incorrect + KV mapping has time to build up error. + """ + output_len = 1024 + temperature = 0.0 + + # We want to ensure equality even with preemption. + # We force the total block size to be 1 + cdiv(output_len, block_size) + # so that only one sequence can fit at a time (once the sequences grow). + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + print('Getting token ids from block manager v1') + baseline_token_ids = get_token_ids_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + print('Getting token ids from block manager v2') + test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + prompts, sampling_params) + + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, + test_token_ids): + assert expected_token_ids == actual_token_ids + + assert baseline_token_ids == test_token_ids + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + "model": "facebook/opt-125m", + + # skip cuda graph creation for fast test. + "enforce_eager": True, + + # Allow only 5 sequences of ~1024 tokens in worst case. + "block_size": 16, + "num_gpu_blocks_override": 5 * (64 + 1), + + # Test APC in v2 block + "use_v2_block_manager": True, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{ + "enable_prefix_caching": False +}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"enable_prefix_caching": True}]) +@pytest.mark.parametrize("batch_size", [10]) +@pytest.mark.parametrize("seed", [1]) +def test_auto_prefix_caching_with_preemption(baseline_llm_generator, + test_llm_generator, batch_size): + """Verify block manager v2 with auto prefix caching enabled produces same + outputs as auto prefix caching disabled, even when there is preemption. + + This constructs two LLM, each with limited number of GPU blocks. The limit + is decided such that as the sequences in the batch grow, sequences must be + preempted and removed from cache. + + If the output token ids are equivalent, then we have confidence that auto + prefix caching itself at least don't cause result error. + """ + output_len = 1024 + temperature = 0.0 + + # We want to ensure equality even with preemption. + # We force the total block size to be 1 + cdiv(output_len, block_size) + # so that only one sequence can fit at a time (once the sequences grow). + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + sampling_params = SamplingParams( + max_tokens=output_len, + ignore_eos=True, + temperature=temperature, + ) + + print('Getting token ids with APC disabled') + baseline_token_ids = get_token_ids_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + print('Getting token ids with APC enabled') + test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, + prompts, sampling_params) + + for expected_token_ids, actual_token_ids in zip(baseline_token_ids, + test_token_ids): + assert expected_token_ids == actual_token_ids + + assert baseline_token_ids == test_token_ids + + def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params): for llm in llm_generator: outputs = llm.generate(prompts, sampling_params, use_tqdm=True) diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index 5f4d58dd5fd39..c4c680e109a84 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -358,6 +358,131 @@ def test_get_num_free_blocks_shared(num_blocks: int, block_size: int, i) allocator.free(block) + @staticmethod + @pytest.mark.parametrize("num_blocks", [1024]) + @pytest.mark.parametrize("block_size", [16]) + @pytest.mark.parametrize("seed", list(range(20))) + def test_get_common_computed_block_ids(num_blocks: int, block_size: int, + seed: int): + """Verify get_common_computed_block_ids could get correct result + by create two immutable chain sharing prefix at specified pos, + and compare whether we also could get right result + from get_common_computed_block_ids. + """ + random.seed(seed) + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks * 2, + block_size=block_size) + num_blocks_to_consume = random.randint(1, num_blocks - 1) + + # Create token ids that will exhaust all blocks. + token_ids = list(range(num_blocks_to_consume * block_size)) + blocks = list(range(num_blocks_to_consume)) + + first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids, + allocator=allocator, + ) + + # mark all blocks in first chain as computed + allocator.mark_blocks_as_computed(blocks) + + # After zero_point, second_chain's token_ids would be set -1, which + # make it different from here comparing with first_chain + zero_point = random.randint(1, len(token_ids) - 1) + zero_point_blocks = zero_point // block_size + token_ids[zero_point:] = [-1] * (len(token_ids) - zero_point) + + second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids, + allocator=allocator, + ) + + first_computed_ids = [ + first_chain[i].block_id for i in range(num_blocks_to_consume) + ] + second_computed_ids = [ + second_chain[i].block_id for i in range(num_blocks_to_consume) + ] + res = allocator.get_common_computed_block_ids( + [first_computed_ids, second_computed_ids]) + + assert (len(res) == zero_point_blocks) + + # Test case where two last accessed times are equal + @staticmethod + @pytest.mark.parametrize("num_blocks", [1024]) + @pytest.mark.parametrize("block_size", [16]) + @pytest.mark.parametrize("seed", list(range(20))) + def test_eviction_order(num_blocks: int, block_size: int, seed: int): + """This test case simulate the two chain created and free in order, + and together they would exhaust the initial freed blocks. + + So the next block created after those two chain shall use the block + from the first chain as that block has long access time. + While first chain has two blocks, it shall pick up the last one, as + it has larger token number. + """ + + random.seed(seed) + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + num_blocks_to_consume = num_blocks + 1 + + token_ids = list(range(num_blocks_to_consume * block_size)) + + num_blocks_in_first_chain = 2 + num_tokens_in_first_chain = block_size * num_blocks_in_first_chain + # First chain takes the first block + first_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids[:num_tokens_in_first_chain], + allocator=allocator, + ) + # There should only be one block allocated at this point + assert allocator.get_num_free_blocks() == (num_blocks - + num_blocks_in_first_chain) + + # Set the last accessed time of the first block to 1 + blocks_ids = [block.block_id for block in first_chain] + allocator.mark_blocks_as_accessed(blocks_ids, 1) + + # Second chain takes the rest of the blocks + second_chain = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids[num_tokens_in_first_chain:-block_size], + allocator=allocator, + ) + + # There shouldn't be any blocks left at this point + assert allocator.get_num_free_blocks() == (0) + + assert len(first_chain) == num_blocks_in_first_chain + last_block_id = first_chain[-1].block_id + # Free each block in the first chain. + for i, block in enumerate(first_chain): + allocator.free(block) + + # Set the last accessed time on all of the blocks in the second chain + # to 2 + blocks_ids = [block.block_id for block in second_chain] + allocator.mark_blocks_as_accessed(blocks_ids, 2) + + # Free each block in the second chain. + for i, block in enumerate(second_chain): + allocator.free(block) + + # Allocate a new block and check that it's the least recently used block + # from the first chain. + new_block = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=token_ids[-block_size:], + allocator=allocator, + ) + + assert new_block[0].block_id == last_block_id + @staticmethod def create_immutable_chain( block_size: int, diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 62984ef4caabb..9f9a6180add78 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -224,7 +224,7 @@ def test_swap(): # Swap seq group from CPU -> GPU. cpu_blocks = block_manager.get_block_table(prompt) - assert block_manager.can_swap_in(seq_group) + assert block_manager.can_swap_in(seq_group) == AllocStatus.OK before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() mapping = block_manager.swap_in(seq_group) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index cce396bf4953c..92498c0014666 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -4,6 +4,7 @@ import pytest # noqa from vllm.config import CacheConfig, SchedulerConfig +from vllm.core.interfaces import AllocStatus from vllm.core.scheduler import Scheduler from vllm.sequence import Logprob, SequenceGroup @@ -410,7 +411,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): # Add 1 more task. Swap is not possible, so prefill is running. scheduler.block_manager.can_swap_in = MagicMock() - scheduler.block_manager.can_swap_in.return_value = False + scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER _, seq_group2 = create_dummy_prompt("2", prompt_length=60) scheduler.add_seq_group(seq_group2) @@ -423,7 +424,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert out.scheduled_seq_groups[0].seq_group == seq_group2 # Now although swap is possible, running prefill is prioritized. - scheduler.block_manager.can_swap_in.return_value = True + scheduler.block_manager.can_swap_in.return_value = AllocStatus.OK _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index ab471d206618b..348169035ae97 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -568,7 +568,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): # Both should be preempted, not swapped. assert output.blocks_to_swap_out == {} # Nothing is copied. - assert output.blocks_to_copy == {} + assert output.blocks_to_copy == [] def test_decode_swap_beam_search(): @@ -618,7 +618,7 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): # Both should be preempted, not swapped. assert output.blocks_to_swap_out == expected_swap_mapping # Nothing is copied. - assert output.blocks_to_copy == {} + assert output.blocks_to_copy == [] def test_schedule_decode_blocks_to_copy_update(): @@ -650,7 +650,7 @@ def test_schedule_decode_blocks_to_copy_update(): assert output.blocks_to_swap_out == {} # Since append_slot returns the source -> dist mapping, it should # applied. - assert output.blocks_to_copy == {2: [3]} + assert output.blocks_to_copy == [(2, 3)] def test_schedule_swapped_simple(): @@ -791,7 +791,7 @@ def test_schedule_swapped_cannot_swap_in(): # The last request should be swapped out. scheduler.block_manager.can_swap_in = MagicMock() - scheduler.block_manager.can_swap_in.return_value = False + scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER # Since we cannot swap in, none of the requests are swapped in. budget = create_token_budget() remaining_swapped, output = scheduler._schedule_swapped( @@ -803,6 +803,34 @@ def test_schedule_swapped_cannot_swap_in(): assert len(output.prefill_seq_groups) == 0 +def test_infeasible_swap(): + scheduler = initialize_scheduler() + swapped = deque() + policy = PolicyFactory.get_policy(policy_name="fcfs") + curr_loras = None + blocks_to_swap_out = {} + for _ in range(2): + _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + scheduler._allocate_and_set_running(seq_group) + append_new_token_seq_group(60, seq_group, 1) + scheduler._swap_out(seq_group, blocks_to_swap_out) + swapped.append(seq_group) + + # The last request should be swapped out. + scheduler.block_manager.can_swap_in = MagicMock() + scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER + # Since we cannot swap in, none of the requests are swapped in. + budget = create_token_budget() + remaining_swapped, output = scheduler._schedule_swapped( + swapped, budget, curr_loras, policy) + assert len(remaining_swapped) == 0 + assert len(output.infeasible_seq_groups) == 2 + assert budget.num_batched_tokens == 0 + assert budget.num_curr_seqs == 0 + assert len(output.decode_seq_groups) == 0 + assert len(output.prefill_seq_groups) == 0 + + def test_schedule_swapped_blocks_to_copy(): scheduler = initialize_scheduler() swapped = deque() @@ -825,7 +853,7 @@ def test_schedule_swapped_blocks_to_copy(): assert len(remaining_swapped) == 0 assert len(output.decode_seq_groups) == 1 assert len(output.prefill_seq_groups) == 0 - assert output.blocks_to_copy == {2: [3]} + assert output.blocks_to_copy == [(2, 3)] def test_scheduling_budget(): diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 77aa90b12bf8f..527452630c9f5 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -18,6 +18,7 @@ MODELS = [ os.environ["TEST_DIST_MODEL"], ] +VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" @pytest.mark.skipif(torch.cuda.device_count() < 2, @@ -33,16 +34,19 @@ def test_models( dtype: str, max_tokens: int, ) -> None: + enforce_eager = False + backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) + if backend_by_env_var == "FLASHINFER": + enforce_eager = True hf_model = hf_runner(model, dtype=dtype) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) del hf_model - vllm_model = vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=2, - ) + vllm_model = vllm_runner(model, + dtype=dtype, + tensor_parallel_size=2, + enforce_eager=enforce_eager) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) del vllm_model diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 6d7d4a5806bd0..b6f461b76ed03 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -3,9 +3,13 @@ import pytest import torch +import vllm.distributed.device_communicators.pynccl_utils as pynccl_utils +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, ncclGetUniqueId) -from vllm.distributed.parallel_state import init_distributed_environment +from vllm.distributed.parallel_state import ( + ensure_model_parallel_initialized, get_tensor_model_parallel_cpu_group, + init_distributed_environment, with_pynccl_for_all_reduce) from vllm.utils import update_environment_variables @@ -58,6 +62,65 @@ def test_pynccl(): distributed_run(worker_fn, 2) +@worker_fn_wrapper +def multiple_tp_worker_fn(): + device = torch.device(f"cuda:{torch.distributed.get_rank()}") + groups = [ + torch.distributed.new_group(ranks=[0, 1], backend="gloo"), + torch.distributed.new_group(ranks=[2, 3], backend="gloo") + ] + group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] + comm = NCCLCommunicator(group=group, device=device) + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) + # two groups can communicate independently + if torch.distributed.get_rank() in [0, 1]: + comm.all_reduce(tensor) + comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 4 + else: + comm.all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 2 + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +def test_pynccl_multiple_tp(): + # this tests pynccl for multiple tp groups, in a standalone way + # i.e. call `comm.all_reduce` directly + distributed_run(multiple_tp_worker_fn, 4) + + +@worker_fn_wrapper +def multiple_tp_with_vllm_worker_fn(): + device = torch.device(f"cuda:{torch.distributed.get_rank()}") + torch.cuda.set_device(torch.distributed.get_rank()) + ensure_model_parallel_initialized(2, 2) + pynccl_utils.init_process_group( + group=get_tensor_model_parallel_cpu_group()) + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) + with with_pynccl_for_all_reduce(): + # two tp groups can communicate independently + if torch.distributed.get_rank() in [0, 1]: + tensor = tensor_model_parallel_all_reduce(tensor) + tensor = tensor_model_parallel_all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 4 + else: + tensor = tensor_model_parallel_all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 2 + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +def test_pynccl_multiple_tp_with_vllm(): + # this tests pynccl for multiple tp groups, together with vllm + # i.e. call `tensor_model_parallel_all_reduce` + distributed_run(multiple_tp_with_vllm_worker_fn, 4) + + @worker_fn_wrapper def worker_fn_with_cudagraph(): with torch.no_grad(): diff --git a/tests/engine/test_multiproc_workers.py b/tests/engine/test_multiproc_workers.py new file mode 100644 index 0000000000000..610ad9732fb91 --- /dev/null +++ b/tests/engine/test_multiproc_workers.py @@ -0,0 +1,176 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from time import sleep +from typing import Any, List, Tuple + +import pytest + +from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, + ResultHandler, WorkerMonitor) + + +class DummyWorker: + """Dummy version of vllm.worker.worker.Worker""" + + def __init__(self, rank: int): + self.rank = rank + + def worker_method(self, worker_input: Any) -> Tuple[int, Any]: + sleep(0.05) + + if isinstance(worker_input, Exception): + # simulate error case + raise worker_input + + return self.rank, input + + +def _start_workers() -> Tuple[List[ProcessWorkerWrapper], WorkerMonitor]: + result_handler = ResultHandler() + workers = [ + ProcessWorkerWrapper(result_handler, partial(DummyWorker, rank=rank)) + for rank in range(8) + ] + + worker_monitor = WorkerMonitor(workers, result_handler) + assert not worker_monitor.is_alive() + + result_handler.start() + worker_monitor.start() + assert worker_monitor.is_alive() + + return workers, worker_monitor + + +def test_local_workers() -> None: + """Test workers with sync task submission""" + + workers, worker_monitor = _start_workers() + + def execute_workers(worker_input: str) -> None: + worker_outputs = [ + worker.execute_method("worker_method", worker_input) + for worker in workers + ] + + for rank, output in enumerate(worker_outputs): + assert output.get() == (rank, input) + + executor = ThreadPoolExecutor(max_workers=4) + + # Test concurrent submission from different threads + futures = [ + executor.submit(partial(execute_workers, f"thread {thread_num}")) + for thread_num in range(4) + ] + + for future in futures: + future.result() + + # Test error case + exception = ValueError("fake error") + result = workers[0].execute_method("worker_method", exception) + try: + result.get() + pytest.fail("task should have failed") + except Exception as e: + assert isinstance(e, ValueError) + assert str(e) == "fake error" + + # Test cleanup when a worker fails + assert worker_monitor.is_alive() + workers[3].process.kill() + + # Other workers should get shut down here + worker_monitor.join(2) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.process.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = workers[0].execute_method("worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) + + +def test_local_workers_clean_shutdown() -> None: + """Test clean shutdown""" + + workers, worker_monitor = _start_workers() + + assert worker_monitor.is_alive() + assert all(worker.process.is_alive() for worker in workers) + + # Clean shutdown + worker_monitor.close() + + worker_monitor.join(5) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.process.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = workers[0].execute_method("worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) + + +@pytest.mark.asyncio +async def test_local_workers_async() -> None: + """Test local workers with async task submission""" + + workers, worker_monitor = _start_workers() + + async def execute_workers(worker_input: str) -> None: + worker_coros = [ + worker.execute_method_async("worker_method", worker_input) + for worker in workers + ] + + results = await asyncio.gather(*worker_coros) + for rank, result in enumerate(results): + assert result == (rank, input) + + tasks = [ + asyncio.create_task(execute_workers(f"task {task_num}")) + for task_num in range(4) + ] + + for task in tasks: + await task + + # Test error case + exception = ValueError("fake error") + try: + _result = await workers[0].execute_method_async( + "worker_method", exception) + pytest.fail("task should have failed") + except Exception as e: + assert isinstance(e, ValueError) + assert str(e) == "fake error" + + # Test cleanup when a worker fails + assert worker_monitor.is_alive() + workers[3].process.kill() + + # Other workers should get shut down here + worker_monitor.join(2) + + # Ensure everything is stopped + assert not worker_monitor.is_alive() + assert all(not worker.process.is_alive() for worker in workers) + + # Further attempts to submit tasks should fail + try: + _result = await workers[0].execute_method_async( + "worker_method", "test") + pytest.fail("task should fail once workers have been shut down") + except Exception as e: + assert isinstance(e, ChildProcessError) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py new file mode 100644 index 0000000000000..269b0823fec05 --- /dev/null +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -0,0 +1,37 @@ +import asyncio +from dataclasses import dataclass + +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat + +MODEL_NAME = "openai-community/gpt2" +CHAT_TEMPLATE = "Dummy chat template for testing {}" + + +@dataclass +class MockModelConfig: + tokenizer = MODEL_NAME + trust_remote_code = False + tokenizer_mode = "auto" + max_model_len = 100 + tokenizer_revision = None + + +@dataclass +class MockEngine: + + async def get_model_config(self): + return MockModelConfig + + +async def _async_serving_chat_init(): + serving_completion = OpenAIServingChat(MockEngine(), + served_model_names=[MODEL_NAME], + response_role="assistant", + chat_template=CHAT_TEMPLATE) + return serving_completion + + +def test_async_serving_chat_init(): + serving_completion = asyncio.run(_async_serving_chat_init()) + assert serving_completion.tokenizer is not None + assert serving_completion.tokenizer.chat_template == CHAT_TEMPLATE diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py index 30f0ad5d8272f..41c871ca40bc8 100644 --- a/tests/entrypoints/test_guided_processors.py +++ b/tests/entrypoints/test_guided_processors.py @@ -57,7 +57,9 @@ def test_guided_logits_processors(): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer) - json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer) + json_LP = JSONLogitsProcessor(TEST_SCHEMA, + tokenizer, + whitespace_pattern=None) regex_LP.init_state() token_ids = tokenizer.encode( diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 68332228ace08..e53e64a0c1ff8 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -13,6 +13,7 @@ # and debugging. import ray import requests +import torch # downloading lora to test lora requests from huggingface_hub import snapshot_download from openai import BadRequestError @@ -149,7 +150,7 @@ def server(zephyr_lora_files): ray.shutdown() -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def client(): client = openai.AsyncOpenAI( base_url="http://localhost:8000/v1", @@ -786,6 +787,25 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI): assert "extra_forbidden" in exc_info.value.message +async def test_complex_message_content(server, client: openai.AsyncOpenAI): + resp = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": + "user", + "content": [{ + "type": + "text", + "text": + "what is 1+1? please provide the result without any other text." + }] + }], + temperature=0, + seed=0) + content = resp.choices[0].message.content + assert content == "2" + + async def test_guided_grammar(server, client: openai.AsyncOpenAI): simple_sql_grammar = """ start: select_statement @@ -851,5 +871,24 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, assert len(logprobs.tokens) > 5 +async def test_long_seed(server, client: openai.AsyncOpenAI): + for seed in [ + torch.iinfo(torch.long).min - 1, + torch.iinfo(torch.long).max + 1 + ]: + with pytest.raises(BadRequestError) as exc_info: + await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "system", + "content": "You are a helpful assistant.", + }], + temperature=0, + seed=seed) + + assert ("greater_than_equal" in exc_info.value.message + or "less_than_equal" in exc_info.value.message) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/kernels/conftest.py b/tests/kernels/conftest.py index d26da2c7fe4ee..4f2f9cc3dac7d 100644 --- a/tests/kernels/conftest.py +++ b/tests/kernels/conftest.py @@ -1,8 +1,14 @@ import pytest -from vllm.utils import create_kv_caches_with_random +from vllm.utils import (create_kv_caches_with_random, + create_kv_caches_with_random_flash) @pytest.fixture() def kv_cache_factory(): return create_kv_caches_with_random + + +@pytest.fixture() +def kv_cache_factory_flashinfer(): + return create_kv_caches_with_random_flash diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 9b1f3e30b6dca..84539205e0ae3 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -61,7 +61,7 @@ def ref_single_query_cached_kv_attention( key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, - context_lens: torch.Tensor, + seq_lens: torch.Tensor, scale: float, alibi_slopes: Optional[torch.Tensor], ) -> None: @@ -72,15 +72,15 @@ def ref_single_query_cached_kv_attention( num_seqs = query.shape[0] block_tables = block_tables.cpu().tolist() - context_lens = context_lens.cpu().tolist() + seq_lens = seq_lens.cpu().tolist() for i in range(num_seqs): q = query[i].unsqueeze(0) block_table = block_tables[i] - context_len = int(context_lens[i]) + seq_len = int(seq_lens[i]) keys = [] values = [] - for j in range(context_len): + for j in range(seq_len): block_number = int(block_table[j // block_size]) block_offset = j % block_size @@ -100,8 +100,8 @@ def ref_single_query_cached_kv_attention( alibi_bias = None if alibi_slopes is not None: # Create the ALiBi bias used in the paged attention kernel. - position_ids = torch.arange(context_len).int() - alibi_bias = (position_ids - context_len + 1).float() + position_ids = torch.arange(seq_len).int() + alibi_bias = (position_ids - seq_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( 1, 1, -1) @@ -149,13 +149,13 @@ def test_paged_attention( if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) - context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - context_lens[-1] = MAX_SEQ_LEN - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int) + seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + seq_lens = torch.tensor(seq_lens, dtype=torch.int) # Create the block tables. - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = [] for _ in range(num_seqs): block_table = [ @@ -186,16 +186,15 @@ def test_paged_attention( num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, ) elif version == "v2": - num_partitions = ((max_context_len + PARTITION_SIZE - 1) // - PARTITION_SIZE) + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape tmp_output = torch.empty( @@ -218,9 +217,9 @@ def test_paged_attention( num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -255,7 +254,7 @@ def test_paged_attention( key_cache, value_cache, block_tables, - context_lens, + seq_lens, scale, alibi_slopes, ) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index d1051fd7e2f4d..94a577139596e 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -5,6 +5,7 @@ import torch from vllm import _custom_ops as ops +from vllm._C import cache_ops from vllm.utils import is_hip COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] @@ -62,12 +63,13 @@ def test_copy_blocks( src_blocks = random.sample(range(num_blocks), num_mappings) remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) - block_mapping = {} + block_mapping = [] for i in range(num_mappings): src = src_blocks[i] dst1 = dst_blocks[2 * i] dst2 = dst_blocks[2 * i + 1] - block_mapping[src] = [dst1, dst2] + block_mapping.append((src, dst1)) + block_mapping.append((src, dst2)) # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, @@ -80,15 +82,17 @@ def test_copy_blocks( cloned_value_caches = [value_cache.clone() for value_cache in value_caches] # Call the copy blocks kernel. - ops.copy_blocks(key_caches, value_caches, block_mapping) + block_mapping_tensor = torch.tensor(block_mapping, + dtype=torch.int64, + device=device).view(-1, 2) + ops.copy_blocks(key_caches, value_caches, block_mapping_tensor) # Run the reference implementation. - for src, dsts in block_mapping.items(): - for dst in dsts: - for cloned_key_cache in cloned_key_caches: - cloned_key_cache[dst].copy_(cloned_key_cache[src]) - for cloned_value_cache in cloned_value_caches: - cloned_value_cache[dst].copy_(cloned_value_cache[src]) + for src, dst in block_mapping: + for cloned_key_cache in cloned_key_caches: + cloned_key_cache[dst].copy_(cloned_key_cache[src]) + for cloned_value_cache in cloned_value_caches: + cloned_value_cache[dst].copy_(cloned_value_cache[src]) # Compare the results. for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): @@ -191,6 +195,82 @@ def test_reshape_and_cache( assert torch.allclose(value_cache, cloned_value_cache) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@torch.inference_mode() +def test_reshape_and_cache_flash( + kv_cache_factory_flashinfer, + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: str, + kv_cache_dtype: str, +) -> None: + if kv_cache_dtype == "fp8": + pytest.skip() + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Create a random slot mapping. + num_slots = block_size * num_blocks + slot_mapping = random.sample(range(num_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device='cuda') + + qkv = torch.randn(num_tokens, + 3, + num_heads, + head_size, + dtype=dtype, + device=device) + _, key, value = qkv.unbind(dim=1) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory_flashinfer( + num_blocks, + block_size, + 1, + num_heads, + head_size, + kv_cache_dtype, + dtype, + ) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Clone the KV caches. + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() + + # Call the reshape_and_cache kernel. + cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype) + + # Run the reference implementation. + block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor') + block_indicies = block_indicies.cpu().tolist() + block_offsets = slot_mapping % block_size + block_offsets = block_offsets.cpu().tolist() + for i in range(num_tokens): + block_idx = block_indicies[i] + block_offset = block_offsets[i] + cloned_key_cache[block_idx, block_offset, :, :] = key[i] + cloned_value_cache[block_idx, block_offset, :, :] = value[i] + + assert torch.allclose(key_cache, cloned_key_cache) + assert torch.allclose(value_cache, cloned_value_cache) + + @pytest.mark.parametrize("direction", COPYING_DIRECTION) @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @pytest.mark.parametrize("num_heads", NUM_HEADS) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 046f11d957bdd..2356b9ec18b0d 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -77,8 +77,8 @@ def test_mixtral_moe(dtype: torch.dtype): for i in range(config.num_local_experts): weights = (hf_moe.experts[i].w1.weight.data, hf_moe.experts[i].w3.weight.data) - vllm_moe.ws[i][:] = torch.cat(weights, dim=0) - vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data + vllm_moe.w13_weight[i][:] = torch.cat(weights, dim=0) + vllm_moe.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index ad31b0a7c2a19..5a5987e2242fa 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -15,6 +15,7 @@ CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048] @pytest.mark.parametrize("num_heads", NUM_HEADS) @@ -22,11 +23,13 @@ @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW) @torch.inference_mode() def test_contexted_kv_attention( num_heads: int, num_queries_per_kv: int, head_size: int, + sliding_window: int, dtype: torch.dtype, device: str, ) -> None: @@ -48,12 +51,12 @@ def test_contexted_kv_attention( cache_size = 640 block_size = 32 max_block_per_request = 64 - subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] + query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] - seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)] + seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)] num_kv_heads = num_heads // num_queries_per_kv - num_tokens = sum(subquery_lens) + num_tokens = sum(query_lens) query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) query.uniform_(-1e-3, 1e-3) output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) @@ -72,15 +75,15 @@ def test_contexted_kv_attention( num_kv_heads, head_size, dtype=dtype) - k = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype) - v = torch.zeros(sum(subquery_lens), num_kv_heads, head_size, dtype=dtype) + k = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) + v = torch.zeros(sum(query_lens), num_kv_heads, head_size, dtype=dtype) values = torch.arange(0, cache_size, dtype=torch.long) values = values[torch.randperm(cache_size)] block_table = values[:BS * max_block_per_request].view( BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1], + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], dtype=torch.long), dim=0) max_input_len = MAX_SEQ_LEN @@ -89,7 +92,7 @@ def test_contexted_kv_attention( dtype=torch.long), dim=0) for i in range(BS): - for j in range(subquery_lens[i]): + for j in range(query_lens[i]): k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j]) v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + @@ -123,12 +126,32 @@ def test_contexted_kv_attention( # Warm up the Triton kernel by calling it once before actually measuring # generation time - context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, - b_start_loc, b_seq_len, b_ctx_len, max_input_len) + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + sliding_window=sliding_window) torch.cuda.synchronize() start_time = time.time() - context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, - b_start_loc, b_seq_len, b_ctx_len, max_input_len) + context_attention_fwd(query, + k, + v, + output, + k_cache, + v_cache, + block_table, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + sliding_window=sliding_window) torch.cuda.synchronize() end_time = time.time() print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") @@ -155,7 +178,10 @@ def test_contexted_kv_attention( value = value.unsqueeze(0) attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( - subquery_lens, seq_lens) + query_lens, seq_lens) + if sliding_window > 0: + attn_bias = attn_bias.make_local_attention_from_bottomright( + sliding_window) output_ref = xops.memory_efficient_attention_forward( query, key, diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 0ab9c63ce4377..e0aa14f165c2d 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -1,4 +1,12 @@ +from typing import List + import pytest +from prometheus_client import REGISTRY + +from vllm import EngineArgs, LLMEngine +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.sampling_params import SamplingParams MODELS = [ "facebook/opt-125m", @@ -68,3 +76,119 @@ def test_metric_counter_generation_tokens( assert vllm_generation_count == metric_count, ( f"generation token count: {vllm_generation_count!r}\n" f"metric: {metric_count!r}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize( + "served_model_name", + [None, [], ["ModelName0"], ["ModelName0", "ModelName1", "ModelName2"]]) +def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str, + served_model_name: List[str]) -> None: + vllm_model = vllm_runner(model, + dtype=dtype, + disable_log_stats=False, + gpu_memory_utilization=0.3, + served_model_name=served_model_name) + stat_logger = vllm_model.model.llm_engine.stat_logger + metrics_tag_content = stat_logger.labels["model_name"] + + del vllm_model + + if served_model_name is None or served_model_name == []: + assert metrics_tag_content == model, ( + f"Metrics tag model_name is wrong! expect: {model!r}\n" + f"actual: {metrics_tag_content!r}") + else: + assert metrics_tag_content == served_model_name[0], ( + f"Metrics tag model_name is wrong! expect: " + f"{served_model_name[0]!r}\n" + f"actual: {metrics_tag_content!r}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [4]) +@pytest.mark.parametrize("disable_log_stats", [True, False]) +@pytest.mark.asyncio +async def test_async_engine_log_metrics_regression( + example_prompts, + model: str, + dtype: str, + max_tokens: int, + disable_log_stats: bool, +) -> None: + """ + Regression test ensuring async engine generates metrics + when disable_log_stats=False + (see: https://github.com/vllm-project/vllm/pull/4150#pullrequestreview-2008176678) + """ + engine_args = AsyncEngineArgs(model=model, + dtype=dtype, + disable_log_stats=disable_log_stats) + async_engine = AsyncLLMEngine.from_engine_args(engine_args) + for i, prompt in enumerate(example_prompts): + results = async_engine.generate( + prompt, + SamplingParams(max_tokens=max_tokens), + f"request-id-{i}", + ) + # Exhaust the async iterator to make the async engine work + async for _ in results: + pass + + assert_metrics(async_engine.engine, disable_log_stats, + len(example_prompts)) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [4]) +@pytest.mark.parametrize("disable_log_stats", [True, False]) +def test_engine_log_metrics_regression( + example_prompts, + model: str, + dtype: str, + max_tokens: int, + disable_log_stats: bool, +) -> None: + engine_args = EngineArgs(model=model, + dtype=dtype, + disable_log_stats=disable_log_stats) + engine = LLMEngine.from_engine_args(engine_args) + for i, prompt in enumerate(example_prompts): + engine.add_request( + f"request-id-{i}", + prompt, + SamplingParams(max_tokens=max_tokens), + ) + while engine.has_unfinished_requests(): + engine.step() + + assert_metrics(engine, disable_log_stats, len(example_prompts)) + + +def assert_metrics(engine: LLMEngine, disable_log_stats: bool, + num_requests: int) -> None: + if disable_log_stats: + with pytest.raises(AttributeError): + _ = engine.stat_logger + else: + assert (engine.stat_logger + is not None), "engine.stat_logger should be set" + # Ensure the count bucket of request-level histogram metrics matches + # the number of requests as a simple sanity check to ensure metrics are + # generated + labels = {'model_name': engine.model_config.model} + request_histogram_metrics = [ + "vllm:e2e_request_latency_seconds", + "vllm:request_prompt_tokens", + "vllm:request_generation_tokens", + "vllm:request_params_best_of", + "vllm:request_params_n", + ] + for metric_name in request_histogram_metrics: + metric_value = REGISTRY.get_sample_value(f"{metric_name}_count", + labels) + assert ( + metric_value == num_requests), "Metrics should be collected" diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index 504eaad43c8d7..3dde498bcd639 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -43,3 +43,18 @@ def test_models( f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") assert hf_output_ids == vllm_output_ids, ( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + vllm_model = vllm_runner(model, dtype=dtype) + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + del vllm_model diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py new file mode 100644 index 0000000000000..e87a1783a83f1 --- /dev/null +++ b/tests/models/test_fp8.py @@ -0,0 +1,90 @@ +# flake8: noqa +"""Tests fp8 models against ground truth generation +Note: these tests will only pass on L4 GPU. +""" +import os + +import pytest +import torch +from transformers import AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +MAX_MODEL_LEN = 1024 + +MODELS = [ + "nm-testing/Meta-Llama-3-8B-Instruct-FP8", + "meta-llama/Meta-Llama-3-8B-Instruct", +] + +EXPECTED_STRS_MAP = { + "nm-testing/Meta-Llama-3-8B-Instruct-FP8": [ + 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya tori, nemuri nemuri)\n\n**' + ], + "meta-llama/Meta-Llama-3-8B-Instruct": [ + 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', + 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', + 'In the year 2154, the robotics lab at NeuroSpark Industries was on the cusp of', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' + ], +} + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] +fp8_not_supported = (capability < + QUANTIZATION_METHODS["fp8"].get_min_capability()) + + +@pytest.mark.skipif(fp8_not_supported, + reason="fp8 is not supported on this GPU type.") +@pytest.mark.parametrize("model_name", MODELS) +def test_models( + example_prompts, + model_name, +) -> None: + model = LLM(model=model_name, + max_model_len=MAX_MODEL_LEN, + enforce_eager=True, + quantization="fp8") + + tokenizer = AutoTokenizer.from_pretrained(model_name) + formatted_prompts = [ + tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], + tokenize=False, + add_generation_prompt=True) + for prompt in example_prompts + ] + + params = SamplingParams(max_tokens=20, temperature=0) + generations = [] + # Note: these need to be run 1 at a time due to numerical precision, + # since the expected strs were generated this way. + for prompt in formatted_prompts: + outputs = model.generate(prompt, params) + generations.append(outputs[0].outputs[0].text) + del model + + print(generations) + expected_strs = EXPECTED_STRS_MAP[model_name] + for i in range(len(example_prompts)): + generated_str = generations[i] + expected_str = expected_strs[i] + assert expected_str == generated_str, ( + f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}") diff --git a/tests/models/test_gptq_marlin.py b/tests/models/test_gptq_marlin.py index dc027697ffd4d..4d73843f970c4 100644 --- a/tests/models/test_gptq_marlin.py +++ b/tests/models/test_gptq_marlin.py @@ -39,6 +39,13 @@ ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True"), # act_order==True, group_size=32 ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True"), + + # 8-bit, act_order==True, group_size=channelwise + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit--1g-actorder_True"), + # 8-bit, act_order==True, group_size=128 + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-128g-actorder_True"), + # 8-bit, act_order==True, group_size=32 + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-8bit-32g-actorder_True"), ] @@ -65,8 +72,7 @@ def test_models( dtype=dtype, quantization="marlin", max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=1, - disable_custom_all_reduce=True) + tensor_parallel_size=1) gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) @@ -78,8 +84,7 @@ def test_models( dtype=dtype, quantization="gptq", max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=1, - disable_custom_all_reduce=True) + tensor_parallel_size=1) gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts, max_tokens, num_logprobs) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index cfe2539e3a052..e4609620387fa 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -49,3 +49,18 @@ def test_models( f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") assert hf_output_ids == vllm_output_ids, ( f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + vllm_model = vllm_runner(model, dtype=dtype) + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) + del vllm_model diff --git a/tests/samplers/test_ignore_eos.py b/tests/samplers/test_ignore_eos.py new file mode 100644 index 0000000000000..864657a3c2b28 --- /dev/null +++ b/tests/samplers/test_ignore_eos.py @@ -0,0 +1,31 @@ +"""Make sure ignore_eos works. + +Run `pytest tests/samplers/test_ignore_eos.py`. +""" + +import pytest + +from vllm import SamplingParams + +MODELS = ["facebook/opt-125m"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [1024]) +def test_beam_search_single_input( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + example_prompts = "1 + 1 is" + + vllm_model = vllm_runner(model, dtype=dtype) + sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True) + ignore_eos_output = vllm_model.model.generate( + example_prompts, sampling_params=sampling_params) + print(len(ignore_eos_output[0].outputs[0].token_ids)) + assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) < 10 + assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) >= 0 diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 7859f0b21812f..e4fea165a4d46 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -58,7 +58,7 @@ def _do_sample( device: str, ): seq_group_metadata_list = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -68,12 +68,12 @@ def _do_sample( sampling_params=sampling_params, block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=device, pin_memory=model_runner.pin_memory) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) @@ -421,7 +421,7 @@ def run_test_case(*, "Invalid test case, need seq_group_metadata_list" batch_size = 0 - prompt_lens = [] + seq_lens = [] sampling_params_per_row = [] for sgm in seq_group_metadata_list: sampling_params = sgm.sampling_params @@ -431,7 +431,7 @@ def run_test_case(*, # a prompt seq_group has only one sequence seq_data = next(iter(sgm.seq_data.values())) prompt_len = seq_data.get_prompt_len() - prompt_lens.append(prompt_len) + seq_lens.append(prompt_len) if sgm.sampling_params.prompt_logprobs: # with prompt_logprobs each token in the prompt has a row in @@ -451,8 +451,8 @@ def run_test_case(*, _, fake_logits, sampler, model_runner = _prepare_test(batch_size) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens=prompt_lens if prompt_lens else None, - subquery_lens=prompt_lens if prompt_lens else None, + seq_lens=seq_lens if seq_lens else None, + query_lens=seq_lens if seq_lens else None, device=device, pin_memory=model_runner.pin_memory) # the logits tensor is modified in-place by the sampler @@ -497,7 +497,7 @@ def test_sampler_mixed(seed: int, device: str): seq_group_metadata_list = [] expected_tokens: List[Optional[List[int]]] = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): expected: Optional[List[int]] = None sampling_type = random.randint(0, 3) @@ -532,13 +532,13 @@ def test_sampler_mixed(seed: int, device: str): sampling_params=sampling_params, block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) def test_sampling(model_runner: ModelRunner): sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=device, pin_memory=model_runner.pin_memory) sampler_output = sampler(logits=fake_logits, @@ -575,7 +575,7 @@ def test_sampling(model_runner: ModelRunner): # Shuffle the batch and resample target_index = list(range(batch_size)) for list_to_shuffle in (target_index, seq_group_metadata_list, - expected_tokens, prompt_lens): + expected_tokens, seq_lens): random.Random(seed).shuffle(list_to_shuffle) target_index = torch.tensor(target_index) input_tensor.data = input_tensor.index_select(0, target_index) @@ -620,7 +620,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): assert len(warpers) == 2 # top_p and top_k seq_group_metadata_list = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -634,12 +634,12 @@ def test_sampler_top_k_top_p(seed: int, device: str): ), block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=device, pin_memory=model_runner.pin_memory) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 59fb8311fc5b7..b1ab8a07ca636 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -1,10 +1,132 @@ -from typing import List, Tuple +import asyncio +import time +from itertools import cycle +from typing import Dict, List, Optional, Tuple, Union import pytest +import ray +import torch +from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, + nvmlInit) from tests.conftest import cleanup from vllm import LLM +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.lora.request import LoRARequest from vllm.model_executor.utils import set_random_seed +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.sequence import Logprob, MultiModalData +from vllm.usage.usage_lib import UsageContext +from vllm.utils import Counter, random_uuid + + +class AsyncLLM: + """AsyncLLM + + Note: Current LLM class in vllm don't support async mode, for test purpose, + we implement async one in here. Maybe we could move to + vllm/entrypoints/llm.py in future. + + Below AsyncLLM is directly borrow from vllm/entrypoints/llm.py with changes + to make to work in async mode. + """ + + def __init__( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: int = 4, + enforce_eager: bool = False, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + **kwargs, + ) -> None: + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + self.engine_args = AsyncEngineArgs( + model=model, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + enforce_eager=enforce_eager, + max_seq_len_to_capture=max_seq_len_to_capture, + engine_use_ray=True, + disable_custom_all_reduce=disable_custom_all_reduce, + **kwargs, + ) + self.request_counter = Counter() + + def generate( + self, + prompts: Optional[Union[str, List[str]]] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + prompt_token_ids: Optional[List[List[int]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + + llm_engine = AsyncLLMEngine.from_engine_args( + self.engine_args, usage_context=UsageContext.LLM_CLASS) + + if prompts is None: + raise ValueError("prompts must be provided.") + if isinstance(prompts, str): + # Convert a single prompt to a list. + prompts = [prompts] + + if prompts is not None: + num_requests = len(prompts) + + if sampling_params is None: + # Use default sampling params. + sampling_params = SamplingParams() + + elif isinstance(sampling_params, + list) and len(sampling_params) != num_requests: + raise ValueError("The lengths of prompts and " + "sampling_params must be the same.") + + async def get_output(prompt, sampling_param) -> str: + request_id = random_uuid() + results_generator = llm_engine.generate(prompt, sampling_param, + request_id) + final_output = None + async for request_output in results_generator: + final_output = request_output + return final_output + + outputs = [] + try: + for i in range(num_requests): + prompt = prompts[i] if prompts is not None else None + res = asyncio.run(get_output(prompt, sampling_params)) + outputs.append(res) + finally: + ray.shutdown() + return outputs @pytest.fixture @@ -35,9 +157,20 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs, test_name = request.node.name def generator_inner(): - print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}') - llm = LLM(**kwargs) + wait_for_gpu_memory_to_clear( + devices=list(range(torch.cuda.device_count())), + threshold_bytes=2 * 2**30, + timeout_s=60, + ) + + use_async = False + if "use_async" in kwargs: + use_async = kwargs.pop("use_async") + print(f'{use_async=}') + + print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}') + llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs) set_random_seed(seed) yield llm @@ -64,3 +197,109 @@ def get_output_from_llm_generator( del llm return tokens, token_ids + + +def get_logprobs_from_llm_generator( + llm_generator, prompts, + sampling_params) -> List[List[Dict[int, Logprob]]]: + """Returns a dict of (token_id: Logprob) for each generated position, for + each sequence in the batch. + """ + for llm in llm_generator(): + outputs = llm.generate(prompts, sampling_params, use_tqdm=True) + logprobs = [output.outputs[0].logprobs[:] for output in outputs] + del llm + + return logprobs + + +def run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len: bool, + print_tokens: bool = False): + """Helper method that compares the outputs of both the baseline LLM and + the test LLM. It asserts greedy equality, e.g. that the outputs are exactly + the same when temperature is zero. + """ + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "San Francisco is know for its", + "Facebook was created in 2004 by", + "Curious George is a", + "Python 3.11 brings improvements to its", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + # If the test requires that we generated max_output_len tokens, then set the + # sampling params to ignore eos token. + ignore_eos = force_output_len + + sampling_params = SamplingParams( + max_tokens=max_output_len, + ignore_eos=ignore_eos, + temperature=temperature, + ) + + spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator( + test_llm_generator, prompts, sampling_params) + + (baseline_batch_tokens, + baseline_batch_token_ids) = get_output_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + assert len(baseline_batch_token_ids) == len(prompts) + assert len(spec_batch_token_ids) == len(prompts) + + for i, (baseline_token_ids, baseline_tokens, spec_token_ids, + spec_tokens) in enumerate( + zip(baseline_batch_token_ids, baseline_batch_tokens, + spec_batch_token_ids, spec_batch_tokens)): + if print_tokens: + print(f'{i=} {baseline_tokens=}') + print(f'{i=} {spec_tokens=}') + print(f'{i=} {baseline_token_ids=}') + print(f'{i=} {spec_token_ids=}') + assert baseline_token_ids == spec_token_ids + + +def wait_for_gpu_memory_to_clear(devices: List[int], + threshold_bytes: int, + timeout_s: float = 120) -> None: + # Use nvml instead of pytorch to reduce measurement error from torch cuda + # context. + nvmlInit() + start_time = time.time() + while True: + output = {} + output_raw = {} + for device in devices: + dev_handle = nvmlDeviceGetHandleByIndex(device) + mem_info = nvmlDeviceGetMemoryInfo(dev_handle) + gb_used = mem_info.used / 2**30 + output_raw[device] = gb_used + output[device] = f'{gb_used:.02f}' + + print('gpu memory used (GB): ', end='') + for k, v in output.items(): + print(f'{k}={v}; ', end='') + print('') + + dur_s = time.time() - start_time + if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()): + print(f'Done waiting for free GPU memory on devices {devices=} ' + f'({threshold_bytes/2**30=}) {dur_s=:.02f}') + break + + if dur_s >= timeout_s: + raise ValueError(f'Memory of devices {devices=} not free after ' + f'{dur_s=:.02f} ({threshold_bytes/2**30=})') + + time.sleep(5) diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py index fde950c14382c..60c20ed7db7a3 100644 --- a/tests/spec_decode/e2e/test_compatibility.py +++ b/tests/spec_decode/e2e/test_compatibility.py @@ -42,10 +42,17 @@ def test_spec_decode_xfail_ray(test_llm_generator): temperature=temperature, ) - with pytest.raises(AssertionError, - match="Speculative decoding not yet supported for "): - get_output_from_llm_generator(test_llm_generator, prompts, - sampling_params) + try: + with pytest.raises( + AssertionError, + match="Speculative decoding not yet supported for "): + get_output_from_llm_generator(test_llm_generator, prompts, + sampling_params) + finally: + # we need to free up ray resource, + # so that latter test could use the gpu we allocated here + import ray + ray.shutdown() @pytest.mark.parametrize( diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py new file mode 100644 index 0000000000000..9572aac7df6e0 --- /dev/null +++ b/tests/spec_decode/e2e/test_logprobs.py @@ -0,0 +1,335 @@ +import math +from itertools import cycle + +import pytest + +from vllm import SamplingParams + +from .conftest import get_logprobs_from_llm_generator + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "max_logprobs": 6, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, +}]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 7, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_logprobs_equality(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify output logprobs are equal with and without speculative decoding. + """ + run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "max_logprobs": 6, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, +}]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize("num_logprobs", [6]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 7, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int, + num_logprobs: int): + """Verify output logprobs are equal with and without spec decode. + This specifies a number of logprobs >1. + """ + run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + logprob_rank=num_logprobs) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, +}, { + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 6, +}]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_logprobs_different_k(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Veriy logprob greedy equality with different speculation lens. + """ + run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + + # Artificially limit the draft model max model len; this forces vLLM + # to skip speculation once the sequences grow beyond 32-k tokens. + "speculative_max_model_len": 32, + }]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_logprobs_when_skip_speculation(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): + """Verify logprobs greedy equality when some sequences skip speculation. + """ + run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{ + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, +}]) +@pytest.mark.parametrize("batch_size", [1]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_logprobs_temp_1(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify at least one logprob result has num_logprobs+1, which tests the + case where the sampled token is not in top-k logprobs. + + Ideally, this test should validate equality with non-spec by getting + logprobs. This is left as future improvement. + """ + batch_size = 8 + max_output_len = output_len + force_output_len = True + logprob_rank = 5 + + temperature = 1.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "San Francisco is know for its", + "Facebook was created in 2004 by", + "Curious George is a", + "Python 3.11 brings improvements to its", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + # If the test requires that we generated max_output_len tokens, then set the + # sampling params to ignore eos token. + ignore_eos = force_output_len + + sampling_params = SamplingParams( + max_tokens=max_output_len, + ignore_eos=ignore_eos, + temperature=temperature, + logprobs=logprob_rank, + ) + + spec_batch_logprobs = get_logprobs_from_llm_generator( + test_llm_generator, prompts, sampling_params) + + num_returned_logprobs = [ + len(logprob_dict) for seq_logprobs in spec_batch_logprobs + for logprob_dict in seq_logprobs + ] + + # Assert one of the returned logprobs has > num_logprobs (indicating the + # sampled token is not in top-k). + assert any([ + num_returned > logprob_rank for num_returned in num_returned_logprobs + ]) + + +def run_greedy_logprobs_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len, + force_output_len: bool, + logprob_rank: int = 1): + """Helper method that compares the logprobs outputs of both the baseline LLM + and the test LLM. It asserts greedy equality of the logprobs when the + temperature is zero. + """ + temperature = 0.0 + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "San Francisco is know for its", + "Facebook was created in 2004 by", + "Curious George is a", + "Python 3.11 brings improvements to its", + ] + + prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] + + # If the test requires that we generated max_output_len tokens, then set the + # sampling params to ignore eos token. + ignore_eos = force_output_len + + sampling_params = SamplingParams( + max_tokens=max_output_len, + ignore_eos=ignore_eos, + temperature=temperature, + logprobs=logprob_rank, + ) + + spec_batch_logprobs = get_logprobs_from_llm_generator( + test_llm_generator, prompts, sampling_params) + baseline_batch_logprobs = get_logprobs_from_llm_generator( + baseline_llm_generator, prompts, sampling_params) + + assert len(baseline_batch_logprobs) == len(prompts) + assert len(spec_batch_logprobs) == len(prompts) + + # For each sequence in the batch. + for i, (baseline_logprobs, spec_logprobs) in enumerate( + zip(baseline_batch_logprobs, spec_batch_logprobs)): + assert len(spec_logprobs) == len(baseline_logprobs) + + # For each generated position of the sequence. + for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate( + zip(spec_logprobs, baseline_logprobs)): + + # Map rank to token/logprob in spec output. + spec_rank_to_token_id = { + value.rank: key + for key, value in spec_pos_logprobs.items() + } + spec_rank_to_logprob = { + value.rank: value.logprob + for key, value in spec_pos_logprobs.items() + } + + # Map rank to token/logprob in baseline output. + baseline_rank_to_token_id = { + value.rank: key + for key, value in baseline_pos_logprobs.items() + } + baseline_rank_to_logprob = { + value.rank: value.logprob + for key, value in baseline_pos_logprobs.items() + } + + # Assert set of ranks returned is equal. + assert set(spec_rank_to_token_id.keys()) == set( + baseline_rank_to_token_id.keys()) + + # Assert each logprob/token id is correct, keyed by rank. + for rank in sorted(set(spec_rank_to_token_id.keys())): + assert spec_rank_to_token_id[ + rank] == baseline_rank_to_token_id[rank], f"{rank}" + assert math.isclose( + a=spec_rank_to_logprob[rank], + b=baseline_rank_to_logprob[rank], + abs_tol=1e-1, + ) diff --git a/tests/spec_decode/e2e/test_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py similarity index 89% rename from tests/spec_decode/e2e/test_correctness.py rename to tests/spec_decode/e2e/test_multistep_correctness.py index 0536cc4ecde76..f15fcc4746d20 100644 --- a/tests/spec_decode/e2e/test_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -35,7 +35,8 @@ from vllm import SamplingParams -from .conftest import get_output_from_llm_generator +from .conftest import (get_output_from_llm_generator, + run_greedy_equality_correctness_test) @pytest.mark.parametrize( @@ -49,7 +50,7 @@ "enforce_eager": True, # Required for spec decode. - "use_v2_block_manager": True + "use_v2_block_manager": True, }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", @@ -109,6 +110,44 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator, assert actual_tokens.strip() == expected_tokens.strip() +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Use a small model for a fast test. + # Note this is repeated in the test body; to initialize a tokenizer. + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Use AsyncLLM engine + "use_async": True, + }]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, +]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_e2e_with_async_engine(test_llm_generator, + baseline_llm_generator, + batch_size: int): + """Verify spec decode works well with async LLM engine. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=32, + force_output_len=True) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ @@ -538,60 +577,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, batch_size, max_output_len=output_len, force_output_len=True) - - -def run_greedy_equality_correctness_test(baseline_llm_generator, - test_llm_generator, - batch_size, - max_output_len, - force_output_len: bool, - print_tokens: bool = False): - """Helper method that compares the outputs of both the baseline LLM and - the test LLM. It asserts greedy equality, e.g. that the outputs are exactly - the same when temperature is zero. - """ - temperature = 0.0 - - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - "San Francisco is know for its", - "Facebook was created in 2004 by", - "Curious George is a", - "Python 3.11 brings improvements to its", - ] - - prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] - - # If the test requires that we generated max_output_len tokens, then set the - # sampling params to ignore eos token. - ignore_eos = force_output_len - - sampling_params = SamplingParams( - max_tokens=max_output_len, - ignore_eos=ignore_eos, - temperature=temperature, - ) - - spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator( - test_llm_generator, prompts, sampling_params) - - (baseline_batch_tokens, - baseline_batch_token_ids) = get_output_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) - - assert len(baseline_batch_token_ids) == len(prompts) - assert len(spec_batch_token_ids) == len(prompts) - - for i, (baseline_token_ids, baseline_tokens, spec_token_ids, - spec_tokens) in enumerate( - zip(baseline_batch_token_ids, baseline_batch_tokens, - spec_batch_token_ids, spec_batch_tokens)): - if print_tokens: - print(f'{i=} {baseline_tokens=}') - print(f'{i=} {spec_tokens=}') - print(f'{i=} {baseline_token_ids=}') - print(f'{i=} {spec_token_ids=}') - assert baseline_token_ids == spec_token_ids diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py new file mode 100644 index 0000000000000..44ef400c91d34 --- /dev/null +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -0,0 +1,172 @@ +"""This docstring details important information on the testing methodology. + +Most of the tests rely on "greedy equality", where we expect the output of +speculative decoding on a sequence to exactly match the output of normal non- +speculative decoding. + +Since speculative decoding with rejection sampling guarantees that the output +distribution matches the target model's output distribution (up to hardware +numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy +equality. + +For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding, +and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775. +Since there is no model is needed for generate the proposal, we could make +the testcase much simpler than drafter multi-step one. + +However, we still need to verify below scenario could be passed: + * Batch size 1 greedy equality + * Batch size >1 greedy equality + * Test greedy equality under preemption + * Test greedy equality under various ngram sizes / speculative sizes + +With those tests, we can say at least, ngram spec would not break the correctess +for the target model outputs. +""" + +import pytest + +from .conftest import run_greedy_equality_correctness_test + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "model": "JackFram/llama-68m", + }, +]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }, +]) +@pytest.mark.parametrize("output_len", [ + 256, +]) +@pytest.mark.parametrize("batch_size", [1, 64]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_e2e_greedy_correctness(baseline_llm_generator, + test_llm_generator, batch_size: int, + output_len: int): + """Verify greedy equality on a tiny model with different batch size.""" + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "block_size": 8, + # 2 for small prompt, 256//8 for generated. + "num_gpu_blocks_override": 2 + 256 // 8, + "max_model_len": (2 + 256 // 8) * 8, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "model": "JackFram/llama-160m", + }, +]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }, +]) +@pytest.mark.parametrize( + "output_len", + [ + # Use small output len for fast test. + 256, + ]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator, + test_llm_generator, + batch_size: int, + output_len: int): + """Verify greedy equality, even when some sequences are preempted mid- + generation. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize( + "test_llm_kwargs", + [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": k, + "ngram_prompt_lookup_max": 3, + } + # Try a range of common k, as well as large speculation. + for k in [1, 3, 5] + ] + [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": k, + "ngram_prompt_lookup_max": 1, + } + # Try a range of common k, as well as large speculation. + for k in [1, 3, 5] + ]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_different_k(baseline_llm_generator, test_llm_generator, + batch_size: int, output_len: int): + """Verify that ngram speculative decoding produces exact equality + to without spec decode with many different values of k and + different ngram_prompt_lookup_max. + """ + run_greedy_equality_correctness_test(baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index e7aaa1ff4eff8..cb2de97a4af94 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -5,13 +5,12 @@ import torch from vllm.model_executor.utils import set_random_seed -from vllm.sequence import SamplerOutput -from vllm.spec_decode.multi_step_worker import (DraftModelTop1Proposer, - MultiStepWorker) +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.worker.worker import Worker from .utils import (assert_logprobs_dict_allclose, create_batch, - create_execute_model_data, create_seq_group_metadata_from_prompts, create_worker, patch_execute_model_with_seeds, zero_kv_cache) @@ -34,7 +33,7 @@ def test_assert_enough_kv_space(num_steps: int): list(range(block_size * 2)), ] - final_seq_lens = [ + final_prompt_lens = [ len(prompt + output) + num_steps for prompt, output in zip(prompts, prev_output_tokens) ] @@ -43,7 +42,7 @@ def test_assert_enough_kv_space(num_steps: int): prompts, num_gpu_blocks, block_size, - final_seq_lens, + final_prompt_lens, continuations=prev_output_tokens) assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access @@ -103,29 +102,34 @@ def test_same_output_for_single_step(): [6, 7, 8, 9, 10], ] - final_seq_lens = [len(prompt) + num_steps for prompt in prompts] + final_prompt_lens = [len(prompt) + num_steps for prompt in prompts] - multi_step_execute_model_data = create_execute_model_data( - seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) - - single_step_execute_model_data = create_execute_model_data( - seq_group_metadata_list=create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, - final_seq_lens=final_seq_lens)) + multi_step_seq_group = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) zero_kv_cache(multi_step_worker.cache_engine) set_random_seed(seed) - actual_output = multi_step_worker.execute_model_multi_step( - **multi_step_execute_model_data.to_dict(), num_steps=num_steps) + actual_output, _ = multi_step_worker.sampler_output( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=multi_step_seq_group), + sample_len=num_steps) assert len(actual_output) == num_steps actual_output = actual_output[0] + single_step_seq_group = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) + zero_kv_cache(worker.cache_engine) set_random_seed(seed) expected_output = worker.execute_model( - **single_step_execute_model_data.to_dict(), )[0] + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=single_step_seq_group))[0] actual_token_ids = [ output.samples[0].output_token for output in actual_output @@ -181,7 +185,7 @@ def test_same_output_for_multi_step(): random.randint(0, 1000) for _ in range(random.randint(10, 20)) ] for _ in range(10)] - final_seq_lens = [len(prompt) + num_steps for prompt in prompts] + final_prompt_lens = [len(prompt) + num_steps for prompt in prompts] rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) multi_step_worker.execute_model = patch_execute_model_with_seeds( @@ -189,19 +193,20 @@ def test_same_output_for_multi_step(): worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds) continuations = [[1] for _ in prompts] - execute_model_data = create_execute_model_data( - create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=continuations, - final_seq_lens=final_seq_lens), ) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_prompt_lens=final_prompt_lens) # Run multi-step. zero_kv_cache(multi_step_worker.cache_engine) set_random_seed(seed) - multi_step_output = multi_step_worker.execute_model_multi_step( - **execute_model_data.to_dict(), num_steps=num_steps) + multi_step_output, _ = multi_step_worker.sampler_output( + execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list), + sample_len=num_steps) # Run single-step repeatedly. zero_kv_cache(worker.cache_engine) @@ -211,16 +216,16 @@ def test_same_output_for_multi_step(): for _ in multi_step_output: - execute_model_data = create_execute_model_data( - create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=continuations, - final_seq_lens=final_seq_lens)) + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + continuations=continuations, + final_prompt_lens=final_prompt_lens) single_step_output.extend( - worker.execute_model(**execute_model_data.to_dict(), )) + worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list))) # Append output tokens to new sequence data. for i, seq_group_output in enumerate(single_step_output[-1]): @@ -266,7 +271,7 @@ def test_same_output_for_multi_step(): @torch.inference_mode() def test_draft_proposals_full_speculation_len(): - """Verify DraftModelTop1Proposer correctly handles case where all sequences + """Verify Top1Proposer correctly handles case where all sequences can speculate. """ k = 10 @@ -275,33 +280,36 @@ def test_draft_proposals_full_speculation_len(): device = 'cuda:0' draft_worker = MagicMock() - proposer = DraftModelTop1Proposer( - draft_worker=draft_worker, + proposer = Top1Proposer( + worker=draft_worker, device=device, - max_model_len=2048, vocab_size=vocab_size, + max_proposal_len=2048, ) - draft_worker.execute_model_multi_step.return_value = [ + draft_worker.sampler_output.return_value = [ SamplerOutput( outputs=[], sampled_token_probs=torch.rand(batch_size, vocab_size, device=device, dtype=torch.float32), + logprobs=torch.rand(batch_size, + vocab_size, + device=device, + dtype=torch.float32), sampled_token_ids=torch.randint(low=0, high=vocab_size, size=(batch_size, ), device=device, dtype=torch.long), ) for _ in range(k) - ] + ], True - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) - proposals = proposer.get_proposals( - **execute_model_data.to_dict(), - max_proposal_len=k, - ) + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -315,7 +323,7 @@ def test_draft_proposals_full_speculation_len(): @torch.inference_mode() def test_draft_proposals_no_speculations(): - """Verify DraftModelTop1Proposer correctly handles case where no sequences + """Verify Top1Proposer correctly handles case where no sequences can speculate. """ k = 10 @@ -325,21 +333,20 @@ def test_draft_proposals_no_speculations(): prompt_len = 10 draft_worker = MagicMock() - proposer = DraftModelTop1Proposer( - draft_worker=draft_worker, + proposer = Top1Proposer( + worker=draft_worker, device=device, - max_model_len=prompt_len + k - 1, vocab_size=vocab_size, + max_proposal_len=prompt_len + k - 1, ) - execute_model_data, _, _ = create_batch(batch_size, - k, - prompt_len=prompt_len) + seq_group_metadata_list, _, _ = create_batch(batch_size, + k, + prompt_len=prompt_len) - proposals = proposer.get_proposals( - **execute_model_data.to_dict(), - max_proposal_len=k, - ) + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) @@ -353,7 +360,7 @@ def test_draft_proposals_no_speculations(): @torch.inference_mode() def test_draft_proposals_mixed_k(): - """Verify DraftModelTop1Proposer correctly handles case some sequences can + """Verify Top1Proposer correctly handles case some sequences can speculate and some can't. """ k = 10 @@ -374,20 +381,24 @@ def test_draft_proposals_mixed_k(): for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len] draft_worker = MagicMock() - proposer = DraftModelTop1Proposer( - draft_worker=draft_worker, + proposer = Top1Proposer( + worker=draft_worker, device=device, - max_model_len=long_prompt_len + prev_output_token_len + k - 1, vocab_size=vocab_size, + max_proposal_len=long_prompt_len + prev_output_token_len + k - 1, ) - draft_worker.execute_model_multi_step.return_value = [ + draft_worker.sampler_output.return_value = [ SamplerOutput( outputs=[], sampled_token_probs=torch.rand(expected_num_proposal_seqs, vocab_size, device=device, dtype=torch.float32), + logprobs=torch.rand(expected_num_proposal_seqs, + vocab_size, + device=device, + dtype=torch.float32), sampled_token_ids=torch.randint( low=0, high=vocab_size, @@ -395,19 +406,18 @@ def test_draft_proposals_mixed_k(): device=device, dtype=torch.long), ) for _ in range(k) - ] + ], True - execute_model_data, _, _ = create_batch( + seq_group_metadata_list, _, _ = create_batch( batch_size, k, prompt_len=prompt_len, prev_output_token_len=prev_output_token_len, ) - proposals = proposer.get_proposals( - **execute_model_data.to_dict(), - max_proposal_len=k, - ) + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k), ) assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_probs) diff --git a/tests/spec_decode/test_ngram_worker.py b/tests/spec_decode/test_ngram_worker.py new file mode 100644 index 0000000000000..de305c4030aa9 --- /dev/null +++ b/tests/spec_decode/test_ngram_worker.py @@ -0,0 +1,206 @@ +import torch + +from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.ngram_worker import NGramWorker +from vllm.spec_decode.top1_proposer import Top1Proposer + +from .utils import create_seq_group_metadata_from_prompts, create_worker + + +def test_ngram_algo_correctness_for_single_no_match(): + """Verify our ngram algo find the right candidate in the prompt + + For the scenario cannot find any candidate in one single batch + """ + block_size = 32 + num_gpu_blocks = 2048 // block_size + seed = 100 + model_name = 'JackFram/llama-68m' + vocab_size = 32_000 + device = 'cuda:0' + + ngram_worker = create_worker( + NGramWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + + proposer = Top1Proposer( + worker=ngram_worker, + device=device, + vocab_size=vocab_size, + max_proposal_len=20, + ) + + # set ngram window (0, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(0, 3) + + prompts = [ + # shall find no candidate + [1, 2, 3, 4, 5, 6, 7], + ] + + proposal_len = 5 + final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) + + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=proposal_len), ) + + assert torch.is_tensor(proposals.proposal_token_ids) + assert torch.is_tensor(proposals.proposal_probs) + + assert proposals.proposal_token_ids.shape == torch.Size([1, proposal_len]) + assert proposals.proposal_probs.shape[:-1] == torch.Size([1, proposal_len]) + assert proposals.proposal_lens.shape == torch.Size([1]) + assert proposals.proposal_lens.tolist() == [0] + + +def test_ngram_algo_correctness_for_batches_not_match_all(): + """Verify our ngram algo find the right candidate in the prompt + + For the scenario find some candidate not full in batchs + """ + block_size = 32 + num_gpu_blocks = 2048 // block_size + seed = 100 + model_name = 'JackFram/llama-68m' + vocab_size = 32_000 + device = 'cuda:0' + + ngram_worker = create_worker( + NGramWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + + proposer = Top1Proposer( + worker=ngram_worker, + device=device, + vocab_size=vocab_size, + max_proposal_len=20, + ) + + # set ngram window (0, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(0, 3) + + prompts = [ + # shall find no candidate + [1, 2, 3, 4, 5, 6, 7], + # shall find candidate 12,13,14,15,16 + [11, 12, 13, 14, 15, 16, 11], + # shall find candidate 23,24,25,26,21 + [21, 21, 22, 23, 24, 25, 26, 21, 22], + # shall find candidate 34,35,36,37,38 + [31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33], + # shall find no candidate as exceed max_proposal_len + [ + 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 33, 34, 35, 36, 37, + 38, 31, 32, 33 + ], + ] + + proposal_len = 5 + final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) + + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=proposal_len), ) + + assert torch.is_tensor(proposals.proposal_token_ids) + assert torch.is_tensor(proposals.proposal_probs) + + assert proposals.proposal_token_ids.shape == torch.Size([5, proposal_len]) + assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len]) + assert proposals.proposal_lens.shape == torch.Size([5]) + + assert proposals.proposal_lens.tolist( + ) == [proposal_len for _ in range(4)] + [0] + + for i in range(proposal_len): + assert proposals.proposal_token_ids[0][i] == 0 + assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1] + assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3] + assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5] + assert proposals.proposal_token_ids[4][i] == -1 + + +def test_ngram_algo_correctness_for_batches_match_all(): + """Verify our ngram algo find the right candidate in the prompt + + For the scenario find candidate in all batchs + """ + + block_size = 32 + num_gpu_blocks = 2048 // block_size + seed = 100 + model_name = 'JackFram/llama-68m' + vocab_size = 32_000 + device = 'cuda:0' + + ngram_worker = create_worker( + NGramWorker, + model_name, + block_size, + num_gpu_blocks, + seed, + ) + + proposer = Top1Proposer( + worker=ngram_worker, + device=device, + vocab_size=vocab_size, + max_proposal_len=20, + ) + + # set ngram window (0, 3], which is window=1/2/3 + ngram_worker.set_ngram_window_size(0, 3) + + prompts = [ + # shall find candidate 12,13,14,15,16 + [11, 12, 13, 14, 15, 16, 11], + # shall find candidate 23,24,25,26,21 + [21, 21, 22, 23, 24, 25, 26, 21, 22], + # shall find candidate 34,35,36,37,38 + [31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33], + ] + + proposal_len = 5 + final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, + num_gpu_blocks, + block_size, + final_prompt_lens=final_prompt_lens) + + proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=proposal_len), ) + + assert torch.is_tensor(proposals.proposal_token_ids) + assert torch.is_tensor(proposals.proposal_probs) + + assert proposals.proposal_token_ids.shape == torch.Size([3, proposal_len]) + assert proposals.proposal_probs.shape[:-1] == torch.Size([3, proposal_len]) + assert proposals.proposal_lens.shape == torch.Size([3]) + + assert proposals.proposal_lens.tolist() == [proposal_len for _ in range(3)] + + for i in range(proposal_len): + assert proposals.proposal_token_ids[0][i] == prompts[0][i + 1] + assert proposals.proposal_token_ids[1][i] == prompts[1][i + 3] + assert proposals.proposal_token_ids[2][i] == prompts[2][i + 5] diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index d24d726c9c0cf..ef9d32f73d668 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -7,7 +7,7 @@ from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.utils import set_random_seed -from vllm.sequence import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.metrics import (AsyncMetricsCollector, SpecDecodeWorkerMetrics) @@ -15,8 +15,7 @@ from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker, split_num_cache_blocks_evenly) -from .utils import (ExecuteModelData, create_batch, create_sampler_output_list, - mock_worker) +from .utils import create_batch, create_sampler_output_list, mock_worker @pytest.mark.parametrize('k', [1, 2, 6]) @@ -33,27 +32,22 @@ def test_correctly_calls_draft_model(k: int, batch_size: int): worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, metrics_collector) - exception_secret = 'artifical stop' + exception_secret = 'artificial stop' draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + worker.execute_model(execute_model_req=execute_model_req) call_args_list = draft_worker.get_spec_proposals.call_args_list assert len(call_args_list) == 1 for args, _ in call_args_list: - (seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, - blocks_to_copy, actual_k) = args - actual_execute_model_data = ExecuteModelData(seq_group_metadata_list, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy) - assert actual_execute_model_data == execute_model_data - assert actual_k == k + actual_execute_model_data = args[0] + assert actual_execute_model_data == execute_model_req @pytest.mark.parametrize('k', [1, 2, 6]) @@ -93,7 +87,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int): proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k - execute_model_data, prompts, prev_output_tokens = create_batch( + seq_group_metadata_list, prompts, prev_output_tokens = create_batch( batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( @@ -101,24 +95,24 @@ def test_correctly_calls_target_model(k: int, batch_size: int): proposal_probs=proposal_probs, proposal_lens=proposal_lens) - exception_secret = 'artifical stop' + exception_secret = 'artificial stop' target_worker.execute_model.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k)) seen_contexts = [] call_args_list = target_worker.execute_model.call_args_list assert len(call_args_list) == 1 - for args, kwargs in call_args_list: - target_execute_model_data = ExecuteModelData.from_dict(kwargs) + for _, kwargs in call_args_list: + seq_group_metadata_list = kwargs[ + "execute_model_req"].seq_group_metadata_list - assert len(target_execute_model_data.seq_group_metadata_list) == ( - k + 1) * batch_size - for seq_group_metadata in ( - target_execute_model_data.seq_group_metadata_list): + assert len(seq_group_metadata_list) == (k + 1) * batch_size + for seq_group_metadata in seq_group_metadata_list: for seq_data in seq_group_metadata.seq_data.values(): seen_contexts.append(seq_data.get_token_ids()) @@ -175,7 +169,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( proposal_token_ids=proposal_token_ids, @@ -192,17 +186,24 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int): vocab_size, dtype=torch.float32, device='cuda') + target_token_logprobs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') target_output = create_sampler_output_list(target_token_ids, - target_token_probs) + target_token_probs, + target_token_logprobs) target_worker.execute_model.return_value = [target_output[0]] - exception_secret = 'artifical stop' + exception_secret = 'artificial stop' rejection_sampler.side_effect = ValueError(exception_secret) with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k)) assert len(rejection_sampler.call_args_list) == 1 _, kwargs = rejection_sampler.call_args_list[0] @@ -256,7 +257,7 @@ def test_correctly_formats_output(k: int, batch_size: int): proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( proposal_token_ids=proposal_token_ids, @@ -273,8 +274,14 @@ def test_correctly_formats_output(k: int, batch_size: int): vocab_size, dtype=torch.float32, device='cuda') + target_token_logprobs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') target_output = create_sampler_output_list(target_token_ids, - target_token_probs) + target_token_probs, + target_token_logprobs) target_worker.execute_model.return_value = [target_output[0]] @@ -290,15 +297,18 @@ def test_correctly_formats_output(k: int, batch_size: int): rejection_sampler.return_value = rejection_sampler_output - output = worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + output = worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k)) expected_output = create_sampler_output_list( - rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)]) + token_ids=rejection_sampler_output.transpose(0, 1), + probs=[None for _ in range(k + 1)], + logprobs=[None for _ in range(k + 1)]) seq_ids = [ next(iter(seq_group_metadata.seq_data.keys())) - for seq_group_metadata in execute_model_data.seq_group_metadata_list + for seq_group_metadata in seq_group_metadata_list ] actual_output_by_seq = {seq_id: [] for seq_id in seq_ids} expected_output_by_seq = {seq_id: [] for seq_id in seq_ids} @@ -328,7 +338,6 @@ def test_correctly_formats_output(k: int, batch_size: int): continue assert actual_by_step[i].output_token == expected_by_step[ i].output_token - assert actual_by_step[i].logprobs == expected_by_step[i].logprobs @pytest.mark.parametrize('k', [1, 2]) @@ -370,7 +379,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): proposal_lens = torch.ones(batch_size, dtype=torch.int64, device='cuda') * k - execute_model_data, _, _ = create_batch(batch_size, k) + seq_group_metadata_list, _, _ = create_batch(batch_size, k) draft_worker.get_spec_proposals.return_value = SpeculativeProposals( proposal_token_ids=proposal_token_ids, @@ -387,8 +396,14 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): vocab_size, dtype=torch.float32, device='cuda') + target_token_logprobs = torch.rand(1, + batch_size * (k + 1), + vocab_size, + dtype=torch.float32, + device='cuda') target_output = create_sampler_output_list(target_token_ids, - target_token_probs) + target_token_probs, + target_token_logprobs) target_worker.execute_model.return_value = [target_output[0]] @@ -409,8 +424,9 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool): metrics_collector.maybe_collect_rejsample_metrics.return_value = ( mock_rejsample_metrics) - output = worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + output = worker.execute_model(execute_model_req=ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + num_lookahead_slots=k)) assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics call_args_list = ( @@ -443,21 +459,21 @@ def test_k_equals_zero(k: int, batch_size: int): worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, metrics_collector) - execute_model_data, prompts, prev_output_tokens = create_batch( - batch_size, k, prev_output_token_len=0) + seq_group_metadata_list, _, _ = create_batch(batch_size, + k, + prev_output_token_len=0) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k) - out = worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + out = worker.execute_model(execute_model_req=execute_model_req) assert len(out) == 1, f"expected only one token output when {k=}" assert out[0].probs is None, "expect gpu tensor references to be None" assert out[ 0].sampled_tokens is None, "expect gpu tensor references to be None" - draft_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict()) - target_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict()) + draft_worker.execute_model.assert_called_once_with(execute_model_req) + target_worker.execute_model.assert_called_once_with(execute_model_req) @pytest.mark.parametrize('k', [0, 5]) @@ -484,21 +500,21 @@ def test_empty_input_batch(k: int, batch_size: int): worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, metrics_collector) - execute_model_data, prompts, prev_output_tokens = create_batch( - batch_size, k, prev_output_token_len=0) + seq_group_metadata_list, _, _ = create_batch(batch_size, + k, + prev_output_token_len=0) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k) - out = worker.execute_model(**execute_model_data.to_dict(), - num_lookahead_slots=k) + out = worker.execute_model(execute_model_req=execute_model_req) assert len(out) == 1, f"expected only one token output when {k=}" assert out[0].probs is None, "expect gpu tensor references to be None" assert out[ 0].sampled_tokens is None, "expect gpu tensor references to be None" - draft_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict()) - target_worker.execute_model.assert_called_once_with( - **execute_model_data.to_dict()) + draft_worker.execute_model.assert_called_once_with(execute_model_req) + target_worker.execute_model.assert_called_once_with(execute_model_req) @pytest.mark.skip_global_cleanup diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 4f8295d25cf41..f288652d51556 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass, fields from itertools import count from typing import Dict, Iterable, List, Optional, Union from unittest.mock import MagicMock @@ -16,50 +15,10 @@ from vllm.worker.worker import Worker -@dataclass -class ExecuteModelData: - """Helper data structure which facilitates cleaner tests. - """ - seq_group_metadata_list: List[SequenceGroupMetadata] - blocks_to_swap_in: Dict[int, int] - blocks_to_swap_out: Dict[int, int] - blocks_to_copy: Dict[int, List[int]] - - def to_dict(self): - return dict( - (field.name, getattr(self, field.name)) for field in fields(self)) - - @classmethod - def from_dict(cls, d): - cleaned = dict((field.name, d[field.name]) for field in fields(cls)) - return cls(**cleaned) - - def round_up_to_next_block(seq_len: int, block_size: int) -> int: return (seq_len + block_size - 1) // block_size -def create_execute_model_data( - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]] = None, - blocks_to_swap_out: Optional[Dict[int, int]] = None, - blocks_to_copy: Optional[Dict[int, int]] = None, -) -> ExecuteModelData: - if blocks_to_swap_in is None: - blocks_to_swap_in = {} - if blocks_to_swap_out is None: - blocks_to_swap_out = {} - if blocks_to_copy is None: - blocks_to_copy = {} - - return ExecuteModelData( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) - - def mock_worker(cls=None, vocab_size: int = 30_000, max_model_len: int = 2048, @@ -144,7 +103,7 @@ def create_seq_group_metadata_from_prompts( prompts: List[List[int]], num_gpu_blocks: int, block_size: int, - final_seq_lens: List[int], + final_prompt_lens: List[int], continuations: Optional[List[List[int]]] = None, seq_ids: Optional[List[int]] = None, ) -> List[SequenceGroupMetadata]: @@ -162,7 +121,7 @@ def create_seq_group_metadata_from_prompts( free_gpu_blocks.pop() for _ in range(round_up_to_next_block(final_len, block_size)) ] - for i, final_len in enumerate(final_seq_lens) + for i, final_len in enumerate(final_prompt_lens) } return [ @@ -201,6 +160,7 @@ def assert_logprobs_dict_allclose( def create_sampler_output_list( token_ids: torch.Tensor, probs: Iterable[Optional[torch.Tensor]], + logprobs: Iterable[Optional[torch.Tensor]], seq_ids: Optional[List[int]] = None) -> List[SamplerOutput]: num_steps, batch_size = token_ids.shape token_ids_by_step = token_ids.tolist() @@ -222,6 +182,7 @@ def create_sampler_output_list( ) for seq_index, token_id in enumerate(token_ids_by_step[step]) ], sampled_token_probs=probs[step], + logprobs=logprobs[step], sampled_token_ids=token_ids[step]) for step in range(num_steps) ] @@ -251,13 +212,12 @@ def create_batch(batch_size, prev_output_tokens = [[ next(iterator) for _ in range(prev_output_token_len) ] for _ in range(batch_size)] - final_seq_lens = [ + final_prompt_lens = [ len(prompt) + len(prev_output_token) + k + 1 for prompt, prev_output_token in zip(prompts, prev_output_tokens) ] - execute_model_data = create_execute_model_data( - create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks, - block_size, final_seq_lens, - prev_output_tokens, seq_ids), ) - return execute_model_data, prompts, prev_output_tokens + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, num_gpu_blocks, block_size, final_prompt_lens, + prev_output_tokens, seq_ids) + return seq_group_metadata_list, prompts, prev_output_tokens diff --git a/tests/test_logger.py b/tests/test_logger.py index 601f72b50811c..74f1125fb37c9 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,8 +1,19 @@ +import json +import logging import os import sys import tempfile +from json.decoder import JSONDecodeError +from tempfile import NamedTemporaryFile +from typing import Any +from unittest.mock import patch +from uuid import uuid4 -from vllm.logger import enable_trace_function_call +import pytest + +from vllm.logger import (_DATE_FORMAT, _FORMAT, _configure_vllm_root_logger, + enable_trace_function_call, init_logger) +from vllm.logging import NewLineFormatter def f1(x): @@ -25,3 +36,179 @@ def test_trace_function_call(): assert "f2" in content sys.settrace(None) os.remove(path) + + +def test_default_vllm_root_logger_configuration(): + """This test presumes that VLLM_CONFIGURE_LOGGING (default: True) and + VLLM_LOGGING_CONFIG_PATH (default: None) are not configured and default + behavior is activated.""" + logger = logging.getLogger("vllm") + assert logger.level == logging.DEBUG + assert not logger.propagate + + handler = logger.handlers[0] + assert handler.stream == sys.stdout + assert handler.level == logging.INFO + + formatter = handler.formatter + assert formatter is not None + assert isinstance(formatter, NewLineFormatter) + assert formatter._fmt == _FORMAT + assert formatter.datefmt == _DATE_FORMAT + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) +@patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", None) +def test_descendent_loggers_depend_on_and_propagate_logs_to_root_logger(): + """This test presumes that VLLM_CONFIGURE_LOGGING (default: True) and + VLLM_LOGGING_CONFIG_PATH (default: None) are not configured and default + behavior is activated.""" + root_logger = logging.getLogger("vllm") + root_handler = root_logger.handlers[0] + + unique_name = f"vllm.{uuid4()}" + logger = init_logger(unique_name) + assert logger.name == unique_name + assert logger.level == logging.NOTSET + assert not logger.handlers + assert logger.propagate + + message = "Hello, world!" + with patch.object(root_handler, "emit") as root_handle_mock: + logger.info(message) + + root_handle_mock.assert_called_once() + _, call_args, _ = root_handle_mock.mock_calls[0] + log_record = call_args[0] + assert unique_name == log_record.name + assert message == log_record.msg + assert message == log_record.msg + assert log_record.levelno == logging.INFO + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 0) +@patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", None) +def test_logger_configuring_can_be_disabled(): + """This test calls _configure_vllm_root_logger again to test custom logging + config behavior, however mocks are used to ensure no changes in behavior or + configuration occur.""" + + with patch("logging.config.dictConfig") as dict_config_mock: + _configure_vllm_root_logger() + dict_config_mock.assert_not_called() + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) +@patch( + "vllm.logger.VLLM_LOGGING_CONFIG_PATH", + "/if/there/is/a/file/here/then/you/did/this/to/yourself.json", +) +def test_an_error_is_raised_when_custom_logging_config_file_does_not_exist(): + """This test calls _configure_vllm_root_logger again to test custom logging + config behavior, however it fails before any change in behavior or + configuration occurs.""" + with pytest.raises(RuntimeError) as ex_info: + _configure_vllm_root_logger() + assert ex_info.type == RuntimeError + assert "File does not exist" in str(ex_info) + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) +def test_an_error_is_raised_when_custom_logging_config_is_invalid_json(): + """This test calls _configure_vllm_root_logger again to test custom logging + config behavior, however it fails before any change in behavior or + configuration occurs.""" + with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: + logging_config_file.write("---\nloggers: []\nversion: 1") + logging_config_file.flush() + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", + logging_config_file.name): + with pytest.raises(JSONDecodeError) as ex_info: + _configure_vllm_root_logger() + assert ex_info.type == JSONDecodeError + assert "Expecting value" in str(ex_info) + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) +@pytest.mark.parametrize("unexpected_config", ( + "Invalid string", + [{ + "version": 1, + "loggers": [] + }], + 0, +)) +def test_an_error_is_raised_when_custom_logging_config_is_unexpected_json( + unexpected_config: Any): + """This test calls _configure_vllm_root_logger again to test custom logging + config behavior, however it fails before any change in behavior or + configuration occurs.""" + with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: + logging_config_file.write(json.dumps(unexpected_config)) + logging_config_file.flush() + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", + logging_config_file.name): + with pytest.raises(ValueError) as ex_info: + _configure_vllm_root_logger() + assert ex_info.type == ValueError + assert "Invalid logging config. Expected Dict, got" in str(ex_info) + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 1) +def test_custom_logging_config_is_parsed_and_used_when_provided(): + """This test calls _configure_vllm_root_logger again to test custom logging + config behavior, however mocks are used to ensure no changes in behavior or + configuration occur.""" + valid_logging_config = { + "loggers": { + "vllm.test_logger.logger": { + "handlers": [], + "propagate": False, + } + }, + "version": 1 + } + with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: + logging_config_file.write(json.dumps(valid_logging_config)) + logging_config_file.flush() + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", + logging_config_file.name), patch( + "logging.config.dictConfig") as dict_config_mock: + _configure_vllm_root_logger() + assert dict_config_mock.called_with(valid_logging_config) + + +@patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 0) +def test_custom_logging_config_causes_an_error_if_configure_logging_is_off(): + """This test calls _configure_vllm_root_logger again to test custom logging + config behavior, however mocks are used to ensure no changes in behavior or + configuration occur.""" + valid_logging_config = { + "loggers": { + "vllm.test_logger.logger": { + "handlers": [], + } + }, + "version": 1 + } + with NamedTemporaryFile(encoding="utf-8", mode="w") as logging_config_file: + logging_config_file.write(json.dumps(valid_logging_config)) + logging_config_file.flush() + with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", + logging_config_file.name): + with pytest.raises(RuntimeError) as ex_info: + _configure_vllm_root_logger() + assert ex_info.type is RuntimeError + expected_message_snippet = ( + "VLLM_CONFIGURE_LOGGING evaluated to false, but " + "VLLM_LOGGING_CONFIG_PATH was given.") + assert expected_message_snippet in str(ex_info) + + # Remember! The root logger is assumed to have been configured as + # though VLLM_CONFIGURE_LOGGING=1 and VLLM_LOGGING_CONFIG_PATH=None. + root_logger = logging.getLogger("vllm") + other_logger_name = f"vllm.test_logger.{uuid4()}" + other_logger = init_logger(other_logger_name) + assert other_logger.handlers != root_logger.handlers + assert other_logger.level != root_logger.level + assert other_logger.propagate diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index dbaeb4de18258..179e8d25a341b 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -70,7 +70,7 @@ def pick_ith(token_ids, logits): return logits seq_group_metadata_list = [] - prompt_lens = [] + seq_lens = [] for i in range(batch_size): seq_group_metadata_list.append( SequenceGroupMetadata( @@ -81,12 +81,12 @@ def pick_ith(token_ids, logits): logits_processors=[pick_ith]), block_tables={0: [1]}, )) - prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) + seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=model_runner.device, pin_memory=model_runner.pin_memory) logits_processor_output = logits_processor( diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 56fe6db589f18..e7975d0ef48b9 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -23,14 +23,14 @@ def test_prepare_prompt(batch_size): lora_config=None) model_runner.set_block_size(16) - prompt_lens = [] + seq_lens = [] seq_group_metadata_list = [] block_tables = {0: [1]} for i in range(batch_size): # make sure all tokens fit into one block - prompt_len = i % (model_runner.block_size - 1) + 1 - prompt_lens.append(prompt_len) - seq_data = SequenceData(list(range(prompt_len))) + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -43,29 +43,29 @@ def test_prepare_prompt(batch_size): expected_selected_token_indices = [] selected_token_start_idx = 0 - for prompt_len in prompt_lens: + for seq_len in seq_lens: expected_selected_token_indices.append(selected_token_start_idx + - prompt_len - 1) - selected_token_start_idx += prompt_len - (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, - _, _, - slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) - assert return_prompt_lens == prompt_lens + seq_len - 1) + selected_token_start_idx += seq_len + (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, + _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + assert return_seq_lens == seq_lens assert len(slot_mapping) == len(input_tokens) # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.is_prompt is True - assert torch.allclose(attn_metadata.prompt_lens_tensor, - torch.tensor(prompt_lens, device=device)) - assert attn_metadata.prompt_lens == prompt_lens - assert attn_metadata.max_prompt_len == max(prompt_lens) + assert torch.allclose( + attn_metadata.seq_lens_tensor, + torch.tensor(seq_lens, device=device, dtype=torch.int)) + assert attn_metadata.seq_lens == seq_lens + assert attn_metadata.max_seq_len == max(seq_lens) # Test subquery start locs. start_idx = 0 start_loc = [start_idx] - for prompt_len in prompt_lens: - start_idx += prompt_len + for seq_len in seq_lens: + start_idx += seq_len start_loc.append(start_idx) assert torch.allclose( attn_metadata.subquery_start_loc, @@ -75,17 +75,16 @@ def test_prepare_prompt(batch_size): # equivalent to subquery_start_loc. start_idx = 0 seq_start_loc = [start_idx] - for prompt_len in prompt_lens: - start_idx += prompt_len + for seq_len in seq_lens: + start_idx += seq_len seq_start_loc.append(start_idx) assert torch.allclose( attn_metadata.seq_start_loc, torch.tensor(start_loc, dtype=torch.int32, device=device)) - assert attn_metadata.max_context_len is None assert torch.allclose( - attn_metadata.context_lens, - torch.zeros(attn_metadata.context_lens.shape[0], + attn_metadata.context_lens_tensor, + torch.zeros(attn_metadata.context_lens_tensor.shape[0], dtype=torch.int, device=device)) @@ -96,18 +95,18 @@ def test_prepare_prompt(batch_size): # Cuda graph should not be used for prerill. assert attn_metadata.use_cuda_graph is False - assert len(input_tokens) == sum(prompt_lens) - assert len(input_positions) == sum(prompt_lens) + assert len(input_tokens) == sum(seq_lens) + assert len(input_positions) == sum(seq_lens) torch.testing.assert_close(input_tokens, input_positions) sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=model_runner.device, pin_memory=model_runner.pin_memory) - assert len(input_tokens) == sum(prompt_lens) - assert len(input_positions) == sum(prompt_lens) + assert len(input_tokens) == sum(seq_lens) + assert len(input_positions) == sum(seq_lens) actual = sampling_metadata.selected_token_indices expected = torch.tensor(expected_selected_token_indices, device=actual.device, @@ -146,13 +145,13 @@ def test_prepare_decode_cuda_graph(batch_size): lora_config=None) model_runner.set_block_size(16) - prompt_lens = [] + seq_lens = [] seq_group_metadata_list = [] for i in range(batch_size): # make sure all tokens fit into one block - prompt_len = i % (model_runner.block_size - 1) + 1 - prompt_lens.append(prompt_len) - seq_data = list(range(prompt_len)) + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = list(range(seq_len)) seq_data = SequenceData(seq_data) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -172,14 +171,13 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.is_prompt is False - assert attn_metadata.prompt_lens is None - assert attn_metadata.max_prompt_len is None + assert attn_metadata.seq_lens is None assert attn_metadata.subquery_start_loc is None assert attn_metadata.seq_start_loc is None - assert attn_metadata.max_context_len == max(prompt_lens) + assert attn_metadata.max_seq_len == max(seq_lens) assert torch.allclose( - attn_metadata.context_lens[:len(prompt_lens)], - torch.tensor(prompt_lens, dtype=torch.int, device=device)) + attn_metadata.seq_lens_tensor[:len(seq_lens)], + torch.tensor(seq_lens, dtype=torch.int, device=device)) # block table's first index corresponds to each batch, meaning in # decoding it is each token. @@ -198,13 +196,13 @@ def test_prepare_decode_cuda_graph(batch_size): # Verify Sampling expected_selected_token_indices = [] selected_token_start_idx = 0 - for prompt_len in prompt_lens: + for seq_len in seq_lens: expected_selected_token_indices.append(selected_token_start_idx) selected_token_start_idx += 1 sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - subquery_lens=prompt_lens, + seq_lens, + query_lens=seq_lens, device=model_runner.device, pin_memory=model_runner.pin_memory) actual = sampling_metadata.selected_token_indices @@ -241,14 +239,13 @@ def test_empty_seq_group(): assert attn_metadata is None assert len(slot_mapping) == 0 - (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, - _, _, - slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) + (input_tokens, input_positions, attn_metadata, return_seq_lens, _, _, _, _, + _, slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list)) assert len(input_tokens) == 0 assert len(input_positions) == 0 assert attn_metadata is None assert len(slot_mapping) == 0 - assert len(return_prompt_lens) == 0 + assert len(return_seq_lens) == 0 @pytest.fixture @@ -288,7 +285,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): model_runner.set_block_size(16) # Add prefill requests. - prompt_lens = [] + seq_lens = [] seq_group_metadata_list = [] prefill_metadata_list = [] decode_metadata_list = [] @@ -297,9 +294,9 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): decode_batch_size = batch_size - prefill_batch_size for i in range(prefill_batch_size): # make sure all tokens fit into one block - prompt_len = i % (model_runner.block_size - 1) + 1 - prompt_lens.append(prompt_len) - seq_data = SequenceData(list(range(prompt_len))) + seq_len = i % (model_runner.block_size - 1) + 1 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -314,8 +311,8 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): # Add decode requests for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block - prompt_len = i % (model_runner.block_size - 1) + 1 - prompt_toks = list(range(prompt_len)) + seq_len = i % (model_runner.block_size - 1) + 1 + prompt_toks = list(range(seq_len)) seq_data = SequenceData(prompt_toks) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -343,7 +340,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): else: assert attn_metadata.num_decode_tokens == _get_graph_batch_size( decode_batch_size) - assert attn_metadata.num_prefill_tokens == sum(prompt_lens) + assert attn_metadata.num_prefill_tokens == sum(seq_lens) # Verify attn metadata is consistent. We don't need to test individual # values here because they are tested above. diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 1804cf78d8003..4d2d3add27d59 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -1,6 +1,7 @@ import torch from vllm.engine.arg_utils import EngineArgs +from vllm.sequence import ExecuteModelRequest from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.worker import Worker @@ -54,10 +55,14 @@ def test_swap() -> None: # Test swap out. blocks_to_swap_out = {3: 72, 56: 35, 84: 34} - worker.execute_model(seq_group_metadata_list=[], - blocks_to_swap_in={}, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy={}) + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=[], + blocks_to_swap_in={}, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=[], + ) + worker.execute_model(execute_model_req=execute_model_req) + for i in range(num_layers): gpu_key_cache, gpu_value_cache = gpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i] @@ -66,14 +71,19 @@ def test_swap() -> None: assert allclose(gpu_value_cache[src], cpu_value_cache[dst]) # Test swap in. - blocks_to_swap_in = {19: 45, 67: 23, 12: 78, 40: 99, 1: 71} - worker.execute_model(seq_group_metadata_list=[], - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out={}, - blocks_to_copy={}) + execute_model_req.blocks_to_swap_out = {} + execute_model_req.blocks_to_swap_in = { + 19: 45, + 67: 23, + 12: 78, + 40: 99, + 1: 71 + } + worker.execute_model(execute_model_req=execute_model_req) + for i in range(num_layers): gpu_key_cache, gpu_value_cache = gpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i] - for src, dst in blocks_to_swap_in.items(): + for src, dst in execute_model_req.blocks_to_swap_in.items(): assert allclose(gpu_key_cache[dst], cpu_key_cache[src]) assert allclose(gpu_value_cache[dst], cpu_value_cache[src]) diff --git a/vllm/__init__.py b/vllm/__init__.py index ca454efd44b24..59810da3ca411 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -9,7 +9,7 @@ from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams -__version__ = "0.4.1" +__version__ = "0.4.2" __all__ = [ "LLM", diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 4af8b09b1e16c..5b56437487477 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -39,17 +39,17 @@ def paged_attention_v1( num_kv_heads: int, scale: float, block_tables: torch.Tensor, - context_lens: torch.Tensor, + seq_lens: torch.Tensor, block_size: int, - max_context_len: int, + max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, kv_scale: float, ) -> None: vllm_ops.paged_attention_v1(out, query, key_cache, value_cache, - num_kv_heads, scale, block_tables, - context_lens, block_size, max_context_len, - alibi_slopes, kv_cache_dtype, kv_scale) + num_kv_heads, scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, kv_scale) def paged_attention_v2( @@ -63,17 +63,17 @@ def paged_attention_v2( num_kv_heads: int, scale: float, block_tables: torch.Tensor, - context_lens: torch.Tensor, + seq_lens: torch.Tensor, block_size: int, - max_context_len: int, + max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, kv_scale: float, ) -> None: vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, - block_tables, context_lens, block_size, - max_context_len, alibi_slopes, kv_cache_dtype, + block_tables, seq_lens, block_size, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale) @@ -169,18 +169,20 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, # gptq_marlin def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, - size_k: int, size_n: int) -> torch.Tensor: - return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n) + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, + num_bits) def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, g_idx: torch.Tensor, - perm: torch.Tensor, workspace: torch.Tensor, size_m: int, - size_n: int, size_k: int, + perm: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: int, size_n: int, size_k: int, is_k_full: bool) -> torch.Tensor: return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, - workspace, size_m, size_n, size_k, - is_k_full) + workspace, num_bits, size_m, size_n, + size_k, is_k_full) # fp8 @@ -220,6 +222,18 @@ def reshape_and_cache( slot_mapping, kv_cache_dtype, kv_scale) +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, +) -> None: + vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, + slot_mapping, kv_cache_dtype) + + def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, block_mapping: torch.Tensor) -> None: vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index be747c9900368..b2b6e7ac810e3 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, fields -from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar +from typing import (Any, Dict, Generic, List, Optional, Set, Tuple, Type, + TypeVar) import torch @@ -15,7 +16,7 @@ def get_impl_cls() -> Type["AttentionImpl"]: @staticmethod @abstractmethod - def make_metadata(*args, **kwargs) -> "AttentionMetadata": + def make_metadata(*args, **kwargs) -> "AttentionMetadataPerStage": raise NotImplementedError @staticmethod @@ -41,7 +42,7 @@ def swap_blocks( @abstractmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: raise NotImplementedError @@ -50,13 +51,17 @@ def copy_blocks( class AttentionMetadataPerStage: """Attention metadata for a specific stage. I.e., prefill or decode.""" - def asdict_zerocopy(self) -> Dict[str, Any]: + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: """Similar to dataclasses.asdict, but avoids deepcopying.""" + if skip_fields is None: + skip_fields = set() # Note that if we add dataclasses as fields, they will need # similar handling. return { field.name: getattr(self, field.name) - for field in fields(self) + for field in fields(self) if field.name not in skip_fields } diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 12e8c4404b94e..da672d5df6161 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -48,7 +48,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) @@ -66,27 +66,24 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] - # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| - # |-------------------- seqlen ----------------------| - # |- subquery_len -| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| - # WARNING(sang): context_len has different definition depending on if it is - # prefill vs decoding. When it is prefill, it doesn't include new tokens. - # When it is for decoding, it includes a new token. - - # Maximum subquery length in the batch. - max_subquery_len: Optional[int] - # Maximum prompt length in the batch. - max_prompt_len: Optional[int] + # Maximum query length in the batch. + max_query_len: Optional[int] + # Maximum sequence length in the batch. + max_seq_len: Optional[int] # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. @@ -95,6 +92,9 @@ class FlashAttentionMetadata(AttentionMetadataPerStage, # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. @@ -223,8 +223,8 @@ def forward( v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prompt_len, - max_seqlen_k=prefill_meta.max_prompt_len, + max_seqlen_q=prefill_meta.max_seq_len, + max_seqlen_k=prefill_meta.max_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, @@ -245,10 +245,11 @@ def forward( value_cache, prefill_meta.block_tables, prefill_meta.subquery_start_loc, - prefill_meta.prompt_lens_tensor, - prefill_meta.context_lens, - prefill_meta.max_subquery_len, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, self.alibi_slopes, + self.sliding_window[0], ) if decode_meta := attn_metadata.decode_metadata: # Decoding run. @@ -257,8 +258,8 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.context_lens, - decode_meta.max_context_len, + decode_meta.seq_lens_tensor, + decode_meta.max_seq_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py new file mode 100644 index 0000000000000..2851cbe2396b2 --- /dev/null +++ b/vllm/attention/backends/flashinfer.py @@ -0,0 +1,220 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Set, Tuple, Type + +try: + import flashinfer + from flash_attn import flash_attn_varlen_func + from flashinfer import BatchDecodeWithPagedKVCacheWrapper +except ImportError: + flashinfer = None + flash_attn_varlen_func = None + BatchDecodeWithPagedKVCacheWrapper = None + +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, + AttentionMetadataPerStage) + + +class FlashInferBackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["FlashInferImpl"]: + return FlashInferImpl + + @staticmethod + def make_metadata(*args, **kwargs) -> "FlashInferMetadata": + return FlashInferMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + raise NotImplementedError + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + raise NotImplementedError + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 128, 256] + + +@dataclass +class FlashInferMetadata(AttentionMetadataPerStage): + + is_prompt: bool + + use_cuda_graph: bool = False + + decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None + + # Metadata for the prefill stage since we still + # use flash attention for prefill. + seq_start_loc: Optional[torch.Tensor] = None + max_seq_len: Optional[int] = None + block_tables: Optional[torch.Tensor] = None + + # Metadata for the decode stage + # Workspace buffer required by the kernel, the buffer should not + # be allocated/deacollated by the FalshInfermetadata object. + workspace_buffer: Optional[torch.Tensor] = None + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: Optional[torch.Tensor] = None + # The number of query/output heads + num_qo_heads: Optional[int] = None + # The number of key/value heads + num_kv_heads: Optional[int] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + # Block size of vllm + page_size: Optional[int] = None + # The data type of the paged kv cache + data_type: torch.dtype = None + + def __post_init__(self): + # Refer to + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + supported_head_sizes = FlashInferBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f"received {self.head_dim}.") + + # When using flashinfer, we are also creating the FlashInferMetadata, + # which will also call post_init by default, here we want to skip the + # post_init if it's the prefill phase. + if not self.is_prompt: + self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, "NHD") + self.decode_wrapper.begin_forward( + self.paged_kv_indptr, + self.paged_kv_indices, + self.paged_kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + data_type=self.data_type) + + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: + if skip_fields is None: + skip_fields = set() + # We need to skip the decode_wrapper field since it cannot be + # broadcasted with nccl when TP is enabled. + skip_fields.add('decode_wrapper') + return super().asdict_zerocopy(skip_fields) + + +class FlashInferImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + if sliding_window is not None: + raise ValueError("Sliding window is not supported in FlashInfer.") + self.sliding_window = (-1, -1) + self.alibi_slopes = alibi_slopes + self.scale = scale + self.num_heads = num_heads + self.head_size = head_size + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + def forward(self, query: torch.Tensor, key: torch.Tensor, + value: torch.Tensor, kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata[FlashInferMetadata], + kv_scale: float): + num_tokens, hidden_size = query.shape + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if attn_metadata.num_prefill_tokens > 0: + assert attn_metadata.num_decode_tokens == 0, ( + "Chunked prefill is not supported with flashinfer yet.") + if attn_metadata.num_decode_tokens > 0: + assert attn_metadata.num_prefill_tokens == 0, ( + "Chunked prefill is not supported with flashinfer yet.") + + if kv_cache is not None: + # Use the same reshape and cache kernel as flash attention. + ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + attn_metadata.kv_cache_dtype, + ) + + if prefill_meta := attn_metadata.prefill_metadata: + assert prefill_meta.block_tables is not None + if kv_cache is None or prefill_meta.block_tables.numel() == 0: + output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_seq_len, + max_seqlen_k=prefill_meta.max_seq_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) + else: + raise NotImplementedError( + "Prefix caching is not supported with flashinfer yet.") + else: + assert attn_metadata.decode_metadata is not None + assert attn_metadata.decode_metadata.decode_wrapper is not None + query = query.contiguous( + ) # Flashinfer requires query to be contiguous + output = attn_metadata.decode_metadata.decode_wrapper.forward( + query, + kv_cache, + sm_scale=self.scale, + ) + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 934acea0a3d60..c3b522e63b4b8 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,10 +1,10 @@ """Attention layer ROCm GPUs.""" -import os from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Type import torch +import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataPerStage) @@ -46,7 +46,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) @@ -64,27 +64,24 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] - # NOTE(sang): Definition of context_len, subquery_len, and seqlen. + # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| - # |-------------------- seqlen ----------------------| - # |- subquery_len -| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| - # WARNING(sang): context_len has different definition depending on if it is - # prefill vs decoding. When it is prefill, it doesn't include new tokens. - # When it is for decoding, it includes a new token. - - # Maximum subquery length in the batch. - max_subquery_len: Optional[int] - # Maximum prompt length in the batch. - max_prompt_len: Optional[int] + # Maximum query length in the batch. + max_query_len: Optional[int] + # Maximum sequence length in the batch. + max_seq_len: Optional[int] # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. @@ -98,6 +95,9 @@ class ROCmFlashAttentionMetadata(AttentionMetadataPerStage, # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] class ROCmFlashAttentionImpl(AttentionImpl): @@ -156,8 +156,7 @@ def __init__( self.use_naive_attn = False # NOTE: Allow for switching between Triton and CK. Defaulting to triton. - self.use_triton_flash_attn = (os.environ.get( - "VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")) + self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN if self.use_triton_flash_attn: from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 triton_attention) @@ -248,7 +247,7 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - assert prefill_meta.prompt_lens is not None + assert prefill_meta.seq_lens is not None if kv_cache is None or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the @@ -261,8 +260,8 @@ def forward( None, prefill_meta.seq_start_loc, prefill_meta.seq_start_loc, - prefill_meta.max_prompt_len, - prefill_meta.max_prompt_len, + prefill_meta.max_seq_len, + prefill_meta.max_seq_len, True, self.scale, ) @@ -275,7 +274,7 @@ def forward( query, key, value, - prefill_meta.prompt_lens, + prefill_meta.seq_lens, self.scale, ) else: @@ -285,8 +284,8 @@ def forward( v=value, cu_seqlens_q=prefill_meta.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prompt_len, - max_seqlen_k=prefill_meta.max_prompt_len, + max_seqlen_q=prefill_meta.max_seq_len, + max_seqlen_k=prefill_meta.max_seq_len, softmax_scale=self.scale, causal=True, ) @@ -304,10 +303,11 @@ def forward( value_cache, prefill_meta.block_tables, prefill_meta.subquery_start_loc, - prefill_meta.prompt_lens_tensor, - prefill_meta.context_lens, - prefill_meta.max_subquery_len, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, self.alibi_slopes, + self.sliding_window[0], ) if decode_meta := attn_metadata.decode_metadata: @@ -317,8 +317,8 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.context_lens, - decode_meta.max_context_len, + decode_meta.seq_lens_tensor, + decode_meta.max_seq_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -334,13 +334,13 @@ def _naive_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - prompt_lens: List[int], + seq_lens: List[int], scale: float, ) -> torch.Tensor: output = torch.empty_like(query) start = 0 - for _, prompt_len in enumerate(prompt_lens): - end = start + prompt_len + for _, seq_len in enumerate(seq_lens): + end = start + seq_len out = _naive_masked_attention( query[start:end], key[start:end], @@ -349,7 +349,7 @@ def _naive_attention( ) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out) - start += prompt_len + start += seq_len return output diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 55a7ce59ac6e0..03825f6023f4c 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -44,7 +44,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) @@ -58,7 +58,7 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata, # or all decoding. True if all sequences are prompts. is_prompt: bool slot_mapping: torch.Tensor - prompt_lens: Optional[List[int]] + seq_lens: Optional[List[int]] def __post_init__(self): # Set during the execution of the first attention op. @@ -136,7 +136,7 @@ def forward( kv_scale) if attn_metadata.is_prompt: - assert attn_metadata.prompt_lens is not None + assert attn_metadata.seq_lens is not None if (kv_cache is None or attn_metadata.block_tables.numel() == 0): if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) @@ -147,13 +147,13 @@ def forward( if self.alibi_slopes is not None: att_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.prompt_lens) # type: ignore + attn_metadata.seq_lens) # type: ignore elif self.sliding_window is not None: att_masks = _make_sliding_window_bias( - attn_metadata.prompt_lens, self.sliding_window, + attn_metadata.seq_lens, self.sliding_window, query.dtype) # type: ignore else: - att_masks = [None] * len(attn_metadata.prompt_lens) + att_masks = [None] * len(attn_metadata.seq_lens) attn_metadata.attn_bias = att_masks query = query.movedim(0, query.dim() - 2) @@ -164,9 +164,9 @@ def forward( output = torch.empty( (num_tokens, self.num_heads, self.head_size), dtype=query.dtype) - for prompt_len, mask in zip(attn_metadata.prompt_lens, - attn_metadata.attn_bias): - end = start + prompt_len + for seq_len, mask in zip(attn_metadata.seq_lens, + attn_metadata.attn_bias): + end = start + seq_len sub_out = scaled_dot_product_attention( query[:, start:end, :], key[:, start:end, :], @@ -189,8 +189,8 @@ def forward( key_cache, value_cache, attn_metadata.block_tables, - attn_metadata.context_lens, - attn_metadata.max_context_len, + attn_metadata.seq_lens_tensor, + attn_metadata.max_seq_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -205,13 +205,13 @@ def forward( def _make_alibi_bias( alibi_slopes: torch.Tensor, dtype: torch.dtype, - prompt_lens: List[int], + seq_lens: List[int], ) -> List[torch.Tensor]: attn_biases = [] - for prompt_len in prompt_lens: - bias = torch.arange(prompt_len, dtype=dtype) + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(prompt_len, 1)` + # `bias = bias[None, :].repeat(seq_len, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. @@ -221,7 +221,7 @@ def _make_alibi_bias( bias = bias[None, :].repeat((num_heads, 1, 1)) bias.mul_(alibi_slopes[:, None, None]) inf_mask = torch.empty( - (1, prompt_len, prompt_len), + (1, seq_len, seq_len), dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) attn_biases.append((bias + inf_mask).to(dtype)) @@ -229,14 +229,14 @@ def _make_alibi_bias( def _make_sliding_window_bias( - prompt_lens: List[int], + seq_lens: List[int], window_size: Optional[int], dtype: torch.dtype, ) -> List[torch.Tensor]: attn_biases = [] - for prompt_len in prompt_lens: + for seq_len in seq_lens: tensor = torch.full( - (1, prompt_len, prompt_len), + (1, seq_len, seq_len), dtype=dtype, fill_value=1, ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 572a4dc79a719..4c7fa71a2c78e 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -49,7 +49,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: PagedAttention.copy_blocks(kv_caches, src_to_dists) @@ -66,28 +66,24 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. is_prompt: bool - # (batch_size,). The prompt length per sequence. None if it is a decoding. - prompt_lens: Optional[List[int]] - # prompt_lens stored as a tensor. - prompt_lens_tensor: Optional[torch.Tensor] + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] - # NOTE(sang): Definition of context_len, subquery_len, and seqlen. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| - # |-------------------- seqlen ----------------------| - # |- subquery_len -| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| - # WARNING(sang): context_len has different definition depending on if it is - # prefill vs decoding. When it is prefill, it doesn't include new tokens. - # When it is for decoding, it includes a new token. - - # Maximum subquery length in the batch. - max_subquery_len: Optional[int] + # Maximum query length in the batch. + max_query_len: Optional[int] # FIXME: It is for flash attn. - # Maximum prompt length in the batch. - max_prompt_len: Optional[int] + # Maximum sequence length in the batch. + max_seq_len: Optional[int] # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. @@ -97,6 +93,9 @@ class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata): # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] # Whether or not if cuda graph is enabled. # Cuda-graph is currently enabled for decoding only. @@ -242,10 +241,11 @@ def forward( value_cache, prefill_meta.block_tables, prefill_meta.subquery_start_loc, - prefill_meta.prompt_lens_tensor, - prefill_meta.context_lens, - prefill_meta.max_subquery_len, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, self.alibi_slopes, + self.sliding_window, ) assert output[:num_prefill_tokens].shape == out.shape output[:num_prefill_tokens] = out @@ -256,8 +256,8 @@ def forward( key_cache, value_cache, decode_meta.block_tables, - decode_meta.context_lens, - decode_meta.max_context_len, + decode_meta.seq_lens_tensor, + decode_meta.max_seq_len, attn_metadata.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -288,7 +288,7 @@ def _run_memory_efficient_xformers_forward( value: shape = [num_prefill_tokens, num_kv_heads, head_size] attn_metadata: Metadata for attention. """ - assert attn_metadata.prompt_lens is not None + assert attn_metadata.seq_lens is not None original_query = query if self.num_kv_heads != self.num_heads: # GQA/MQA requires the shape [B, M, G, H, K]. @@ -309,7 +309,7 @@ def _run_memory_efficient_xformers_forward( if attn_metadata.attn_bias is None: if self.alibi_slopes is None: attn_bias = BlockDiagonalCausalMask.from_seqlens( - attn_metadata.prompt_lens) + attn_metadata.seq_lens) if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) @@ -317,7 +317,7 @@ def _run_memory_efficient_xformers_forward( else: attn_metadata.attn_bias = _make_alibi_bias( self.alibi_slopes, self.num_kv_heads, query.dtype, - attn_metadata.prompt_lens) + attn_metadata.seq_lens) # No alibi slopes. # TODO(woosuk): Too many view operations. Let's try to reduce @@ -342,8 +342,8 @@ def _run_memory_efficient_xformers_forward( # one. This is inefficient, especially when we have many short prompts. output = torch.empty_like(original_query) start = 0 - for i, prompt_len in enumerate(attn_metadata.prompt_lens): - end = start + prompt_len + for i, seq_len in enumerate(attn_metadata.seq_lens): + end = start + seq_len out = xops.memory_efficient_attention_forward( query[None, start:end], key[None, start:end], @@ -353,7 +353,7 @@ def _run_memory_efficient_xformers_forward( scale=self.scale) # TODO(woosuk): Unnecessary copy. Optimize. output[start:end].copy_(out.view_as(original_query[start:end])) - start += prompt_len + start += seq_len return output @@ -361,13 +361,13 @@ def _make_alibi_bias( alibi_slopes: torch.Tensor, num_kv_heads: int, dtype: torch.dtype, - prompt_lens: List[int], + seq_lens: List[int], ) -> LowerTriangularMaskWithTensorBias: attn_biases = [] - for prompt_len in prompt_lens: - bias = torch.arange(prompt_len, dtype=dtype) + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(prompt_len, 1)` + # `bias = bias[None, :].repeat(seq_len, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. @@ -375,16 +375,16 @@ def _make_alibi_bias( # element. bias = bias[None, :] - bias[:, None] - padded_len = (prompt_len + 7) // 8 * 8 + padded_len = (seq_len + 7) // 8 * 8 num_heads = alibi_slopes.shape[0] bias = torch.empty( 1, # batch size num_heads, - prompt_len, + seq_len, padded_len, device=alibi_slopes.device, dtype=dtype, - )[:, :, :, :prompt_len].copy_(bias) + )[:, :, :, :seq_len].copy_(bias) bias.mul_(alibi_slopes[:, None, None]) if num_heads != num_kv_heads: bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index fc65ae108dbb1..ee7be26c0876c 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -47,3 +47,10 @@ def forward( ) -> torch.Tensor: return self.impl.forward(query, key, value, kv_cache, attn_metadata, kv_scale) + + def extra_repr(self) -> str: + s = f"head_size={self.impl.head_size}" # type: ignore + s += f", num_heads={self.impl.num_heads}" # type: ignore + s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore + s += f", scale={self.impl.scale}" # type: ignore + return s diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index cd0690a4ba957..6f7fd51c774f8 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -13,12 +13,11 @@ @dataclass class PagedAttentionMetadata: """Metadata for PagedAttention.""" - # (batch_size,). The length of context (tokens stored in KV cache) per - # sequence. WARNING: When it is a prefill request, it doesn't include new - # tokens. When it is for decoding, it includes a new token. - context_lens: Optional[torch.Tensor] - # Maximum context length in the batch. - max_context_len: Optional[int] + # (batch_size,). The length of sequences (entire tokens seen so far) per + # sequence. + seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length in the batch. + max_seq_len: Optional[int] # (batch_size, max_blocks_per_seq). # Block addresses per sequence. (Seq id -> list of physical block) # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks @@ -85,8 +84,8 @@ def forward_decode( key_cache: torch.Tensor, value_cache: torch.Tensor, block_tables: torch.Tensor, - context_lens: torch.Tensor, - max_context_len: int, + seq_lens: torch.Tensor, + max_seq_len: int, kv_cache_dtype: str, num_kv_heads: int, scale: float, @@ -97,7 +96,7 @@ def forward_decode( block_size = value_cache.shape[3] num_seqs, num_heads, head_size = query.shape - max_num_partitions = ((max_context_len + _PARTITION_SIZE - 1) // + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // _PARTITION_SIZE) # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use @@ -106,7 +105,7 @@ def forward_decode( # to parallelize. # TODO(woosuk): Tune this heuristic. # For context len > 8192, use V2 kernel to avoid shared memory shortage. - use_v1 = (max_context_len <= 8192 + use_v1 = (max_seq_len <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)) if use_v1: # Run PagedAttention V1. @@ -118,9 +117,9 @@ def forward_decode( num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -150,9 +149,9 @@ def forward_decode( num_kv_heads, scale, block_tables, - context_lens, + seq_lens, block_size, - max_context_len, + max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, @@ -168,10 +167,11 @@ def forward_prefix( value_cache: torch.Tensor, block_tables: torch.Tensor, subquery_start_loc: torch.Tensor, - prompt_lens_tensor: torch.Tensor, + seq_lens_tensor: torch.Tensor, context_lens: torch.Tensor, - max_subquery_len: int, + max_query_len: int, alibi_slopes: Optional[torch.Tensor], + sliding_window: Optional[int], ) -> torch.Tensor: output = torch.empty_like(query) context_attention_fwd( @@ -184,10 +184,11 @@ def forward_prefix( block_tables, # subquery_start_loc is (batch_size + 1,) subquery_start_loc[:-1], - prompt_lens_tensor, + seq_lens_tensor, context_lens, - max_subquery_len, + max_query_len, alibi_slopes, + sliding_window, ) return output @@ -208,7 +209,7 @@ def swap_blocks( @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dists: torch.Tensor, ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 4896cf3909c6e..79878b26c5294 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -50,6 +50,7 @@ def _fwd_kernel( BLOCK_DMODEL: tl.constexpr, # head size BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -62,42 +63,53 @@ def _fwd_kernel( cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len + # start position inside of the query + # generally, N goes over kv, while M goes over query_len block_start_loc = BLOCK_M * start_m # initialize offsets + # [N]; starts at 0 offs_n = tl.arange(0, BLOCK_N) + # [D]; starts at 0 offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + # [M]; starts at current position in query offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # [M,D] off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd) dim_mask = tl.where( - tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, + 0).to(tl.int1) # [D] q = tl.load(Q + off_q, mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len), - other=0.0) + other=0.0) # [M,D] - # # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M] + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M] + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], + dtype=tl.float32) # [M,D] + # compute query against context (no causal mask here) for start_n in range(0, cur_batch_ctx_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + ((start_n + offs_n) // block_size) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_ctx_len, - other=0) + other=0) # [N] + # [D,N] off_k = (bn[None, :] * stride_k_cache_bs + cur_kv_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + ((start_n + offs_n[None, :]) % block_size) * stride_k_cache_bl + (offs_d[:, None] % x) * stride_k_cache_x) + # [N,D] off_v = ( bn[:, None] * stride_v_cache_bs + cur_kv_head * stride_v_cache_h + @@ -106,23 +118,39 @@ def _fwd_kernel( k = tl.load(K_cache + off_k, mask=dim_mask[:, None] & ((start_n + offs_n[None, :]) < cur_batch_ctx_len), - other=0.0) + other=0.0) # [D,N] - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] qk += tl.dot(q, k) qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")) qk *= sm_scale + if SLIDING_WINDOW > 0: + # (cur_batch_ctx_len + offs_m[:, None]) are the positions of + # Q entries in sequence + # (start_n + offs_n[None, :]) are the positions of + # KV entries in sequence + # So the condition makes sure each entry in Q only attends + # to KV entries not more than SLIDING_WINDOW away. + # + # We can't use -inf here, because the + # sliding window may lead to the entire row being masked. + # This then makes m_ij contain -inf, which causes NaNs in + # exp(). + qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - + (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, + -10000) # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) + m_ij = tl.max(qk, 1) # [M] + p = tl.exp(qk - m_ij[:, None]) # [M,N] + l_ij = tl.sum(p, 1) # [M] # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij + m_i_new = tl.maximum(m_i, m_ij) # [M] + alpha = tl.exp(m_i - m_i_new) # [M] + beta = tl.exp(m_ij - m_i_new) # [M] + l_i_new = alpha * l_i + beta * l_ij # [M] + # -- update output accumulator -- # scale p p_scale = beta / l_i_new @@ -134,7 +162,7 @@ def _fwd_kernel( v = tl.load(V_cache + off_v, mask=dim_mask[None, :] & ((start_n + offs_n[:, None]) < cur_batch_ctx_len), - other=0.0) + other=0.0) # [N,D] p = p.to(v.dtype) acc += tl.dot(p, v) @@ -149,8 +177,10 @@ def _fwd_kernel( k_ptrs = K + off_k v_ptrs = V + off_v + # block_mask is 0 when we're already past the current query length block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) + # compute query against itself (with causal mask) for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- @@ -163,8 +193,13 @@ def _fwd_kernel( qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk *= sm_scale + # apply causal mask qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + if SLIDING_WINDOW > 0: + qk = tl.where( + offs_m[:, None] - + (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000) # -- compute m_ij, p, l_ij m_ij = tl.max(qk, 1) @@ -636,7 +671,8 @@ def context_attention_fwd(q, b_seq_len, b_ctx_len, max_input_len, - alibi_slopes=None): + alibi_slopes=None, + sliding_window=None): cap = torch.cuda.get_device_capability() BLOCK = 128 if cap[0] >= 8 else 64 @@ -644,7 +680,7 @@ def context_attention_fwd(q, Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv # round up Lk to a power of 2 - this is required for Triton block size - Lk_padded = 2**((Lk - 1).bit_length()) + Lk_padded = triton.next_power_of_2(Lk) sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] @@ -749,6 +785,7 @@ def context_attention_fwd(q, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, + SLIDING_WINDOW=sliding_window if sliding_window is not None else 0, num_warps=num_warps, num_stages=1, ) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 7cc17f21dcd0e..34da0f6c6cdfc 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -1,24 +1,23 @@ import enum -import os from functools import lru_cache from typing import Type import torch +import vllm.envs as envs from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.utils import is_cpu, is_hip logger = init_logger(__name__) -VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" - class _Backend(enum.Enum): FLASH_ATTN = enum.auto() XFORMERS = enum.auto() ROCM_FLASH = enum.auto() TORCH_SDPA = enum.auto() + FLASHINFER = enum.auto() @lru_cache(maxsize=None) @@ -43,6 +42,11 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]: logger.info("Using Torch SDPA backend.") from vllm.attention.backends.torch_sdpa import TorchSDPABackend return TorchSDPABackend + elif backend == _Backend.FLASHINFER: + logger.info("Using Flashinfer backend.") + logger.warning("Eager mode is enforced for the Flashinfer backend. ") + from vllm.attention.backends.flashinfer import FlashInferBackend + return FlashInferBackend else: raise ValueError("Invalid attention backend.") @@ -79,7 +83,7 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend: "package is not found. Please install it for better performance.") return _Backend.XFORMERS - backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) + backend_by_env_var = envs.VLLM_ATTENTION_BACKEND if backend_by_env_var is not None: return _Backend[backend_by_env_var] diff --git a/vllm/config.py b/vllm/config.py index a5512c657e038..5c3a8615eefb4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,19 +1,16 @@ import enum import json -import os from dataclasses import dataclass, field, fields from typing import TYPE_CHECKING, ClassVar, List, Optional, Union import torch -from packaging.version import Version from transformers import PretrainedConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, get_quantization_config) from vllm.transformers_utils.config import get_config, get_hf_text_config -from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip, - is_neuron) +from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron GPTQMarlinConfig = get_quantization_config("gptq_marlin") @@ -24,10 +21,6 @@ logger = init_logger(__name__) -# If true, will load models from ModelScope instead of Hugging Face Hub. -VLLM_USE_MODELSCOPE = os.environ.get("VLLM_USE_MODELSCOPE", - "False").lower() == "true" - _GB = 1 << 30 @@ -36,6 +29,8 @@ class ModelConfig: Args: model: Name or path of the huggingface model to use. + It is also used as the content for `model_name` tag in metrics + output when `served_model_name` is not specified. tokenizer: Name or path of the huggingface tokenizer to use. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if available, and "slow" will always use the slow tokenizer. @@ -68,9 +63,16 @@ class ModelConfig: If False, we will use CUDA graph and eager execution in hybrid. max_context_len_to_capture: Maximum context len covered by CUDA graphs. When a sequence has context length larger than this, we fall back - to eager mode. + to eager mode (DEPRECATED. Use max_seq_len_to_capture instead). + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode skip_tokenizer_init: If true, skip initialization of tokenizer and detokenizer. + served_model_name: The model name used in metrics tag `model_name`, + matches the model name exposed via the APIs. If multiple model + names provided, the first name will be used. If not specified, + the model name will be the same as `model`. """ def __init__( @@ -89,8 +91,10 @@ def __init__( quantization_param_path: Optional[str] = None, enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 5, skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, ) -> None: self.model = model self.tokenizer = tokenizer @@ -104,6 +108,11 @@ def __init__( self.quantization_param_path = quantization_param_path self.enforce_eager = enforce_eager self.max_context_len_to_capture = max_context_len_to_capture + if self.max_context_len_to_capture is not None: + raise ValueError("`max_context_len_to_capture` is deprecated. " + "Use `max_seq_len_to_capture` instead.") + self.max_seq_len_to_capture = (max_seq_len_to_capture + or max_context_len_to_capture) self.max_logprobs = max_logprobs self.skip_tokenizer_init = skip_tokenizer_init @@ -113,6 +122,8 @@ def __init__( self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.max_model_len = _get_and_verify_max_len(self.hf_text_config, max_model_len) + self.served_model_name = get_served_model_name(model, + served_model_name) if not self.skip_tokenizer_init: self._verify_tokenizer_mode() self._verify_quantization() @@ -195,10 +206,10 @@ def _verify_quantization(self) -> None: "non-quantized models.", self.quantization) def _verify_cuda_graph(self) -> None: - if self.max_context_len_to_capture is None: - self.max_context_len_to_capture = self.max_model_len - self.max_context_len_to_capture = min(self.max_context_len_to_capture, - self.max_model_len) + if self.max_seq_len_to_capture is None: + self.max_seq_len_to_capture = self.max_model_len + self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, + self.max_model_len) def verify_with_parallel_config( self, @@ -294,6 +305,11 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size) + def get_num_attention_heads(self, + parallel_config: "ParallelConfig") -> int: + return self.hf_text_config.num_attention_heads // \ + parallel_config.tensor_parallel_size + def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_text_config.num_hidden_layers return total_num_hidden_layers // parallel_config.pipeline_parallel_size @@ -351,12 +367,6 @@ def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass elif self.cache_dtype == "fp8": - if not is_hip(): - nvcc_cuda_version = get_nvcc_cuda_version() - if nvcc_cuda_version < Version("11.8"): - raise ValueError( - "FP8 is not supported when cuda version is" - "lower than 11.8.") logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " @@ -597,8 +607,9 @@ def __init__( self.max_num_batched_tokens = max_num_batched_tokens else: if enable_chunked_prefill: - # For chunked prefill, choose the well-tuned batch size. - self.max_num_batched_tokens = 768 + # It is the values that have the best balance between ITL + # and TTFT on A100. Note it is not optimized for throughput. + self.max_num_batched_tokens = 512 else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. @@ -681,6 +692,8 @@ def maybe_create_spec_config( speculative_max_model_len: Optional[int], enable_chunked_prefill: bool, use_v2_block_manager: bool, + ngram_prompt_lookup_max: Optional[int], + ngram_prompt_lookup_min: Optional[int], ) -> Optional["SpeculativeConfig"]: """Create a SpeculativeConfig if possible, else return None. @@ -707,6 +720,10 @@ def maybe_create_spec_config( use_v2_block_manager (bool): Whether vLLM is configured to use the v2 block manager or not. Used for raising an error since the v2 block manager is required with spec decode. + ngram_prompt_lookup_max (Optional[int]): Max size of ngram token + window, if provided. + ngram_prompt_lookup_min (Optional[int]): Min size of ngram token + window, if provided. Returns: Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if @@ -741,39 +758,57 @@ def maybe_create_spec_config( draft_code_revision = None draft_quantization = None - draft_model_config = ModelConfig( - model=speculative_model, - tokenizer=target_model_config.tokenizer, - tokenizer_mode=target_model_config.tokenizer_mode, - trust_remote_code=target_model_config.trust_remote_code, - dtype=target_model_config.dtype, - seed=target_model_config.seed, - revision=draft_revision, - code_revision=draft_code_revision, - tokenizer_revision=target_model_config.tokenizer_revision, - max_model_len=None, - quantization=draft_quantization, - enforce_eager=target_model_config.enforce_eager, - max_context_len_to_capture=target_model_config. - max_context_len_to_capture, - max_logprobs=target_model_config.max_logprobs, - ) - - draft_model_config.max_model_len = ( - SpeculativeConfig._maybe_override_draft_max_model_len( - speculative_max_model_len, - draft_model_config.max_model_len, - target_model_config.max_model_len, - )) + if speculative_model == "[ngram]": + assert (ngram_prompt_lookup_max is not None + and ngram_prompt_lookup_max > 0) + if ngram_prompt_lookup_min is None: + ngram_prompt_lookup_min = 0 + else: + assert ngram_prompt_lookup_max > ngram_prompt_lookup_min - draft_parallel_config = ( - SpeculativeConfig.create_draft_parallel_config( - target_parallel_config)) + # TODO: current we still need extract vocab_size from target model + # config, in future, we may try refactor it out, and set + # draft related config as None here. + draft_model_config = target_model_config + draft_parallel_config = target_parallel_config + else: + ngram_prompt_lookup_max = 0 + ngram_prompt_lookup_min = 0 + draft_model_config = ModelConfig( + model=speculative_model, + tokenizer=target_model_config.tokenizer, + tokenizer_mode=target_model_config.tokenizer_mode, + trust_remote_code=target_model_config.trust_remote_code, + dtype=target_model_config.dtype, + seed=target_model_config.seed, + revision=draft_revision, + code_revision=draft_code_revision, + tokenizer_revision=target_model_config.tokenizer_revision, + max_model_len=None, + quantization=draft_quantization, + enforce_eager=target_model_config.enforce_eager, + max_seq_len_to_capture=target_model_config. + max_seq_len_to_capture, + max_logprobs=target_model_config.max_logprobs, + ) + + draft_model_config.max_model_len = ( + SpeculativeConfig._maybe_override_draft_max_model_len( + speculative_max_model_len, + draft_model_config.max_model_len, + target_model_config.max_model_len, + )) + + draft_parallel_config = ( + SpeculativeConfig.create_draft_parallel_config( + target_parallel_config)) return SpeculativeConfig( draft_model_config, draft_parallel_config, num_speculative_tokens, + ngram_prompt_lookup_max, + ngram_prompt_lookup_min, ) @staticmethod @@ -841,6 +876,8 @@ def __init__( draft_model_config: ModelConfig, draft_parallel_config: ParallelConfig, num_speculative_tokens: int, + ngram_prompt_lookup_max: int, + ngram_prompt_lookup_min: int, ): """Create a SpeculativeConfig object. @@ -853,6 +890,8 @@ def __init__( self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config self.num_speculative_tokens = num_speculative_tokens + self.ngram_prompt_lookup_max = ngram_prompt_lookup_max + self.ngram_prompt_lookup_min = ngram_prompt_lookup_min self._verify_args() @@ -876,7 +915,10 @@ def num_lookahead_slots(self) -> int: return self.num_speculative_tokens def __repr__(self) -> str: - draft_model = self.draft_model_config.model + if self.ngram_prompt_lookup_max > 0: + draft_model = "[ngram]" + else: + draft_model = self.draft_model_config.model num_spec_tokens = self.num_speculative_tokens return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})" @@ -1108,6 +1150,22 @@ def _get_and_verify_max_len( return int(max_model_len) +def get_served_model_name(model: str, + served_model_name: Optional[Union[str, List[str]]]): + """ + If the input is a non-empty list, the first model_name in + `served_model_name` is taken. + If the input is a non-empty string, it is used directly. + For cases where the input is either an empty string or an + empty list, the fallback is to use `self.model`. + """ + if not served_model_name: + return model + if isinstance(served_model_name, list): + return served_model_name[0] + return served_model_name + + @dataclass class DecodingConfig: """Dataclass which contains the decoding strategy of the engine""" diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py index f1b65b2514f76..b0d9511fba521 100644 --- a/vllm/core/block/block_table.py +++ b/vllm/core/block/block_table.py @@ -40,7 +40,9 @@ def __init__( ): self._block_size = block_size self._allocator = block_allocator - self._blocks: Optional[List[Block]] = _blocks + if _blocks is None: + _blocks = [] + self._blocks: List[Block] = _blocks # Use helper method instead of directly calculating, as blocks # may not be allocated. @@ -104,7 +106,7 @@ def append_token_ids(self, token_ids (List[int]): The sequence of token IDs to be appended. """ assert self._is_allocated - assert self._blocks is not None + assert len(self._blocks) > 0 self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + num_lookahead_slots) @@ -141,6 +143,7 @@ def ensure_num_empty_slots(self, num_empty_slots: int) -> None: blocks_to_allocate = cdiv(slots_to_allocate, self._block_size) for _ in range(blocks_to_allocate): + assert len(self._blocks) > 0 self._blocks.append( self._allocator.allocate_mutable(prev_block=self._blocks[-1], device=device)) @@ -159,6 +162,7 @@ def fork(self) -> "BlockTable": the current instance. """ assert self._is_allocated + assert len(self._blocks) > 0 forked_blocks = self._allocator.fork(self._blocks[-1]) return BlockTable( block_size=self._block_size, @@ -177,10 +181,10 @@ def free(self) -> None: assert self._is_allocated for block in self._blocks: self._allocator.free(block) - self._blocks = None + self._blocks = [] @property - def physical_block_ids(self) -> List[int]: + def physical_block_ids(self) -> List[Optional[int]]: """Returns a list of physical block indices for the blocks in the BlockTable. @@ -235,7 +239,7 @@ def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], def _get_all_token_ids(self) -> List[int]: # NOTE: This function is O(seq_len); use sparingly. - token_ids = [] + token_ids: List[int] = [] if not self._is_allocated: return token_ids @@ -247,7 +251,7 @@ def _get_all_token_ids(self) -> List[int]: @property def _is_allocated(self) -> bool: - return self._blocks is not None + return len(self._blocks) > 0 @property def _num_empty_slots(self) -> int: diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index f11234a0bf2dd..3f97a1210b096 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List, Optional, Protocol from vllm.core.block.interfaces import Block, BlockAllocator @@ -7,7 +7,19 @@ RefCount = int -class RefCounter: +class RefCounterProtocol(Protocol): + + def incr(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + def decr(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + def get(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + +class RefCounter(RefCounterProtocol): """A class for managing reference counts for a set of block indices. The RefCounter class maintains a dictionary that maps block indices to their @@ -54,7 +66,7 @@ def as_readonly(self) -> "ReadOnlyRefCounter": return ReadOnlyRefCounter(self) -class ReadOnlyRefCounter: +class ReadOnlyRefCounter(RefCounterProtocol): """A read-only view of the RefCounter class. The ReadOnlyRefCounter class provides a read-only interface to access the @@ -96,7 +108,7 @@ class CopyOnWriteTracker: def __init__( self, - refcounter: RefCounter, + refcounter: RefCounterProtocol, allocator: BlockAllocator, ): self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list) diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 3135e194c5937..5b25e1bcdada0 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -1,6 +1,6 @@ -from typing import Dict, List, Optional +from typing import Dict, FrozenSet, List, Optional -from vllm.core.block.interfaces import (Block, BlockAllocator, +from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, DeviceAwareBlockAllocator) from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator @@ -57,15 +57,15 @@ def create( cpu_block_ids = block_ids[num_gpu_blocks:] if allocator_type == "naive": - gpu_allocator = NaiveBlockAllocator( - create_block=NaiveBlock, + gpu_allocator: BlockAllocator = NaiveBlockAllocator( + create_block=NaiveBlock, # type: ignore num_blocks=num_gpu_blocks, block_size=block_size, block_ids=gpu_block_ids, ) - cpu_allocator = NaiveBlockAllocator( - create_block=NaiveBlock, + cpu_allocator: BlockAllocator = NaiveBlockAllocator( + create_block=NaiveBlock, # type: ignore num_blocks=num_cpu_blocks, block_size=block_size, block_ids=cpu_block_ids, @@ -105,7 +105,7 @@ def __init__( Device.GPU: gpu_block_allocator, } - self._block_ids_to_allocator = {} + self._block_ids_to_allocator: Dict[int, BlockAllocator] = {} for _, allocator in self._allocators.items(): for block_id in allocator.all_block_ids: self._block_ids_to_allocator[block_id] = allocator @@ -149,7 +149,9 @@ def free(self, block: Block) -> None: Args: block (Block): The block to be freed. """ - allocator = self._block_ids_to_allocator[block.block_id] + block_id = block.block_id + assert block_id is not None + allocator = self._block_ids_to_allocator[block_id] return allocator.free(block) def fork(self, last_block: Block) -> List[Block]: @@ -163,7 +165,9 @@ def fork(self, last_block: Block) -> List[Block]: List[Block]: A new list of blocks that shares the same memory as the original sequence. """ - allocator = self._block_ids_to_allocator[last_block.block_id] + block_id = last_block.block_id + assert block_id is not None + allocator = self._block_ids_to_allocator[block_id] return allocator.fork(last_block) def get_num_free_blocks(self, device: Device) -> int: @@ -171,13 +175,16 @@ def get_num_free_blocks(self, device: Device) -> int: Args: device (Device): The device for which to query the number of free - blocks. + blocks. AssertionError is raised if None is passed. Returns: int: The number of free blocks available on the specified device. """ return self._allocators[device].get_num_free_blocks() + def get_num_total_blocks(self, device: Device) -> int: + return self._allocators[device].get_num_total_blocks() + def clear_copy_on_writes(self) -> Dict[int, List[int]]: """Clears the copy-on-write (CoW) state and returns the mapping of source to destination block IDs. @@ -190,10 +197,18 @@ def clear_copy_on_writes(self) -> Dict[int, List[int]]: device = Device.GPU return self._allocators[device].clear_copy_on_writes() - def mark_blocks_as_computed(self) -> None: + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, only use for prefix caching.""" + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].mark_blocks_as_accessed(block_ids, now) + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + """Mark blocks as accessed, only use for prefix caching.""" # Prefix caching only supported on GPU. device = Device.GPU - return self._allocators[device].mark_blocks_as_computed() + return self._allocators[device].mark_blocks_as_computed(block_ids) def get_common_computed_block_ids( self, seq_block_ids: List[List[int]]) -> List[int]: @@ -202,5 +217,12 @@ def get_common_computed_block_ids( return self._allocators[device].get_common_computed_block_ids( seq_block_ids) - def all_block_ids(self) -> frozenset[int]: + @property + def all_block_ids(self) -> FrozenSet[int]: return frozenset(self._block_ids_to_allocator.keys()) + + def promote_to_immutable_block(self, block: Block) -> BlockId: + raise NotImplementedError + + def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: + raise NotImplementedError diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 50ce922118124..634c4016ca19c 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -3,6 +3,8 @@ from vllm.utils import Device +BlockId = int + class Block(ABC): @@ -15,6 +17,12 @@ def append_token_ids(self, token_ids: List[int]) -> None: def block_id(self) -> Optional[int]: pass + @block_id.setter + @abstractmethod + def block_id(self, value: Optional[int]) -> None: + """NOTE: Do not use this API outside Block.""" + self._block_id = value + @property @abstractmethod def token_ids(self) -> List[int]: @@ -35,6 +43,27 @@ def is_full(self) -> bool: def prev_block(self) -> Optional["Block"]: pass + @property + @abstractmethod + def computed(self) -> bool: + raise NotImplementedError + + @computed.setter + @abstractmethod + def computed(self, value) -> bool: + """Should be only used by PrefixCacingAllocator""" + raise NotImplementedError + + @property + @abstractmethod + def last_accessed(self) -> float: + raise NotImplementedError + + @last_accessed.setter + @abstractmethod + def last_accessed(self, last_accessed_ts: float): + raise NotImplementedError + class Factory(Protocol): @abstractmethod @@ -48,6 +77,17 @@ def __call__( ) -> "Block": pass + @property + @abstractmethod + def content_hash(self) -> Optional[int]: + """Return the content-based hash of the current block, or None if it is + not yet defined or not supported. + + For the content-based hash to be defined, the current block must be + full. + """ + return None + class BlockAllocator(ABC): @@ -57,7 +97,7 @@ def allocate_mutable(self, prev_block: Optional[Block]) -> Block: @abstractmethod def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int], device: Device) -> Block: + token_ids: List[int]) -> Block: pass @abstractmethod @@ -69,7 +109,11 @@ def fork(self, last_block: Block) -> List[Block]: pass @abstractmethod - def get_num_free_blocks(self, device: Device) -> int: + def get_num_total_blocks(self) -> int: + pass + + @abstractmethod + def get_num_free_blocks(self) -> int: pass @property @@ -82,7 +126,12 @@ def clear_copy_on_writes(self) -> Dict[int, List[int]]: pass @abstractmethod - def mark_blocks_as_computed(self) -> None: + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + pass + + @abstractmethod + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: pass @abstractmethod @@ -90,14 +139,25 @@ def get_common_computed_block_ids( self, seq_block_ids: List[List[int]]) -> List[int]: pass + @abstractmethod + def cow_block_if_not_appendable(self, block: Block) -> Optional["BlockId"]: + """NOTE: This should not be used besides Block""" + pass + + @abstractmethod + def promote_to_immutable_block(self, block: Block) -> BlockId: + """NOTE: This should not be used besides Block""" + pass + class NoFreeBlocksError(ValueError): pass -class DeviceAwareBlockAllocator(BlockAllocator): +class DeviceAwareBlockAllocator(ABC): @abstractmethod - def allocate_mutable(self, prev_block: Optional[Block]) -> Block: + def allocate_mutable(self, prev_block: Optional[Block], + device: Device) -> Block: pass @abstractmethod @@ -108,3 +168,38 @@ def allocate_immutable(self, prev_block: Optional[Block], @abstractmethod def get_num_free_blocks(self, device: Device) -> int: pass + + @abstractmethod + def get_num_total_blocks(self, device: Device) -> int: + pass + + @abstractmethod + def free(self, block: Block) -> None: + pass + + @abstractmethod + def fork(self, last_block: Block) -> List[Block]: + pass + + @property + @abstractmethod + def all_block_ids(self) -> FrozenSet[int]: + pass + + @abstractmethod + def clear_copy_on_writes(self) -> Dict[int, List[int]]: + pass + + @abstractmethod + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + pass + + @abstractmethod + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + pass + + @abstractmethod + def get_common_computed_block_ids( + self, seq_block_ids: List[List[int]]) -> List[int]: + pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index f8e9265bb2d67..a1b901bf78efc 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -1,10 +1,9 @@ -from typing import Dict, Iterable, List, Optional, Set +from typing import Dict, FrozenSet, Iterable, List, Optional, Set from vllm.core.block.common import (CopyOnWriteTracker, RefCounter, get_all_blocks_recursively) -from vllm.core.block.interfaces import Block, BlockAllocator +from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device -BlockId = int Refcount = int @@ -49,8 +48,10 @@ def __init__( allocator=self, ) - def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int]) -> Block: + def allocate_immutable(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Optional[Device] = None) -> Block: """Allocates a new immutable block with the given token IDs, linked to the previous block. @@ -63,11 +64,14 @@ def allocate_immutable(self, prev_block: Optional[Block], Returns: Block: The newly allocated immutable block. """ + assert device is None block = self.allocate_mutable(prev_block=prev_block) block.append_token_ids(token_ids) return block - def allocate_mutable(self, prev_block: Optional[Block]) -> Block: + def allocate_mutable(self, + prev_block: Optional[Block], + device: Optional[Device] = None) -> Block: """Allocates a new mutable block, linked to the previous block. Args: @@ -78,6 +82,7 @@ def allocate_mutable(self, prev_block: Optional[Block]) -> Block: Returns: Block: The newly allocated mutable block. """ + assert device is None block_id = self._allocate_new_block_id() return self._create_block( prev_block=prev_block, @@ -88,6 +93,7 @@ def allocate_mutable(self, prev_block: Optional[Block]) -> Block: ) def free(self, block: Block) -> None: + assert block.block_id is not None self._free_block_id(block.block_id) # Mark the block as having no allocation. @@ -111,6 +117,7 @@ def fork(self, last_block: Block) -> List[Block]: for block in source_blocks: # Increment refcount for each block. + assert block.block_id is not None refcount = self._refcounter.incr(block.block_id) assert refcount != 1, "can't fork free'd block" @@ -129,6 +136,9 @@ def fork(self, last_block: Block) -> List[Block]: def get_num_free_blocks(self) -> int: return len(self._free_block_indices) + def get_num_total_blocks(self) -> int: + return len(self._all_block_indices) + def _allocate_new_block_id(self) -> BlockId: if not self._free_block_indices: raise BlockAllocator.NoFreeBlocksError() @@ -148,7 +158,7 @@ def refcounter(self): return self._refcounter @property - def all_block_ids(self): + def all_block_ids(self) -> FrozenSet[int]: return self._all_block_indices def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]: @@ -174,7 +184,16 @@ def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]: """ return self._cow_tracker.clear_cows() - def mark_blocks_as_computed(self) -> None: + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, used in prefix caching. + + Since the naive allocator does not implement prefix caching, we do + nothing. + """ + pass + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: """Mark blocks as computed, used in prefix caching. Since the naive allocator does not implement prefix caching, we do @@ -191,6 +210,9 @@ def get_common_computed_block_ids( """ return [] + def promote_to_immutable_block(self, block: Block) -> BlockId: + raise NotImplementedError + class NaiveBlock(Block): """An implementation of the Block class that does not support prefix @@ -215,13 +237,13 @@ class NaiveBlock(Block): """ def __init__(self, - prev_block: Block, + prev_block: Optional[Block], token_ids: List[int], block_size: int, allocator: BlockAllocator, block_id: Optional[int] = None, _cow_target: Optional[Block] = None): - self._token_ids = [] + self._token_ids: List[int] = [] self._block_size = block_size self._prev_block = prev_block self._block_id = block_id @@ -247,6 +269,22 @@ def _append_token_ids_no_cow(self, token_ids: List[int]) -> None: assert self.num_empty_slots >= len(token_ids) self._token_ids.extend(token_ids) + @property + def computed(self) -> bool: + raise NotImplementedError + + @computed.setter + def computed(self, value) -> None: + raise NotImplementedError + + @property + def last_accessed(self) -> float: + raise NotImplementedError + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + raise NotImplementedError + @property def block_id(self) -> Optional[int]: return self._block_id @@ -267,9 +305,14 @@ def num_empty_slots(self) -> int: def token_ids(self) -> List[int]: return self._token_ids + @property def block_size(self) -> int: return self._block_size @property def prev_block(self) -> Optional["Block"]: return self._prev_block + + @property + def content_hash(self) -> Optional[int]: + return None diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 6aa75a8abb80a..4a37e8f87c379 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -1,15 +1,20 @@ """Token blocks.""" from itertools import takewhile from os.path import commonprefix -from typing import Dict, Iterable, List, Optional +from typing import Dict, FrozenSet, Iterable, List, Optional from vllm.core.block.common import (CopyOnWriteTracker, get_all_blocks_recursively) -from vllm.core.block.interfaces import Block, BlockAllocator +from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator +from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor PrefixHash = int -BlockId = int + +# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME +# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME, +# then we know this block hasn't been accessed yet. +_DEFAULT_LAST_ACCESSED_TIME = -1 class PrefixCachingBlockAllocator(BlockAllocator): @@ -27,26 +32,23 @@ class PrefixCachingBlockAllocator(BlockAllocator): from 0 to num_blocks - 1. """ - # TODO last access time / evictor integration - def __init__( self, num_blocks: int, block_size: int, block_ids: Optional[Iterable[int]] = None, + eviction_policy: EvictionPolicy = EvictionPolicy.LRU, ): # A mapping of prefix hash to block index. All blocks which have a # prefix hash will be in this dict, even if they have refcount 0. self._cached_blocks: Dict[PrefixHash, BlockId] = {} - # A mapping of prefix hash to block index. All blocks which have a - # prefix hash AND refcount 0 will be in this dict. Thus, it is a subset - # of self._cached_blocks. - self._unused_cached_blocks: Dict[PrefixHash, BlockId] = {} + # A mapping of blockId to Block to track those cached blocks + self._blocks: Dict[BlockId, Block] = {} # An allocator for blocks that do not have prefix hashes. self._hashless_allocator = NaiveBlockAllocator( - create_block=self._create_block, + create_block=self._create_block, # type: ignore num_blocks=num_blocks, block_size=block_size, block_ids=block_ids, @@ -54,6 +56,10 @@ def __init__( self._block_size = block_size + # Evitor used to maintain how we want to handle those computed blocks + # if we find memory pressure is high. + self.evictor: Evictor = make_evictor(eviction_policy) + # We share the refcounter between allocators. This allows us to promote # blocks originally allocated in the hashless allocator to immutable # blocks. @@ -72,6 +78,7 @@ def _create_block( block_size: int, allocator: BlockAllocator, block_id: Optional[int] = None, + computed: bool = False, ) -> Block: # Bind block to self. allocator = self @@ -82,10 +89,13 @@ def _create_block( block_size=block_size, block_id=block_id, prefix_caching_allocator=allocator, + computed=computed, ) - def allocate_immutable(self, prev_block: Optional[Block], - token_ids: List[int]) -> Block: + def allocate_immutable(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Optional[Device] = None) -> Block: """Allocates an immutable block with the given token IDs, reusing cached blocks if possible. @@ -96,6 +106,7 @@ def allocate_immutable(self, prev_block: Optional[Block], Returns: Block: The allocated immutable block. """ + assert device is None assert_prefix_caching_block_or_none(prev_block) block = self._create_block( @@ -109,65 +120,95 @@ def allocate_immutable(self, prev_block: Optional[Block], cached_block_id = self._cached_blocks.get(block.content_hash, None) if cached_block_id is not None: block.block_id = cached_block_id - self._incr_refcount_cached_block(block.content_hash, - block.block_id) + self._incr_refcount_cached_block(block, block.block_id) return block block = self.allocate_mutable(prev_block) block.append_token_ids(token_ids) assert block.content_hash is not None - # TODO computed bit return block - def allocate_mutable(self, prev_block: Block) -> Block: + def allocate_mutable(self, + prev_block: Optional[Block], + device: Optional[Device] = None) -> Block: """Allocates a mutable block. If there are no free blocks, this will evict unused cached blocks. Args: prev_block (Block): The previous block in the sequence. + None is not allowed unlike it is super class. Returns: Block: The allocated mutable block. """ + assert device is None assert_prefix_caching_block_or_none(prev_block) try: - return self._hashless_allocator.allocate_mutable( + block = self._hashless_allocator.allocate_mutable( prev_block=prev_block) + + assert block.block_id not in self._blocks + assert block.block_id is not None + self._blocks[block.block_id] = block + return block except BlockAllocator.NoFreeBlocksError: # We must check the unused cached blocks before raising OOM. pass - if self._unused_cached_blocks: - # TODO policy for selecting block to remove - content_hash_to_evict = next(iter(self._unused_cached_blocks)) + # If the evictor has blocks available for eviction, evict a block + # and return it. + if self.evictor.num_blocks > 0: + block_id, content_hash_to_evict = self.evictor.evict() + + # Here we may have scenario that several blocks have + # the same content hash, but due to the latter coming block + # is coming from mutable to immutable path, their physical + # block is added into evictor. + # However in this case, we shall not pop the _cached_blocks, + # as the same content is still used by others, which means + # we need to check ref before decide to pop the list. - # Clear content hash mapping; the block will be overwritten. - del self._cached_blocks[content_hash_to_evict] + _block_id = self._cached_blocks[content_hash_to_evict] + refcount = self._refcounter.get(_block_id) + if refcount == 1: + self._cached_blocks.pop(content_hash_to_evict) + assert _block_id == block_id - block_id = self._unused_cached_blocks.pop(content_hash_to_evict) - refcount = self._refcounter.incr(block_id) - assert refcount == 1 + self._refcounter.incr(block_id) + + # the block comes from evictor already contain computed result block = self._create_block( prev_block=prev_block, token_ids=[], block_size=self._block_size, allocator=self, block_id=block_id, + computed=True, ) assert block.content_hash is None + + assert block.block_id not in self._blocks + assert block.block_id is not None + self._blocks[block.block_id] = block return block # No block available in hashless allocator, nor in unused cache blocks. raise BlockAllocator.NoFreeBlocksError() - def _incr_refcount_cached_block(self, content_hash: int, + def _incr_refcount_cached_block(self, block: Block, block_id: BlockId) -> None: + # since block is already computed, mark it + block.computed = True + refcount = self._refcounter.incr(block_id) if refcount == 1: - assert content_hash in self._unused_cached_blocks - del self._unused_cached_blocks[content_hash] + # if block get referred, then it shall not be in evictor + # and put it into _blocks for tracking + if block_id in self.evictor: + self.evictor.remove(block_id) + self._blocks[block_id] = block def free(self, block: Block) -> None: """Decrement the refcount of the block. If the decremented refcount is @@ -180,6 +221,7 @@ def free(self, block: Block) -> None: is not None), "freeing unallocated block is undefined" self._free_block_id_for_block(block.block_id, block) + block.block_id = None def _free_block_id_for_block(self, block_id: BlockId, @@ -187,15 +229,23 @@ def _free_block_id_for_block(self, block_id: BlockId, assert isinstance(block, PrefixCachingBlock) if block.content_hash is None: + refcount = self._refcounter.get(block_id) + # We have fork case where block would get more than one ref, + # so we cannot free it from tracking if ref cnt large than 1 + if refcount <= 1: + assert block.block_id is not None + del self._blocks[block.block_id] return self._hashless_allocator.free(block) refcount = self._refcounter.decr(block_id) - # If no longer used, add the block to the unused cached blocks. + # If no longer used, add the block to the evictor. if refcount == 0: - assert block.content_hash not in self._unused_cached_blocks assert block.content_hash in self._cached_blocks - self._unused_cached_blocks[block.content_hash] = block_id + assert block.block_id is not None + del self._blocks[block.block_id] + self.evictor.add(block.block_id, block.content_hash, + block.num_tokens_total, block.last_accessed) def fork(self, last_block: Block) -> List[Block]: """Creates a new sequence of blocks that shares the same underlying @@ -228,18 +278,21 @@ def fork(self, last_block: Block) -> List[Block]: return forked_blocks - def get_num_free_blocks(self) -> int: + def get_num_free_blocks(self, device: Optional[Device] = None) -> int: + assert device is None # The number of free blocks is the number of hashless free blocks - # plus the number of hashful blocks that are unused. - return self._hashless_allocator.get_num_free_blocks() + len( - self._unused_cached_blocks) + # plus the number of blocks evictor could free from its list. + return self._hashless_allocator.get_num_free_blocks( + ) + self.evictor.num_blocks + + def get_num_total_blocks(self) -> int: + return self._hashless_allocator.get_num_total_blocks() @property - def all_block_ids(self) -> frozenset[int]: + def all_block_ids(self) -> FrozenSet[int]: return self._hashless_allocator.all_block_ids - def promote_to_immutable_block(self, - block: "PrefixCachingBlock") -> BlockId: + def promote_to_immutable_block(self, block: Block) -> BlockId: """Once a mutable block is full, it can be promoted to an immutable block. This means that its content can be referenced by future blocks having the same prefix. @@ -249,7 +302,7 @@ def promote_to_immutable_block(self, block. Args: - block (PrefixCachingBlock): The mutable block to be promoted. + block: The mutable block to be promoted. Returns: BlockId: Either the original block index, or the block index of @@ -266,7 +319,7 @@ def promote_to_immutable_block(self, else: self._free_block_id_for_block(block.block_id, block) self._incr_refcount_cached_block( - block.content_hash, self._cached_blocks[block.content_hash]) + block, self._cached_blocks[block.content_hash]) return self._cached_blocks[block.content_hash] @@ -293,29 +346,63 @@ def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]: """ return self._cow_tracker.clear_cows() - def mark_blocks_as_computed(self) -> None: + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, used in prefix caching. + + If the block is added into evictor, we need to update corresponding + info in evictor's metadata. + """ + + for block_id in block_ids: + if block_id in self._blocks: + self._blocks[block_id].last_accessed = now + elif block_id in self.evictor: + self.evictor.update(block_id, now) + else: + raise ValueError( + "Mark block as accessed which is not belonged to GPU") + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: """Mark blocks as computed, used in prefix caching.""" - # TODO Track computed blocks. - pass + + for block_id in block_ids: + if block_id in self._blocks: + # only those full block is valid for prefix caching + if self._blocks[block_id].is_full: + self._blocks[block_id].computed = True + elif block_id not in self.evictor: + raise ValueError(f"Mark {block_id=} as computed which " + "is not belonged to GPU") + + def block_is_computed(self, block_id: int) -> bool: + if block_id in self._blocks: + return self._blocks[block_id].computed + else: + return block_id in self.evictor def get_common_computed_block_ids( self, seq_block_ids: List[List[int]]) -> List[int]: """Return the block ids that are common for a given sequence group. - Used in prefill (can skip prefill of some blocks). + Only those blocks that are immutable and already be marked + compyted would be taken consideration. """ - # TODO: Track computed blocks. - computed = lambda block_id: False - # NOTE We exclude the last block to avoid the case where the entire # prompt is cached. This would cause erroneous behavior in model # runner. + ids_list = [ - takewhile(lambda block_id: computed(block_id), seq[:-1]) - for seq in seq_block_ids + list( + takewhile(lambda block_id: self.block_is_computed(block_id), + seq[:-1])) for seq in seq_block_ids ] - return commonprefix([ids for ids in ids_list if ids != []]) + # It returns a list of int although type annotation says list of string. + return commonprefix([ + ids for ids in ids_list # type: ignore + if ids != [] + ]) class PrefixCachingBlock(Block): @@ -332,7 +419,7 @@ class PrefixCachingBlock(Block): token_ids (List[int]): The initial token IDs to be stored in the block. block_size (int): The maximum number of token IDs that can be stored in the block. - prefix_caching_allocator (PrefixCachingBlockAllocator): The prefix + prefix_caching_allocator (BlockAllocator): The prefix caching block allocator associated with this block. block_id (Optional[int], optional): The physical block index of this block. Defaults to None. @@ -340,17 +427,25 @@ class PrefixCachingBlock(Block): def __init__( self, - prev_block: Optional["PrefixCachingBlock"], + prev_block: Optional[Block], token_ids: List[int], block_size: int, - prefix_caching_allocator: PrefixCachingBlockAllocator, + prefix_caching_allocator: BlockAllocator, block_id: Optional[int] = None, + computed: bool = False, ): + assert isinstance(prefix_caching_allocator, + PrefixCachingBlockAllocator), ( + "Currently this class is only tested with " + "PrefixCachingBlockAllocator.") assert_prefix_caching_block_or_none(prev_block) self._prev_block = prev_block self._cached_content_hash: Optional[int] = None + self._cached_num_tokens_total: Optional[int] = None self._prefix_caching_allocator = prefix_caching_allocator + self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME + self._computed = computed self._block = NaiveBlock( prev_block=prev_block, @@ -361,6 +456,22 @@ def __init__( _cow_target=self, ) + @property + def computed(self) -> bool: + return self._computed + + @computed.setter + def computed(self, value) -> None: + self._computed = value + + @property + def last_accessed(self) -> float: + return self._last_accessed + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + self._last_accessed = last_accessed_ts + def append_token_ids(self, token_ids: List[int]) -> None: """Appends the given token IDs to the block and registers the block as immutable if the block becomes full. @@ -398,6 +509,27 @@ def is_full(self) -> bool: def num_empty_slots(self) -> int: return self._block.num_empty_slots + @property + def num_tokens_total(self) -> int: + """return the total tokens so far. + + Here we iterate the block chain till to the first block, while + cache the result in local to prevent repeated computations. + """ + if self._cached_num_tokens_total is not None: + return self._cached_num_tokens_total + + _block: Optional[Block] = self + self._cached_num_tokens_total = 0 + + # TODO: current implement here take O(N^2), we expect future + # we have O(1) here + while _block is not None: + self._cached_num_tokens_total += len(_block.token_ids) + _block = _block.prev_block + + return self._cached_num_tokens_total + @property def block_size(self) -> int: return self._block.block_size @@ -428,8 +560,10 @@ def content_hash(self) -> Optional[int]: return None is_first_block = self._prev_block is None - prev_block_hash = (None if is_first_block else - self._prev_block.content_hash) + prev_block_hash = ( + None if is_first_block else + self._prev_block.content_hash # type: ignore + ) # Previous block exists but does not yet have a hash. # Return no hash in this case. diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 1fac2636e86fa..268c5c135d887 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -8,7 +8,7 @@ from typing import Set from vllm.block import BlockTable, PhysicalTokenBlock -from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor +from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger from vllm.sequence import Sequence, SequenceGroup, SequenceStatus @@ -47,6 +47,10 @@ def free(self, block: PhysicalTokenBlock) -> None: def get_num_free_blocks(self) -> int: pass + @abstractmethod + def get_num_total_blocks(self) -> int: + pass + @abstractmethod def contains_block(self, block_hash: int) -> bool: pass @@ -131,6 +135,9 @@ def get_num_free_blocks(self) -> int: return (self.num_blocks - self.current_num_blocks + self.evictor.num_blocks) + def get_num_total_blocks(self) -> int: + return self.num_blocks + def contains_block(self, block_hash: int) -> bool: return block_hash in self.cached_blocks or block_hash in self.evictor @@ -190,6 +197,9 @@ def free(self, block: PhysicalTokenBlock) -> None: def get_num_free_blocks(self) -> int: return len(self.free_blocks) + def get_num_total_blocks(self) -> int: + return self.num_blocks + def contains_block(self, block_hash: int) -> bool: raise NotImplementedError( "Invalid codepath for uncached block allocator.") @@ -391,7 +401,7 @@ def append_slots( block_table.append(block_table[len(block_table) % self.block_sliding_window]) else: - # The sequence has a new logical block. + # The sequence hash a new logical block. # Allocate a new physical block. new_block = self._allocate_last_physical_block(seq) block_table.append(new_block) @@ -444,7 +454,7 @@ def _get_physical_blocks( def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> bool: + num_lookahead_slots: int = 0) -> AllocStatus: assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" blocks = self._get_physical_blocks(seq_group) @@ -454,7 +464,12 @@ def can_swap_in(self, # at least one free block right after the swap-in. # NOTE: This should match the logic in can_append_slot(). num_required_blocks = len(blocks) + num_swapped_seqs - return num_free_blocks - num_required_blocks >= self.watermark_blocks + if self.gpu_allocator.get_num_total_blocks() < num_required_blocks: + return AllocStatus.NEVER + elif num_free_blocks - num_required_blocks >= self.watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER def swap_in(self, seq_group: SequenceGroup, diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 6339a6baf4161..ce90ce2f17278 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -72,14 +72,12 @@ def __init__( self.watermark = watermark assert watermark >= 0.0 - assert not enable_caching, "Prefix caching not yet supported" self.enable_caching = enable_caching self.watermark_blocks = int(watermark * num_gpu_blocks) self.block_allocator = CpuGpuBlockAllocator.create( - # Currently, only naive blocks are supported (no prefix caching). - allocator_type="naive", + allocator_type="prefix_caching" if enable_caching else "naive", num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, block_size=block_size, @@ -192,19 +190,30 @@ def get_block_table(self, seq: Sequence) -> List[int]: assert seq.seq_id in self.block_tables block_ids = self.block_tables[seq.seq_id].physical_block_ids assert all(b is not None for b in block_ids) - return block_ids - - def access_all_blocks_in_seq(self, seq, now): - # TODO add prefix caching support. - # Tracked here https://github.com/vllm-project/vllm/issues/3667 - pass + return block_ids # type: ignore + + def access_all_blocks_in_seq(self, seq: Sequence, now: float): + # Update the last accessed time of all the blocks accessed + # in this step. + # And the accessed time is only useful for prefix caching now, + # as it support internal evictor policy for which cached + # block could be refilled, to keep cached content could be reused + # at max extend. + if self.enable_caching: + block_table = self.block_tables[seq.seq_id] + block_ids = [] + for block_id in block_table.physical_block_ids: + block_ids.append(block_id) + self.block_allocator.mark_blocks_as_accessed( + block_ids, # type: ignore + now) def mark_blocks_as_computed(self, seq_group: SequenceGroup): - # We ignore the sequence group as its not necessary. After the batch is - # formed by the scheduler, we do not need to mark blocks from individual - # sequence groups as computed -- all blocks in the batch can be marked - # as computed. - self.block_allocator.mark_blocks_as_computed() + # The only need for mark block as computed is for prefix caching, + # while currently we could determine whether one block is computed + # or not by check whether it has content hash. + # So this function is useless for block_v2. + pass def get_common_computed_block_ids( self, seqs: List[Sequence]) -> GenericSequence[int]: @@ -220,16 +229,17 @@ def get_common_computed_block_ids( seq_block_ids = [ self.block_tables[seq.seq_id].physical_block_ids for seq in seqs ] + # NOTE(sang): This assumes seq_block_ids doesn't contain any None. return self.block_allocator.get_common_computed_block_ids( - seq_block_ids) + seq_block_ids) # type: ignore def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: src_block_table = self.block_tables[parent_seq.seq_id] self.block_tables[child_seq.seq_id] = src_block_table.fork() def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: - return False + num_lookahead_slots: int) -> AllocStatus: + return AllocStatus.LATER def swap_in(self, seq_group: SequenceGroup, num_lookahead_slots: int) -> Dict[int, int]: diff --git a/vllm/core/evictor.py b/vllm/core/evictor_v1.py similarity index 100% rename from vllm/core/evictor.py rename to vllm/core/evictor_v1.py diff --git a/vllm/core/evictor_v2.py b/vllm/core/evictor_v2.py new file mode 100644 index 0000000000000..57759b29347f4 --- /dev/null +++ b/vllm/core/evictor_v2.py @@ -0,0 +1,127 @@ +import enum +from abc import ABC, abstractmethod, abstractproperty +from typing import OrderedDict, Tuple + + +class EvictionPolicy(enum.Enum): + """Enum for eviction policy used by make_evictor to instantiate the correct + Evictor subclass. + """ + LRU = enum.auto() + + +class Evictor(ABC): + """The Evictor subclasses should be used by the BlockAllocator class to + handle eviction of freed PhysicalTokenBlocks. + """ + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def __contains__(self, block_id: int) -> bool: + pass + + @abstractmethod + def evict(self) -> Tuple[int, int]: + """Runs the eviction algorithm and returns the evicted block's + content hash along with physical block id along with physical block id + """ + pass + + @abstractmethod + def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, + last_accessed: float): + """Adds block to the evictor, making it a candidate for eviction""" + pass + + @abstractmethod + def update(self, block_id: int, last_accessed: float): + """Update corresponding block's access time in metadata""" + pass + + @abstractmethod + def remove(self, block_id: int): + """Remove a given block id from the cache.""" + pass + + @abstractproperty + def num_blocks(self) -> int: + pass + + +class BlockMetaData(): + """Data structure for storing key data describe cached block, so that + evitor could use to make its decision which one to choose for eviction + + Here we use physical block id as the dict key, as there maybe several + blocks with the same content hash, but their physical id is unique. + """ + + def __init__(self, content_hash: int, num_hashed_tokens: int, + last_accessed: float): + self.content_hash = content_hash + self.num_hashed_tokens = num_hashed_tokens + self.last_accessed = last_accessed + + +class LRUEvictor(Evictor): + """Evicts in a least-recently-used order using the last_accessed timestamp + that's recorded in the PhysicalTokenBlock. If there are multiple blocks with + the same last_accessed time, then the one with the largest num_hashed_tokens + will be evicted. If two blocks each have the lowest last_accessed time and + highest num_hashed_tokens value, then one will be chose arbitrarily + """ + + def __init__(self): + self.free_table: OrderedDict[int, BlockMetaData] = OrderedDict() + + def __contains__(self, block_id: int) -> bool: + return block_id in self.free_table + + def evict(self) -> Tuple[int, int]: + if len(self.free_table) == 0: + raise ValueError("No usable cache memory left") + + evicted_block = next(iter(self.free_table.values())) + evicted_block_id = next(iter(self.free_table.keys())) + # The blocks with the lowest timestamps should be placed consecutively + # at the start of OrderedDict. Loop through all these blocks to + # find the one with maximum number of hashed tokens. + for _id, block in self.free_table.items(): + if evicted_block.last_accessed > block.last_accessed or ( + evicted_block.last_accessed == block.last_accessed and + evicted_block.num_hashed_tokens < block.num_hashed_tokens): + evicted_block = block + evicted_block_id = _id + + self.free_table.pop(evicted_block_id) + + return evicted_block_id, evicted_block.content_hash + + def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, + last_accessed: float): + self.free_table[block_id] = BlockMetaData(content_hash, + num_hashed_tokens, + last_accessed) + + def update(self, block_id: int, last_accessed: float): + self.free_table[block_id].last_accessed = last_accessed + + def remove(self, block_id: int): + if block_id not in self.free_table: + raise ValueError( + "Attempting to remove block that's not in the evictor") + self.free_table.pop(block_id) + + @property + def num_blocks(self) -> int: + return len(self.free_table) + + +def make_evictor(eviction_policy: EvictionPolicy) -> Evictor: + if eviction_policy == EvictionPolicy.LRU: + return LRUEvictor() + else: + raise ValueError(f"Unknown cache eviction policy: {eviction_policy}") diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 56c2c5995c38b..09ccaddb62615 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -63,7 +63,7 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: @abstractmethod def can_swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> bool: + num_lookahead_slots: int) -> AllocStatus: pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 024b7e7013441..de3ecd24e52db 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,4 +1,6 @@ import enum +import os +import random import time from collections import deque from dataclasses import dataclass, field @@ -11,10 +13,16 @@ from vllm.lora.request import LoRARequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) -from vllm.utils import merge_dicts logger = init_logger(__name__) +# Test-only. If configured, decode is preempted with +# ARTIFICIAL_PREEMPTION_PROB% probability. +ENABLE_ARTIFICIAL_PREEMPT = bool( + os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa +ARTIFICIAL_PREEMPTION_PROB = 0.5 +ARTIFICIAL_PREEMPTION_MAX_CNT = 500 + class PreemptionMode(enum.Enum): """Preemption modes. @@ -113,12 +121,14 @@ class SchedulerOutputs: blocks_to_swap_in: Dict[int, int] # Blocks to swap out. Dict of GPU -> CPU block number. blocks_to_swap_out: Dict[int, int] - # Blocks to copy. Source to a list of dest blocks. - blocks_to_copy: Dict[int, List[int]] + # Blocks to copy. Source to dest block. + blocks_to_copy: List[Tuple[int, int]] # Sequence groups that are going to be ignored. ignored_seq_groups: List[SequenceGroup] # The number of slots for lookahead decoding. num_lookahead_slots: int + # The number of requests in the running queue + running_queue_size: int def __post_init__(self): # Swap in and swap out should never happen at the same time. @@ -166,7 +176,7 @@ class SchedulerRunningOutputs: # The blocks to swap out. blocks_to_swap_out: Dict[int, int] # The blocks to copy. - blocks_to_copy: Dict[int, List[int]] + blocks_to_copy: List[Tuple[int, int]] # The number of slots for lookahead decoding. num_lookahead_slots: int @@ -178,7 +188,7 @@ def create_empty(cls) -> "SchedulerRunningOutputs": preempted=[], swapped_out=[], blocks_to_swap_out={}, - blocks_to_copy={}, + blocks_to_copy=[], num_lookahead_slots=0, ) @@ -198,9 +208,11 @@ class SchedulerSwappedInOutputs: # The blocks to swap in. blocks_to_swap_in: Dict[int, int] # The blocks to copy. - blocks_to_copy: Dict[int, List[int]] + blocks_to_copy: List[Tuple[int, int]] # The number of slots for lookahead decoding. num_lookahead_slots: int + # Infeasible sequence groups. + infeasible_seq_groups: List[SequenceGroup] @classmethod def create_empty(cls) -> "SchedulerSwappedInOutputs": @@ -208,8 +220,9 @@ def create_empty(cls) -> "SchedulerSwappedInOutputs": decode_seq_groups=[], prefill_seq_groups=[], blocks_to_swap_in={}, - blocks_to_copy={}, + blocks_to_copy=[], num_lookahead_slots=0, + infeasible_seq_groups=[], ) @@ -286,6 +299,13 @@ def __init__( # Latency of the last prompt step self.last_prompt_latency = 0.0 + # The following field is test-only. It is used to inject artificial + # preemption. + self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT + self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT + if self.enable_artificial_preemption + else 0) + @property def lora_enabled(self) -> bool: return bool(self.lora_config) @@ -373,7 +393,7 @@ def _schedule_running( """ # Blocks that need to be swapped or copied before model execution. blocks_to_swap_out: Dict[int, int] = {} - blocks_to_copy: Dict[int, List[int]] = {} + blocks_to_copy: List[Tuple[int, int]] = [] decode_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = [] @@ -386,15 +406,13 @@ def _schedule_running( # groups to preempt. now = time.time() running_queue = policy.sort_by_priority(now, running_queue) - while running_queue: seq_group = running_queue[0] num_running_tokens = self._get_num_new_tokens( seq_group, SequenceStatus.RUNNING, enable_chunking, budget) - # We can have up to 1 running prefill at any given time in running - # queue, which means we can guarantee chunk size is at least 1. - assert num_running_tokens != 0 + if num_running_tokens == 0: + break running_queue.popleft() while not self._can_append_slots(seq_group): @@ -449,9 +467,6 @@ def _schedule_running( if curr_loras is not None and seq_group.lora_int_id > 0: curr_loras.add(seq_group.lora_int_id) - # Make sure all queues are updated. - assert len(running_queue) == 0 - return running_queue, SchedulerRunningOutputs( decode_seq_groups=decode_seq_groups, prefill_seq_groups=prefill_seq_groups, @@ -495,19 +510,31 @@ def _schedule_swapped( """ # Blocks that need to be swapped or copied before model execution. blocks_to_swap_in: Dict[int, int] = {} - blocks_to_copy: Dict[int, List[int]] = {} + blocks_to_copy: List[Tuple[int, int]] = [] decode_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = [] now = time.time() swapped_queue = policy.sort_by_priority(now, swapped_queue) + infeasible_seq_groups: List[SequenceGroup] = [] leftover_swapped: Deque[SequenceGroup] = deque() while swapped_queue: seq_group = swapped_queue[0] # If the sequence group cannot be swapped in, stop. - if not self.block_manager.can_swap_in(seq_group): + alloc_status = self.block_manager.can_swap_in(seq_group) + if alloc_status == AllocStatus.LATER: break + elif alloc_status == AllocStatus.NEVER: + logger.warning( + "Failing the request %s because there's not enough kv " + "cache blocks to run the entire sequence.", + seq_group.request_id) + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_IGNORED + infeasible_seq_groups.append(seq_group) + swapped_queue.popleft() + continue lora_int_id = 0 if self.lora_enabled: @@ -545,7 +572,6 @@ def _schedule_swapped( ScheduledSequenceGroup(seq_group, token_chunk_size=num_new_tokens)) else: - assert num_new_tokens == 1 decode_seq_groups.append( ScheduledSequenceGroup(seq_group, token_chunk_size=1)) budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) @@ -559,7 +585,9 @@ def _schedule_swapped( blocks_to_swap_in=blocks_to_swap_in, blocks_to_copy=blocks_to_copy, num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=False)) + is_prefill=False), + infeasible_seq_groups=infeasible_seq_groups, + ) def _schedule_prefills( self, @@ -765,10 +793,12 @@ def _schedule_default(self) -> SchedulerOutputs: num_batched_tokens=budget.num_batched_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, - swapped_in.blocks_to_copy), - ignored_seq_groups=prefills.ignored_seq_groups, + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, + ignored_seq_groups=prefills.ignored_seq_groups + + swapped_in.infeasible_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, + running_queue_size=len(self.running), ) def _schedule_chunked_prefill(self): @@ -851,10 +881,11 @@ def _schedule_chunked_prefill(self): num_batched_tokens=budget.num_batched_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, - blocks_to_copy=merge_dicts(running_scheduled.blocks_to_copy, - swapped_in.blocks_to_copy), + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, ignored_seq_groups=prefills.ignored_seq_groups, num_lookahead_slots=running_scheduled.num_lookahead_slots, + running_queue_size=len(self.running), ) def _schedule(self) -> SchedulerOutputs: @@ -868,6 +899,13 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: """Determine whether or not we have enough space in the KV cache to continue generation of the sequence group. """ + # It is True only for testing case to trigger artificial preemption. + if (self.enable_artificial_preemption + and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB + and self.artificial_preempt_cnt > 0): + self.artificial_preempt_cnt -= 1 + return False + # Appending slots only occurs in decoding. is_prefill = False @@ -876,15 +914,6 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), ) - def _can_swap_in(self, seq_group: SequenceGroup) -> bool: - # Swapping in is considered decode. - is_prefill = False - - return self.block_manager.can_swap_in( - seq_group=seq_group, - num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), - ) - def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # Schedule sequence groups. # This function call changes the internal states of the scheduler @@ -981,17 +1010,18 @@ def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: def _append_slots( self, seq_group: SequenceGroup, - blocks_to_copy: Dict[int, List[int]], + blocks_to_copy: List[Tuple[int, int]], ) -> None: """Appends new slots to the sequences in the given sequence group. Args: seq_group (SequenceGroup): The sequence group containing the sequences to append slots to. - blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source - block indices to lists of destination block indices. This - dictionary is updated with the new source and destination block - indices for the appended slots. + blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two + ints, the first int is the source block index, and the second + int is the destination block index. This list is updated with + the new source and destination block indices for the appended + slots. """ num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) @@ -999,9 +1029,8 @@ def _append_slots( cows = self.block_manager.append_slots(seq, num_lookahead_slots) for src, dests in cows.items(): - if src not in blocks_to_copy: - blocks_to_copy[src] = [] - blocks_to_copy[src].extend(dests) + for dest in dests: + blocks_to_copy.append((src, dest)) def _preempt( self, @@ -1116,11 +1145,14 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, if `enable_chunking` is True. If a sequence group has multiple sequences (e.g., running beam search), it means it is in decoding phase, so chunking doesn't happen. + + Returns 0 if the new token cannot be computed due to token budget. """ num_new_tokens = 0 seqs = seq_group.get_seqs(status=status) for seq in seqs: num_new_tokens += seq.get_num_new_tokens() + assert num_new_tokens > 0 # Chunk if a running request cannot fit in. # If number of seq > 1, it means it is doing beam search in a # decode phase. Do not chunk in that case. diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 8b2c26c3a8afb..817bd6d812e48 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -34,7 +34,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: if out is not None: return out if is_pynccl_enabled_for_all_reduce(): - # TODO: support multiple parallel groups. pynccl_utils.all_reduce(input_) else: torch.distributed.all_reduce(input_, @@ -204,6 +203,9 @@ def broadcast_tensor_dict( group=metadata_group) async_handles = [] for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue async_handles.append( torch.distributed.broadcast(tensor, src=src, @@ -225,6 +227,10 @@ def broadcast_tensor_dict( tensor = torch.empty(value.size, dtype=value.dtype, device="cuda") + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue async_handle = torch.distributed.broadcast(tensor, src=src, async_op=True, diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index ec4533326e841..cc5f8166877ce 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -1,10 +1,10 @@ -import os from contextlib import contextmanager from typing import Any, List, Optional import torch import torch.distributed as dist +import vllm.envs as envs from vllm.logger import init_logger try: @@ -54,9 +54,9 @@ def init_custom_ar() -> None: return # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported - if "CUDA_VISIBLE_DEVICES" in os.environ: - device_ids = list( - map(int, os.environ["CUDA_VISIBLE_DEVICES"].split(","))) + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) else: device_ids = list(range(num_dev)) # this checks hardware and driver support for NVLink diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 9434867e1b120..758994352e3de 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -200,6 +200,10 @@ def from_torch(cls, op: ReduceOp) -> int: ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p ] +# be cautious! this is a collective call, it will block until all +# processes in the communicator have called this function. +# because Python object destruction can happen in random order, +# it is better not to call it at all. # equivalent to c declaration: # ncclResult_t ncclCommDestroy(ncclComm_t comm); _c_ncclCommDestroy = nccl.ncclCommDestroy @@ -228,6 +232,7 @@ def __init__( assert dist.get_backend(group) != dist.Backend.NCCL, ( "NCCLCommunicator should be attached to a non-NCCL group.") self.group = group + # note: this rank is the rank in the group self.rank = dist.get_rank(group) self.world_size = dist.get_world_size(group) if self.rank == 0: @@ -235,7 +240,9 @@ def __init__( else: self.unique_id = NcclUniqueId() tensor = torch.ByteTensor(list(self.unique_id.internal)) - dist.broadcast(tensor, src=0, group=group) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) byte_list = tensor.tolist() for i, byte in enumerate(byte_list): self.unique_id.internal[i] = byte @@ -278,11 +285,3 @@ def all_reduce(self, ncclDataTypeEnum.from_torch(tensor.dtype), ncclRedOpTypeEnum.from_torch(op), self.comm, ctypes.c_void_p(stream.cuda_stream))) - - def __del__(self): - # `dist` module might have been already destroyed - if hasattr(dist, 'destroy_process_group'): - dist.destroy_process_group() - # function might have been already destroyed - if _c_ncclCommDestroy is not None: - _c_ncclCommDestroy(self.comm) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 6ca6fc5b5f9fe..be5bb4e857caf 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -4,17 +4,18 @@ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. """Tensor and pipeline parallel groups.""" import contextlib -import os from typing import Optional import torch +import vllm.envs as envs from vllm.logger import init_logger logger = init_logger(__name__) # Tensor model parallel group that the current rank belongs to. -_TENSOR_MODEL_PARALLEL_GROUP = None +_TP_DEVICE_GROUP = None +_TP_CPU_GROUP = None # Pipeline model parallel group that the current rank belongs to. _PIPELINE_MODEL_PARALLEL_GROUP = None @@ -80,7 +81,7 @@ def init_distributed_environment( # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 if local_rank == -1 and distributed_init_method == "env://": - local_rank = int(os.environ['LOCAL_RANK']) + local_rank = envs.LOCAL_RANK global _LOCAL_RANK _LOCAL_RANK = local_rank @@ -132,15 +133,17 @@ def initialize_model_parallel( rank = torch.distributed.get_rank() # Build the tensor model-parallel groups. - global _TENSOR_MODEL_PARALLEL_GROUP - assert _TENSOR_MODEL_PARALLEL_GROUP is None, ( + global _TP_DEVICE_GROUP, _TP_CPU_GROUP + assert _TP_DEVICE_GROUP is None, ( "tensor model parallel group is already initialized") for i in range(num_tensor_model_parallel_groups): ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) group = torch.distributed.new_group(ranks, backend=backend) + cpu_group = torch.distributed.new_group(ranks, backend="gloo") if rank in ranks: - _TENSOR_MODEL_PARALLEL_GROUP = group + _TP_DEVICE_GROUP = group + _TP_CPU_GROUP = cpu_group # Build the pipeline model-parallel groups. global _PIPELINE_MODEL_PARALLEL_GROUP @@ -185,7 +188,7 @@ def ensure_model_parallel_initialized( def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" - return (_TENSOR_MODEL_PARALLEL_GROUP is not None + return (_TP_DEVICE_GROUP is not None and _PIPELINE_MODEL_PARALLEL_GROUP is not None) @@ -197,9 +200,16 @@ def get_cpu_world_group(): def get_tensor_model_parallel_group(): """Get the tensor model parallel group the caller rank belongs to.""" - assert _TENSOR_MODEL_PARALLEL_GROUP is not None, ( + assert _TP_DEVICE_GROUP is not None, ( "tensor model parallel group is not initialized") - return _TENSOR_MODEL_PARALLEL_GROUP + return _TP_DEVICE_GROUP + + +def get_tensor_model_parallel_cpu_group(): + """Get the tensor model parallel cpu group the caller rank belongs to.""" + assert _TP_CPU_GROUP is not None, ( + "tensor model parallel cpu group is not initialized") + return _TP_CPU_GROUP def get_pipeline_model_parallel_group(): @@ -277,10 +287,14 @@ def get_pipeline_model_parallel_prev_rank(): def destroy_model_parallel(): """Set the groups to none and destroy them.""" - global _TENSOR_MODEL_PARALLEL_GROUP - if _TENSOR_MODEL_PARALLEL_GROUP: - torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP) - _TENSOR_MODEL_PARALLEL_GROUP = None + global _TP_DEVICE_GROUP + if _TP_DEVICE_GROUP: + torch.distributed.destroy_process_group(_TP_DEVICE_GROUP) + _TP_DEVICE_GROUP = None + global _TP_CPU_GROUP + if _TP_CPU_GROUP: + torch.distributed.destroy_process_group(_TP_CPU_GROUP) + _TP_CPU_GROUP = None global _PIPELINE_MODEL_PARALLEL_GROUP if _PIPELINE_MODEL_PARALLEL_GROUP: torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP) diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 9a13b94c3ada1..1965d4c1d3cbc 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -9,6 +9,7 @@ import torch import torch.distributed as dist +import vllm.envs as envs from vllm.logger import init_logger from .parallel_state import get_cpu_world_group, get_local_rank @@ -102,11 +103,13 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: is_distributed = dist.is_initialized() num_dev = torch.cuda.device_count() - cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES if cuda_visible_devices is None: cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) + VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT path = os.path.expanduser( - f"~/.config/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json") + f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json" + ) os.makedirs(os.path.dirname(path), exist_ok=True) if (not is_distributed or get_local_rank() == 0) \ and (not os.path.exists(path)): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bd6437ee44c28..bb8245eb307f7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,7 +1,7 @@ import argparse import dataclasses from dataclasses import dataclass -from typing import Optional +from typing import List, Optional, Union from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, @@ -11,10 +11,17 @@ from vllm.utils import str_to_int_tuple +def nullable_str(val: str): + if not val or val == "None": + return None + return val + + @dataclass class EngineArgs: """Arguments for vLLM engine.""" model: str + served_model_name: Optional[Union[List[str]]] = None tokenizer: Optional[str] = None skip_tokenizer_init: bool = False tokenizer_mode: str = 'auto' @@ -44,7 +51,8 @@ class EngineArgs: tokenizer_revision: Optional[str] = None quantization: Optional[str] = None enforce_eager: bool = False - max_context_len_to_capture: int = 8192 + max_context_len_to_capture: Optional[int] = None + max_seq_len_to_capture: int = 8192 disable_custom_all_reduce: bool = False tokenizer_pool_size: int = 0 tokenizer_pool_type: str = "ray" @@ -75,6 +83,8 @@ class EngineArgs: speculative_model: Optional[str] = None num_speculative_tokens: Optional[int] = None speculative_max_model_len: Optional[int] = None + ngram_prompt_lookup_max: Optional[int] = None + ngram_prompt_lookup_min: Optional[int] = None def __post_init__(self): if self.tokenizer is None: @@ -93,7 +103,7 @@ def add_cli_args( help='Name or path of the huggingface model to use.') parser.add_argument( '--tokenizer', - type=str, + type=nullable_str, default=EngineArgs.tokenizer, help='Name or path of the huggingface tokenizer to use.') parser.add_argument( @@ -102,21 +112,21 @@ def add_cli_args( help='Skip initialization of tokenizer and detokenizer') parser.add_argument( '--revision', - type=str, + type=nullable_str, default=None, help='The specific model version to use. It can be a branch ' 'name, a tag name, or a commit id. If unspecified, will use ' 'the default version.') parser.add_argument( '--code-revision', - type=str, + type=nullable_str, default=None, help='The specific revision to use for the model code on ' 'Hugging Face Hub. It can be a branch name, a tag name, or a ' 'commit id. If unspecified, will use the default version.') parser.add_argument( '--tokenizer-revision', - type=str, + type=nullable_str, default=None, help='The specific tokenizer version to use. It can be a branch ' 'name, a tag name, or a commit id. If unspecified, will use ' @@ -133,7 +143,7 @@ def add_cli_args( action='store_true', help='Trust remote code from huggingface.') parser.add_argument('--download-dir', - type=str, + type=nullable_str, default=EngineArgs.download_dir, help='Directory to download and load the weights, ' 'default to the default cache dir of ' @@ -184,7 +194,7 @@ def add_cli_args( 'supported for common inference criteria.') parser.add_argument( '--quantization-param-path', - type=str, + type=nullable_str, default=None, help='Path to the JSON file containing the KV cache ' 'scaling factors. This should generally be supplied, when ' @@ -301,7 +311,7 @@ def add_cli_args( # Quantization settings. parser.add_argument('--quantization', '-q', - type=str, + type=nullable_str, choices=[*QUANTIZATION_METHODS, None], default=EngineArgs.quantization, help='Method used to quantize the weights. If ' @@ -320,6 +330,14 @@ def add_cli_args( default=EngineArgs.max_context_len_to_capture, help='Maximum context length covered by CUDA ' 'graphs. When a sequence has context length ' + 'larger than this, we fall back to eager mode. ' + '(DEPRECATED. Use --max-seq_len-to-capture instead' + ')') + parser.add_argument('--max-seq_len-to-capture', + type=int, + default=EngineArgs.max_seq_len_to_capture, + help='Maximum sequence length covered by CUDA ' + 'graphs. When a sequence has context length ' 'larger than this, we fall back to eager mode.') parser.add_argument('--disable-custom-all-reduce', action='store_true', @@ -338,7 +356,7 @@ def add_cli_args( 'asynchronous tokenization. Ignored ' 'if tokenizer_pool_size is 0.') parser.add_argument('--tokenizer-pool-extra-config', - type=str, + type=nullable_str, default=EngineArgs.tokenizer_pool_extra_config, help='Extra config for tokenizer pool. ' 'This should be a JSON string that will be ' @@ -393,7 +411,7 @@ def add_cli_args( # Related to Vision-language models such as llava parser.add_argument( '--image-input-type', - type=str, + type=nullable_str, default=None, choices=[ t.name.lower() for t in VisionLanguageConfig.ImageInputType @@ -406,7 +424,7 @@ def add_cli_args( help=('Input id for image token.')) parser.add_argument( '--image-input-shape', - type=str, + type=nullable_str, default=None, help=('The biggest image input shape (worst for memory footprint) ' 'given an input type. Only used for vLLM\'s profile_run.')) @@ -429,7 +447,7 @@ def add_cli_args( parser.add_argument( '--speculative-model', - type=str, + type=nullable_str, default=EngineArgs.speculative_model, help= 'The name of the draft model to be used in speculative decoding.') @@ -443,14 +461,28 @@ def add_cli_args( parser.add_argument( '--speculative-max-model-len', - type=str, + type=int, default=EngineArgs.speculative_max_model_len, help='The maximum sequence length supported by the ' 'draft model. Sequences over this length will skip ' 'speculation.') + parser.add_argument( + '--ngram-prompt-lookup-max', + type=int, + default=EngineArgs.ngram_prompt_lookup_max, + help='Max size of window for ngram prompt lookup in speculative ' + 'decoding.') + + parser.add_argument( + '--ngram-prompt-lookup-min', + type=int, + default=EngineArgs.ngram_prompt_lookup_min, + help='Min size of window for ngram prompt lookup in speculative ' + 'decoding.') + parser.add_argument('--model-loader-extra-config', - type=str, + type=nullable_str, default=EngineArgs.model_loader_extra_config, help='Extra config for model loader. ' 'This will be passed to the model loader ' @@ -458,6 +490,21 @@ def add_cli_args( 'This should be a JSON string that will be ' 'parsed into a dictionary.') + parser.add_argument( + "--served-model-name", + nargs="+", + type=str, + default=None, + help="The model name(s) used in the API. If multiple " + "names are provided, the server will respond to any " + "of the provided names. The model name in the model " + "field of a response will be the first name in this " + "list. If not specified, the model name will be the " + "same as the `--model` argument. Noted that this name(s)" + "will also be used in `model_name` tag content of " + "prometheus metrics, if multiple names provided, metrics" + "tag will take the first one.") + return parser @classmethod @@ -476,7 +523,8 @@ def create_engine_config(self, ) -> EngineConfig: self.code_revision, self.tokenizer_revision, self.max_model_len, self.quantization, self.quantization_param_path, self.enforce_eager, self.max_context_len_to_capture, - self.max_logprobs, self.skip_tokenizer_init) + self.max_seq_len_to_capture, self.max_logprobs, + self.skip_tokenizer_init, self.served_model_name) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, self.kv_cache_dtype, @@ -502,6 +550,8 @@ def create_engine_config(self, ) -> EngineConfig: speculative_max_model_len=self.speculative_max_model_len, enable_chunked_prefill=self.enable_chunked_prefill, use_v2_block_manager=self.use_v2_block_manager, + ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, + ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, ) scheduler_config = SchedulerConfig( diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7c1eb2ecbe550..37a2dc77a3b50 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,13 +1,14 @@ import asyncio -import os import time from functools import partial -from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List, - Optional, Set, Tuple, Type, Union) +from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, + Set, Tuple, Type, Union) from transformers import PreTrainedTokenizer +import vllm.envs as envs from vllm.config import DecodingConfig, ModelConfig +from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.executor.ray_utils import initialize_ray_cluster, ray @@ -15,12 +16,11 @@ from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import MultiModalData +from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) -ENGINE_ITERATION_TIMEOUT_S = int( - os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")) +ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S class AsyncEngineDeadError(RuntimeError): @@ -210,10 +210,16 @@ async def step_async(self) -> List[RequestOutput]: if not scheduler_outputs.is_empty(): # Execute the model. + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots, + running_queue_size=scheduler_outputs.running_queue_size, + ) output = await self.model_executor.execute_model_async( - seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in, - scheduler_outputs.blocks_to_swap_out, - scheduler_outputs.blocks_to_copy) + execute_model_req) else: output = [] @@ -222,8 +228,7 @@ async def step_async(self) -> List[RequestOutput]: scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) # Log stats. - if self.log_stats: - self.stat_logger.log(self._get_stats(scheduler_outputs)) + self.do_log_stats(scheduler_outputs, output) return request_outputs @@ -322,7 +327,7 @@ def __init__(self, # We need to keep a reference to unshielded # task as well to prevent it from being garbage # collected - self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None + self._background_loop_unshielded: Optional[asyncio.Task] = None self.start_engine_loop = start_engine_loop self._errored_with: Optional[BaseException] = None @@ -705,9 +710,13 @@ async def get_decoding_config(self) -> DecodingConfig: else: return self.engine.get_decoding_config() - async def do_log_stats(self) -> None: + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None) -> None: if self.engine_use_ray: - await self.engine.do_log_stats.remote() # type: ignore + await self.engine.do_log_stats.remote( # type: ignore + scheduler_outputs, model_output) else: self.engine.do_log_stats() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 835803fd4e75d..b9938b045ba2b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -8,7 +8,8 @@ LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VisionLanguageConfig) -from vllm.core.scheduler import Scheduler, SchedulerOutputs +from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, + SchedulerOutputs) from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats from vllm.engine.output_processor.interfaces import ( @@ -21,8 +22,8 @@ from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupMetadata, +from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput, + Sequence, SequenceGroup, SequenceGroupMetadata, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, @@ -105,7 +106,7 @@ def __init__( "tensor_parallel_size=%d, disable_custom_all_reduce=%s, " "quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, " "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, seed=%d)", + "decoding_config=%r, seed=%d, served_model_name=%s)", vllm.__version__, model_config.model, speculative_config, @@ -128,6 +129,7 @@ def __init__( device_config.device, decoding_config, model_config.seed, + model_config.served_model_name, ) # TODO(woosuk): Print more configs in debug mode. @@ -218,7 +220,7 @@ def __init__( if self.log_stats: self.stat_logger = StatLogger( local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - labels=dict(model_name=model_config.model), + labels=dict(model_name=model_config.served_model_name), max_model_len=self.model_config.max_model_len) self.stat_logger.info("cache_config", self.cache_config) @@ -485,7 +487,7 @@ def has_unfinished_requests(self) -> bool: def _process_model_outputs( self, output: List[SamplerOutput], - scheduled_seq_groups: List[SequenceGroup], + scheduled_seq_groups: List[ScheduledSequenceGroup], ignored_seq_groups: List[SequenceGroup], seq_group_metadata_list: List[SequenceGroupMetadata], ) -> List[RequestOutput]: @@ -582,12 +584,16 @@ def step(self) -> List[RequestOutput]: seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() if not scheduler_outputs.is_empty(): - output = self.model_executor.execute_model( + execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_copy=scheduler_outputs.blocks_to_copy, - num_lookahead_slots=scheduler_outputs.num_lookahead_slots) + num_lookahead_slots=scheduler_outputs.num_lookahead_slots, + running_queue_size=scheduler_outputs.running_queue_size, + ) + output = self.model_executor.execute_model( + execute_model_req=execute_model_req) else: output = [] @@ -596,16 +602,18 @@ def step(self) -> List[RequestOutput]: scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) # Log stats. - if self.log_stats: - self.stat_logger.log( - self._get_stats(scheduler_outputs, model_output=output)) + self.do_log_stats(scheduler_outputs, output) return request_outputs - def do_log_stats(self) -> None: + def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None) -> None: """Forced log when no requests active.""" if self.log_stats: - self.stat_logger.log(self._get_stats(scheduler_outputs=None)) + self.stat_logger.log( + self._get_stats(scheduler_outputs, model_output)) def _get_stats( self, @@ -659,10 +667,10 @@ def _get_stats( # decode seq_groups in scheduled_seq_groups. if scheduler_outputs is not None: num_generation_tokens_from_prefill_groups = 0. - if scheduler_outputs.num_prefill_groups > 0 and len( - scheduler_outputs.scheduled_seq_groups - ) != scheduler_outputs.num_prefill_groups: - print("DETECTED CHUNKED") + # NOTE: if scheduler_outputs.num_prefill_groups > 0 and + # the len of scheduler_outputs.scheduled_seq_groups is != + # scheduler_outputs.num_prefill_groups, this means that + # chunked prefills have been detected. for idx, scheduled_seq_group in enumerate( scheduler_outputs.scheduled_seq_groups): diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 45bfad03ec867..3c4aac91549a9 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -119,7 +119,7 @@ def __init__(self, labelnames: List[str], max_model_len: int): buckets=[1, 2, 5, 10, 20], ) self.counter_request_success = Counter( - name="vllm:request_success", + name="vllm:request_success_total", documentation="Count of successfully processed requests.", labelnames=labelnames + [Metrics.labelname_finish_reason]) diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 9abd87a4d5a9a..5f2f433aa811f 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -1,3 +1,4 @@ +import functools from typing import Callable, List from transformers import PreTrainedTokenizer @@ -8,8 +9,8 @@ from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.sequence import (Logprob, Sequence, SequenceGroup, - SequenceGroupOutput, SequenceOutput, SequenceStatus) +from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, + SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.utils import Counter @@ -48,10 +49,14 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: # TODO(sang): Prompt logprob currently not implemented in multi step # workers. + self._log_prompt_logprob_unsupported_warning_once() + + @staticmethod + @functools.lru_cache() + def _log_prompt_logprob_unsupported_warning_once(): logger.warning( "Prompt logprob is not supported by multi step workers. " "(e.g., speculative decode uses multi step workers).") - pass def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: @@ -89,6 +94,7 @@ def _process_seq_outputs(self, seq: Sequence, valid_samples: List[SequenceOutput], sampling_params: SamplingParams) -> None: output_token_ids = [sample.output_token for sample in valid_samples] + output_logprobs = [sample.logprobs for sample in valid_samples] # Truncate to max_tokens if necessary. remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + @@ -113,11 +119,11 @@ def _process_seq_outputs(self, seq: Sequence, # Incrementally append tokens to the sequence, as if we had only one new # token. - for output_token_id in output_token_ids: + for output_token_id, output_logprob in zip(output_token_ids, + output_logprobs): seq.append_token_id( token_id=output_token_id, - # TODO emit logprobs in multi-step decoding. - logprobs={output_token_id: Logprob(0.0)}, + logprobs=output_logprob, ) new_char_count = 0 diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b022707794a78..3ed660e183360 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -69,6 +69,9 @@ class LLM: disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead). + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. disable_custom_all_reduce: See ParallelConfig @@ -90,7 +93,8 @@ def __init__( gpu_memory_utilization: float = 0.9, swap_space: int = 4, enforce_eager: bool = False, - max_context_len_to_capture: int = 8192, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, **kwargs, ) -> None: @@ -112,6 +116,7 @@ def __init__( swap_space=swap_space, enforce_eager=enforce_eager, max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, **kwargs, ) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index af9ba7a3bc825..44a946f2e32d4 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1,9 +1,10 @@ import asyncio import importlib import inspect -import os +import re from contextlib import asynccontextmanager from http import HTTPStatus +from typing import Set import fastapi import uvicorn @@ -12,8 +13,10 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse from prometheus_client import make_asgi_app +from starlette.routing import Mount import vllm +import vllm.envs as envs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.cli_args import make_arg_parser @@ -31,6 +34,8 @@ openai_serving_completion: OpenAIServingCompletion logger = init_logger(__name__) +_running_tasks: Set[asyncio.Task] = set() + @asynccontextmanager async def lifespan(app: fastapi.FastAPI): @@ -41,7 +46,9 @@ async def _force_log(): await engine.do_log_stats() if not engine_args.disable_log_stats: - asyncio.create_task(_force_log()) + task = asyncio.create_task(_force_log()) + _running_tasks.add(task) + task.add_done_callback(_running_tasks.remove) yield @@ -55,8 +62,10 @@ def parse_args(): # Add prometheus asgi middleware to route /metrics requests -metrics_app = make_asgi_app() -app.mount("/metrics", metrics_app) +route = Mount("/metrics", make_asgi_app()) +# Workaround for 307 Redirect for /metrics +route.path_regex = re.compile('^/metrics(?P.*)$') +app.routes.append(route) @app.exception_handler(RequestValidationError) @@ -125,7 +134,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request): allow_headers=args.allowed_headers, ) - if token := os.environ.get("VLLM_API_KEY") or args.api_key: + if token := envs.VLLM_API_KEY or args.api_key: @app.middleware("http") async def authentication(request: Request, call_next): diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 16c5b6c08d37f..4c0cb1e4f3e49 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -8,7 +8,7 @@ import json import ssl -from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.openai.serving_engine import LoRAModulePath @@ -25,7 +25,10 @@ def __call__(self, parser, namespace, values, option_string=None): def make_arg_parser(): parser = argparse.ArgumentParser( description="vLLM OpenAI-Compatible RESTful API server.") - parser.add_argument("--host", type=str, default=None, help="host name") + parser.add_argument("--host", + type=nullable_str, + default=None, + help="host name") parser.add_argument("--port", type=int, default=8000, help="port number") parser.add_argument( "--uvicorn-log-level", @@ -49,49 +52,39 @@ def make_arg_parser(): default=["*"], help="allowed headers") parser.add_argument("--api-key", - type=str, + type=nullable_str, default=None, help="If provided, the server will require this key " "to be presented in the header.") - parser.add_argument("--served-model-name", - nargs="+", - type=str, - default=None, - help="The model name(s) used in the API. If multiple " - "names are provided, the server will respond to any " - "of the provided names. The model name in the model " - "field of a response will be the first name in this " - "list. If not specified, the model name will be the " - "same as the `--model` argument.") parser.add_argument( "--lora-modules", - type=str, + type=nullable_str, default=None, nargs='+', action=LoRAParserAction, help="LoRA module configurations in the format name=path. " "Multiple modules can be specified.") parser.add_argument("--chat-template", - type=str, + type=nullable_str, default=None, help="The file path to the chat template, " "or the template in single-line form " "for the specified model") parser.add_argument("--response-role", - type=str, + type=nullable_str, default="assistant", help="The role name to return if " "`request.add_generation_prompt=true`.") parser.add_argument("--ssl-keyfile", - type=str, + type=nullable_str, default=None, help="The file path to the SSL key file") parser.add_argument("--ssl-certfile", - type=str, + type=nullable_str, default=None, help="The file path to the SSL cert file") parser.add_argument("--ssl-ca-certs", - type=str, + type=nullable_str, default=None, help="The CA certificates file") parser.add_argument( @@ -102,12 +95,12 @@ def make_arg_parser(): ) parser.add_argument( "--root-path", - type=str, + type=nullable_str, default=None, help="FastAPI root_path when app is behind a path based routing proxy") parser.add_argument( "--middleware", - type=str, + type=nullable_str, action="append", default=[], help="Additional ASGI middleware to apply to the app. " diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0a949f9867754..3cd9ddad3b7b7 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -79,7 +79,9 @@ class ChatCompletionRequest(OpenAIBaseModel): n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 response_format: Optional[ResponseFormat] = None - seed: Optional[int] = None + seed: Optional[int] = Field(None, + ge=torch.iinfo(torch.long).min, + le=torch.iinfo(torch.long).max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False temperature: Optional[float] = 0.7 @@ -146,6 +148,11 @@ class ChatCompletionRequest(OpenAIBaseModel): "If specified, will override the default guided decoding backend " "of the server for this specific request. If set, must be either " "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) # doc: end-chat-completion-extra-params @@ -223,7 +230,9 @@ class CompletionRequest(OpenAIBaseModel): max_tokens: Optional[int] = 16 n: int = 1 presence_penalty: Optional[float] = 0.0 - seed: Optional[int] = None + seed: Optional[int] = Field(None, + ge=torch.iinfo(torch.long).min, + le=torch.iinfo(torch.long).max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False suffix: Optional[str] = None @@ -285,6 +294,11 @@ class CompletionRequest(OpenAIBaseModel): "If specified, will override the default guided decoding backend " "of the server for this specific request. If set, must be one of " "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) # doc: end-completion-extra-params diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5ed042ef386ea..c8f4a6b315db0 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,3 +1,4 @@ +import asyncio import codecs import time from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List, @@ -40,9 +41,11 @@ def __init__(self, chat_template: Optional[str] = None): super().__init__(engine=engine, served_model_names=served_model_names, - lora_modules=lora_modules) + lora_modules=lora_modules, + await_post_init=self._load_chat_template( + chat_template=chat_template)) + self.response_role = response_role - self._load_chat_template(chat_template) def _parse_chat_message_content( self, @@ -55,9 +58,16 @@ def _parse_chat_message_content( if isinstance(content, str): return [ConversationMessage(role=role, content=content)], [] - # To be implemented: https://github.com/vllm-project/vllm/pull/3467 - # To be implemented: https://github.com/vllm-project/vllm/pull/4200 - raise NotImplementedError("Complex input not supported yet") + texts: List[str] = [] + for _, part in enumerate(content): + if part["type"] == "text": + text = part["text"] + + texts.append(text) + else: + raise NotImplementedError(f"Unknown part type: {part['type']}") + + return [ConversationMessage(role=role, content="\n".join(texts))], [] async def create_chat_completion( self, request: ChatCompletionRequest, raw_request: Request @@ -122,11 +132,12 @@ async def create_chat_completion( # Streaming response if request.stream: return self.chat_completion_stream_generator( - request, result_generator, request_id) + request, result_generator, request_id, conversation) else: try: return await self.chat_completion_full_generator( - request, raw_request, result_generator, request_id) + request, raw_request, result_generator, request_id, + conversation) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -139,8 +150,9 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: async def chat_completion_stream_generator( self, request: ChatCompletionRequest, - result_generator: AsyncIterator[RequestOutput], - request_id: str) -> AsyncGenerator[str, None]: + result_generator: AsyncIterator[RequestOutput], request_id: str, + conversation: List[ConversationMessage] + ) -> AsyncGenerator[str, None]: model_name = self.served_model_names[0] created_time = int(time.time()) chunk_object_type = "chat.completion.chunk" @@ -179,12 +191,10 @@ async def chat_completion_stream_generator( # last message if request.echo: last_msg_content = "" - if request.messages and isinstance( - request.messages, - list) and request.messages[-1].get( - "content") and request.messages[-1].get( - "role") == role: - last_msg_content = request.messages[-1]["content"] + if conversation and conversation[-1].get( + "content") and conversation[-1].get( + "role") == role: + last_msg_content = conversation[-1]["content"] if last_msg_content: for i in range(request.n): @@ -279,9 +289,10 @@ async def chat_completion_stream_generator( yield "data: [DONE]\n\n" async def chat_completion_full_generator( - self, request: ChatCompletionRequest, raw_request: Request, - result_generator: AsyncIterator[RequestOutput], - request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]: + self, request: ChatCompletionRequest, raw_request: Request, + result_generator: AsyncIterator[RequestOutput], request_id: str, + conversation: List[ConversationMessage] + ) -> Union[ErrorResponse, ChatCompletionResponse]: model_name = self.served_model_names[0] created_time = int(time.time()) @@ -322,11 +333,9 @@ async def chat_completion_full_generator( if request.echo: last_msg_content = "" - if request.messages and isinstance( - request.messages, list) and request.messages[-1].get( - "content") and request.messages[-1].get( - "role") == role: - last_msg_content = request.messages[-1]["content"] + if conversation and conversation[-1].get( + "content") and conversation[-1].get("role") == role: + last_msg_content = conversation[-1]["content"] for choice in choices: full_message = last_msg_content + choice.message.content @@ -350,7 +359,10 @@ async def chat_completion_full_generator( return response - def _load_chat_template(self, chat_template: Optional[str]): + async def _load_chat_template(self, chat_template: Optional[str]): + while self.tokenizer is None: + # Give the parent class time to load the tokenizer + await asyncio.sleep(0.1) tokenizer = self.tokenizer if chat_template is not None: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 3d5ed328b9d19..21baea2e5e7f6 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -2,7 +2,7 @@ import json from dataclasses import dataclass from http import HTTPStatus -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union from pydantic import Field from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -29,8 +29,11 @@ class LoRAModulePath: class OpenAIServing: - def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]]): + def __init__(self, + engine: AsyncLLMEngine, + served_model_names: List[str], + lora_modules: Optional[List[LoRAModulePath]], + await_post_init: Optional[Awaitable[Any]] = None): self.engine = engine self.served_model_names = served_model_names if lora_modules is None: @@ -56,12 +59,12 @@ def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str], if event_loop is not None and event_loop.is_running(): # If the current is instanced by Ray Serve, # there is already a running event loop - event_loop.create_task(self._post_init()) + event_loop.create_task(self._post_init(await_post_init)) else: # When using single vLLM without engine_use_ray - asyncio.run(self._post_init()) + asyncio.run(self._post_init(await_post_init)) - async def _post_init(self): + async def _post_init(self, await_post_init): engine_model_config = await self.engine.get_model_config() self.max_model_len = engine_model_config.max_model_len @@ -73,6 +76,9 @@ async def _post_init(self): trust_remote_code=engine_model_config.trust_remote_code, truncation_side="left") + if await_post_init is not None: + await await_post_init + async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ diff --git a/vllm/envs.py b/vllm/envs.py new file mode 100644 index 0000000000000..91cc8f3be775c --- /dev/null +++ b/vllm/envs.py @@ -0,0 +1,217 @@ +import os +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional + +if TYPE_CHECKING: + VLLM_HOST_IP: str = "" + VLLM_USE_MODELSCOPE: bool = False + VLLM_INSTANCE_ID: Optional[str] = None + VLLM_NCCL_SO_PATH: Optional[str] = None + LD_LIBRARY_PATH: Optional[str] = None + VLLM_USE_TRITON_FLASH_ATTN: bool = False + LOCAL_RANK: int = 0 + CUDA_VISIBLE_DEVICES: Optional[str] = None + VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60 + VLLM_API_KEY: Optional[str] = None + S3_ACCESS_KEY_ID: Optional[str] = None + S3_SECRET_ACCESS_KEY: Optional[str] = None + S3_ENDPOINT_URL: Optional[str] = None + VLLM_CONFIG_ROOT: str = "" + VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai" + VLLM_NO_USAGE_STATS: bool = False + VLLM_DO_NOT_TRACK: bool = False + VLLM_USAGE_SOURCE: str = "" + VLLM_CONFIGURE_LOGGING: int = 1 + VLLM_LOGGING_CONFIG_PATH: Optional[str] = None + VLLM_TRACE_FUNCTION: int = 0 + VLLM_ATTENTION_BACKEND: Optional[str] = None + VLLM_CPU_KVCACHE_SPACE: int = 0 + VLLM_USE_RAY_COMPILED_DAG: bool = False + VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" + VLLM_TARGET_DEVICE: str = "cuda" + MAX_JOBS: Optional[str] = None + NVCC_THREADS: Optional[str] = None + VLLM_BUILD_WITH_NEURON: bool = False + VLLM_USE_PRECOMPILED: bool = False + VLLM_INSTALL_PUNICA_KERNELS: bool = False + CMAKE_BUILD_TYPE: Optional[str] = None + VERBOSE: bool = False + +# The begin-* and end* here are used by the documentation generator +# to extract the used env vars. + +# begin-env-vars-definition + +environment_variables: Dict[str, Callable[[], Any]] = { + + # ================== Installation Time Env Vars ================== + + # Target device of vLLM, supporting [cuda (by default), rocm, neuron, cpu] + "VLLM_TARGET_DEVICE": + lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"), + + # Maximum number of compilation jobs to run in parallel. + # By default this is the number of CPUs + "MAX_JOBS": + lambda: os.getenv("MAX_JOBS", None), + + # Number of threads to use for nvcc + # By default this is 1. + # If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU. + "NVCC_THREADS": + lambda: os.getenv("NVCC_THREADS", None), + + # If set, vllm will build with Neuron support + "VLLM_BUILD_WITH_NEURON": + lambda: bool(os.environ.get("VLLM_BUILD_WITH_NEURON", False)), + + # If set, vllm will use precompiled binaries (*.so) + "VLLM_USE_PRECOMPILED": + lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")), + + # If set, vllm will install Punica kernels + "VLLM_INSTALL_PUNICA_KERNELS": + lambda: bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))), + + # CMake build type + # If not set, defaults to "Debug" or "RelWithDebInfo" + # Available options: "Debug", "Release", "RelWithDebInfo" + "CMAKE_BUILD_TYPE": + lambda: os.getenv("CMAKE_BUILD_TYPE"), + + # If set, vllm will print verbose logs during installation + "VERBOSE": + lambda: bool(int(os.getenv('VERBOSE', '0'))), + + # Root directory for VLLM configuration files + # Note that this not only affects how vllm finds its configuration files + # during runtime, but also affects how vllm installs its configuration + # files during **installation**. + "VLLM_CONFIG_ROOT": + lambda: os.environ.get("VLLM_CONFIG_ROOT", None) or os.getenv( + "XDG_CONFIG_HOME", None) or os.path.expanduser("~/.config"), + + # ================== Runtime Env Vars ================== + + # used in distributed environment to determine the master address + 'VLLM_HOST_IP': + lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""), + + # If true, will load models from ModelScope instead of Hugging Face Hub. + # note that the value is true or false, not numbers + "VLLM_USE_MODELSCOPE": + lambda: os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true", + + # Instance id represents an instance of the VLLM. All processes in the same + # instance should have the same instance id. + "VLLM_INSTANCE_ID": + lambda: os.environ.get("VLLM_INSTANCE_ID", None), + + # path to cudatoolkit home directory, under which should be bin, include, + # and lib directories. + "CUDA_HOME": + lambda: os.environ.get("CUDA_HOME", None), + + # Path to the NCCL library file. It is needed because nccl>=2.19 brought + # by PyTorch contains a bug: https://github.com/NVIDIA/nccl/issues/1234 + "VLLM_NCCL_SO_PATH": + lambda: os.environ.get("VLLM_NCCL_SO_PATH", None), + + # when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl + # library file in the locations specified by `LD_LIBRARY_PATH` + "LD_LIBRARY_PATH": + lambda: os.environ.get("LD_LIBRARY_PATH", None), + + # flag to control if vllm should use triton flash attention + "VLLM_USE_TRITON_FLASH_ATTN": + lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in + ("true", "1")), + + # local rank of the process in the distributed setting, used to determine + # the GPU device id + "LOCAL_RANK": + lambda: int(os.environ.get("LOCAL_RANK", "0")), + + # used to control the visible devices in the distributed setting + "CUDA_VISIBLE_DEVICES": + lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None), + + # timeout for each iteration in the engine + "VLLM_ENGINE_ITERATION_TIMEOUT_S": + lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")), + + # API key for VLLM API server + "VLLM_API_KEY": + lambda: os.environ.get("VLLM_API_KEY", None), + + # S3 access information, used for tensorizer to load model from S3 + "S3_ACCESS_KEY_ID": + lambda: os.environ.get("S3_ACCESS_KEY", None), + "S3_SECRET_ACCESS_KEY": + lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None), + "S3_ENDPOINT_URL": + lambda: os.environ.get("S3_ENDPOINT_URL", None), + + # Usage stats collection + "VLLM_USAGE_STATS_SERVER": + lambda: os.environ.get("VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai"), + "VLLM_NO_USAGE_STATS": + lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1", + "VLLM_DO_NOT_TRACK": + lambda: (os.environ.get("VLLM_DO_NOT_TRACK", None) or os.environ.get( + "DO_NOT_TRACK", None) or "0") == "1", + "VLLM_USAGE_SOURCE": + lambda: os.environ.get("VLLM_USAGE_SOURCE", "production"), + + # Logging configuration + # If set to 0, vllm will not configure logging + # If set to 1, vllm will configure logging using the default configuration + # or the configuration file specified by VLLM_LOGGING_CONFIG_PATH + "VLLM_CONFIGURE_LOGGING": + lambda: int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")), + "VLLM_LOGGING_CONFIG_PATH": + lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"), + + # Trace function calls + # If set to 1, vllm will trace function calls + # Useful for debugging + "VLLM_TRACE_FUNCTION": + lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")), + + # Backend for attention computation + # Available options: + # - "TORCH_SDPA": use torch.nn.MultiheadAttention + # - "FLASH_ATTN": use FlashAttention + # - "XFORMERS": use XFormers + # - "ROCM_FLASH": use ROCmFlashAttention + "VLLM_ATTENTION_BACKEND": + lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), + + # CPU key-value cache space + # default is 4GB + "VLLM_CPU_KVCACHE_SPACE": + lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")), + + # If the env var is set, it uses the Ray's compiled DAG API + # which optimizes the control plane overhead. + # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. + "VLLM_USE_RAY_COMPILED_DAG": + lambda: bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)), + + # Use dedicated multiprocess context for workers. + # Both spawn and fork work + "VLLM_WORKER_MULTIPROC_METHOD": + lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"), +} + +# end-env-vars-definition + + +def __getattr__(name): + # lazy evaluation of environment variables + if name in environment_variables: + return environment_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return list(environment_variables.keys()) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index e4436b2144bd3..a2212459f034e 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -1,13 +1,13 @@ -import os -from typing import Dict, List, Set, Tuple +from typing import List, Set, Tuple import torch +import vllm.envs as envs from vllm.config import CacheConfig, ModelConfig, SchedulerConfig from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) @@ -72,18 +72,10 @@ def initialize_cache(self, num_gpu_blocks: int, logger.info("# CPU blocks: %d", num_gpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int) -> List[SamplerOutput]: - output = self.driver_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = self.driver_worker.execute_model(execute_model_req) return output def add_lora(self, lora_request: LoRARequest) -> bool: @@ -104,17 +96,10 @@ def check_health(self) -> None: class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase): async def execute_model_async( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ) -> List[SamplerOutput]: - output = await make_async(self.driver_worker.execute_model)( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy) + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req, ) return output async def check_health_async(self) -> None: @@ -150,8 +135,7 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: logger.warning("Prefix caching is not supported on CPU, disable it.") config.enable_prefix_caching = False - kv_cache_space_str = os.getenv("VLLM_CPU_KVCACHE_SPACE", "0") - kv_cache_space = int(kv_cache_space_str) + kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE if kv_cache_space >= 0: if kv_cache_space == 0: diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index c36aa18fb25bb..08aa58999b1ec 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,11 +1,11 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Set, Tuple +from typing import List, Optional, Set, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VisionLanguageConfig) from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput class ExecutorBase(ABC): @@ -68,12 +68,9 @@ def initialize_cache(self, num_gpu_blocks: int, raise NotImplementedError @abstractmethod - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int) -> List[SamplerOutput]: + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Executes at least one model step on the given sequences.""" raise NotImplementedError @@ -107,12 +104,8 @@ class ExecutorAsyncBase(ExecutorBase): @abstractmethod async def execute_model_async( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Executes one model step on the given sequences.""" raise NotImplementedError diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 5ac62f02b99c7..1af3bcf380843 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,11 +1,12 @@ -from typing import Dict, List, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -23,30 +24,47 @@ def _init_executor(self) -> None: else: self._init_spec_worker() - def _init_non_spec_worker(self): - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import Worker - - assert self.parallel_config.world_size == 1, ( - "GPUExecutor only supports single GPU.") - - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - self.driver_worker = Worker( + def _get_worker_kwargs( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None) -> Dict[str, Any]: + """Return worker init args for a given rank.""" + if distributed_init_method is None: + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + return dict( model_config=self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, device_config=self.device_config, cache_config=self.cache_config, load_config=self.load_config, - local_rank=0, - rank=0, + local_rank=local_rank, + rank=rank, distributed_init_method=distributed_init_method, lora_config=self.lora_config, vision_language_config=self.vision_language_config, - is_driver_worker=True, + is_driver_worker=rank == 0, ) + + def _create_worker(self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None): + wrapper = WorkerWrapperBase( + worker_module_name="vllm.worker.worker", + worker_class_name="Worker", + ) + wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, + distributed_init_method)) + return wrapper.worker + + def _init_non_spec_worker(self): + assert self.parallel_config.world_size == 1, ( + "GPUExecutor only supports single GPU.") + + self.driver_worker = self._create_worker() self.driver_worker.init_device() self.driver_worker.load_model() @@ -55,46 +73,23 @@ def _init_spec_worker(self): """ assert self.speculative_config is not None - from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker - from vllm.worker.worker import Worker - - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - target_worker = Worker( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - is_driver_worker=True, - ) + target_worker = self._create_worker() - draft_worker = MultiStepWorker( + draft_worker_kwargs = self._get_worker_kwargs() + # Override draft-model specific worker args. + draft_worker_kwargs.update( model_config=self.speculative_config.draft_model_config, parallel_config=self.speculative_config.draft_parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, # TODO allow draft-model specific load config. - load_config=self.load_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - is_driver_worker=True, + #load_config=self.load_config, ) - spec_decode_worker = SpecDecodeWorker.from_workers( - proposer_worker=draft_worker, scorer_worker=target_worker) + spec_decode_worker = SpecDecodeWorker.create_worker( + scorer_worker=target_worker, + draft_worker_kwargs=draft_worker_kwargs, + ) assert self.parallel_config.world_size == 1, ( "GPUExecutor only supports single GPU.") @@ -122,20 +117,9 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int, - ) -> List[SamplerOutput]: - output = self.driver_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - num_lookahead_slots=num_lookahead_slots, - ) + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = self.driver_worker.execute_model(execute_model_req) return output def add_lora(self, lora_request: LoRARequest) -> bool: @@ -159,14 +143,8 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): async def execute_model_async( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], + execute_model_req: ExecuteModelRequest, ) -> List[SamplerOutput]: - output = await make_async(self.driver_worker.execute_model)( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy) + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req, ) return output diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py new file mode 100644 index 0000000000000..62887533f5c27 --- /dev/null +++ b/vllm/executor/multiproc_worker_utils.py @@ -0,0 +1,263 @@ +import asyncio +import multiprocessing +import os +import sys +import threading +import traceback +import uuid +from dataclasses import dataclass +from multiprocessing import Queue +from multiprocessing.connection import wait +from multiprocessing.process import BaseProcess +from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO, + TypeVar, Union) + +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + +T = TypeVar('T') + +_TERMINATE = "TERMINATE" # sentinel + +# ANSI color codes +CYAN = '\033[1;36m' +RESET = '\033[0;0m' + +JOIN_TIMEOUT_S = 2 + +mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD +mp = multiprocessing.get_context(mp_method) + + +@dataclass +class Result(Generic[T]): + """Result of task dispatched to worker""" + + task_id: uuid.UUID + value: Optional[T] = None + exception: Optional[BaseException] = None + + +class ResultFuture(threading.Event, Generic[T]): + """Synchronous future for non-async case""" + + def __init__(self): + super().__init__() + self.result: Optional[Result[T]] = None + + def set_result(self, result: Result[T]): + self.result = result + self.set() + + def get(self) -> T: + self.wait() + assert self.result is not None + if self.result.exception is not None: + raise self.result.exception + return self.result.value # type: ignore[return-value] + + +def _set_future_result(future: Union[ResultFuture, asyncio.Future], + result: Result): + if isinstance(future, ResultFuture): + future.set_result(result) + return + loop = future.get_loop() + if result.exception is not None: + loop.call_soon_threadsafe(future.set_exception, result.exception) + else: + loop.call_soon_threadsafe(future.set_result, result.value) + + +class ResultHandler(threading.Thread): + """Handle results from all workers (in background thread)""" + + def __init__(self) -> None: + super().__init__(daemon=True) + self.result_queue = mp.Queue() + self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} + + def run(self): + for result in iter(self.result_queue.get, _TERMINATE): + future = self.tasks.pop(result.task_id) + _set_future_result(future, result) + # Ensure that all waiters will receive an exception + for task_id, future in self.tasks.items(): + _set_future_result( + future, + Result(task_id=task_id, + exception=ChildProcessError("worker died"))) + + def close(self): + self.result_queue.put(_TERMINATE) + + +class WorkerMonitor(threading.Thread): + """Monitor worker status (in background thread)""" + + def __init__(self, workers: List['ProcessWorkerWrapper'], + result_handler: ResultHandler): + super().__init__(daemon=True) + self.workers = workers + self.result_handler = result_handler + self._close = False + + def run(self) -> None: + # Blocks until any worker exits + dead_sentinels = wait([w.process.sentinel for w in self.workers]) + if not self._close: + self._close = True + + # Kill / cleanup all workers + for worker in self.workers: + process = worker.process + if process.sentinel in dead_sentinels: + process.join(JOIN_TIMEOUT_S) + if process.exitcode is not None and process.exitcode != 0: + logger.error("Worker %s pid %s died, exit code: %s", + process.name, process.pid, process.exitcode) + # Cleanup any remaining workers + logger.info("Killing local vLLM worker processes") + for worker in self.workers: + worker.kill_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + for worker in self.workers: + worker.process.join(JOIN_TIMEOUT_S) + + def close(self): + if self._close: + return + self._close = True + logger.info("Terminating local vLLM worker processes") + for worker in self.workers: + worker.terminate_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + +class ProcessWorkerWrapper: + """Local process wrapper for vllm.worker.Worker, + for handling single-node multi-GPU tensor parallel.""" + + def __init__(self, result_handler: ResultHandler, + worker_factory: Callable[[], Any]) -> None: + self._task_queue = mp.Queue() + self.result_queue = result_handler.result_queue + self.tasks = result_handler.tasks + self.process: BaseProcess = mp.Process( # type: ignore[attr-defined] + target=_run_worker_process, + name="VllmWorkerProcess", + kwargs=dict( + worker_factory=worker_factory, + task_queue=self._task_queue, + result_queue=self.result_queue, + ), + daemon=True) + + self.process.start() + + def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], + method: str, args, kwargs): + task_id = uuid.uuid4() + self.tasks[task_id] = future + try: + self._task_queue.put((task_id, method, args, kwargs)) + except BaseException as e: + del self.tasks[task_id] + raise ChildProcessError("worker died") from e + + def execute_method(self, method: str, *args, **kwargs): + future: ResultFuture = ResultFuture() + self._enqueue_task(future, method, args, kwargs) + return future + + async def execute_method_async(self, method: str, *args, **kwargs): + future = asyncio.get_running_loop().create_future() + self._enqueue_task(future, method, args, kwargs) + return await future + + def terminate_worker(self): + try: + self._task_queue.put(_TERMINATE) + except ValueError: + self.process.kill() + self._task_queue.close() + + def kill_worker(self): + self._task_queue.close() + self.process.kill() + + +def _run_worker_process( + worker_factory: Callable[[], Any], + task_queue: Queue, + result_queue: Queue, +) -> None: + """Worker process event loop""" + + # Add process-specific prefix to stdout and stderr + process_name = mp.current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + # Initialize worker + worker = worker_factory() + del worker_factory + + # Accept tasks from the engine in task_queue + # and return task output in result_queue + logger.info("Worker ready; awaiting tasks") + try: + for items in iter(task_queue.get, _TERMINATE): + output = None + exception = None + task_id, method, args, kwargs = items + try: + executor = getattr(worker, method) + output = executor(*args, **kwargs) + except BaseException as e: + tb = traceback.format_exc() + logger.error( + "Exception in worker %s while processing method %s: %s, %s", + process_name, method, e, tb) + exception = e + result_queue.put( + Result(task_id=task_id, value=output, exception=exception)) + except KeyboardInterrupt: + pass + except Exception: + logger.exception("Worker failed") + + logger.info("Worker exiting") + + +def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: + """Prepend each output line with process-specific prefix""" + + prefix = f"{CYAN}({worker_name} pid={pid}){RESET} " + file_write = file.write + + def write_with_prefix(s: str): + if not s: + return + if file.start_new_line: # type: ignore[attr-defined] + file_write(prefix) + idx = 0 + while (next_idx := s.find('\n', idx)) != -1: + next_idx += 1 + file_write(s[idx:next_idx]) + if next_idx == len(s): + file.start_new_line = True # type: ignore[attr-defined] + return + file_write(prefix) + idx = next_idx + file_write(s[idx:]) + file.start_new_line = False # type: ignore[attr-defined] + + file.start_new_line = True # type: ignore[attr-defined] + file.write = write_with_prefix # type: ignore[method-assign] diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index f406287f3c1d8..e7f0e887921b7 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -1,9 +1,9 @@ -from typing import Dict, List, Set, Tuple +from typing import List, Set, Tuple from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import make_async logger = init_logger(__name__) @@ -45,20 +45,18 @@ def initialize_cache(self, num_gpu_blocks: int, """ self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int) -> List[SamplerOutput]: - assert (blocks_to_swap_in == {} and blocks_to_swap_out == {} - and blocks_to_copy == {}), ( + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + assert (execute_model_req.blocks_to_swap_in == {} + and execute_model_req.blocks_to_swap_out == {} + and execute_model_req.blocks_to_copy == {}), ( "Cache operations are not supported for Neuron backend.") - assert num_lookahead_slots == 0, ( + assert execute_model_req.num_lookahead_slots == 0, ( "lookahead not supported for Neuron backend.") output = self.driver_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list) + execute_model_req.seq_group_metadata_list) return output def add_lora(self, lora_request: LoRARequest) -> bool: @@ -80,13 +78,11 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase): async def execute_model_async( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], + execute_model_req: ExecuteModelRequest, ) -> List[SamplerOutput]: - output = await make_async(self.driver_worker.execute_model)( - seq_group_metadata_list=seq_group_metadata_list, ) + output = await make_async( + self.driver_worker.execute_model + )(seq_group_metadata_list=execute_model_req.seq_group_metadata_list, ) return output async def check_health_async(self) -> None: diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index b6bcda4e6b18c..afc1c886722e6 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -5,11 +5,12 @@ from itertools import islice, repeat from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +import vllm.envs as envs from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.logger import init_logger -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) @@ -21,10 +22,7 @@ logger = init_logger(__name__) -# If the env var is set, it uses the Ray's compiled DAG API -# which optimizes the control plane overhead. -# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. -USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) +USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG class RayGPUExecutor(DistributedGPUExecutor): @@ -145,7 +143,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", "VLLM_INSTANCE_ID": VLLM_INSTANCE_ID, "VLLM_TRACE_FUNCTION": - os.getenv("VLLM_TRACE_FUNCTION", "0"), + str(envs.VLLM_TRACE_FUNCTION), }, ) for (node_id, _) in worker_node_and_gpu_ids] self._run_workers("update_environment_variables", all_args=all_args_to_update_environment_variables) @@ -153,29 +151,14 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", distributed_init_method = get_distributed_init_method( driver_ip, get_open_port()) - def collect_arg_helper_func(**kwargs): - # avoid writing `{"name": value}` manually - return kwargs - # Initialize the actual workers inside worker wrapper. - init_worker_all_kwargs = [] - for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids): - local_rank = node_workers[node_id].index(rank) - init_worker_all_kwargs.append( - collect_arg_helper_func( - model_config=self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - device_config=self.device_config, - cache_config=self.cache_config, - load_config=self.load_config, - local_rank=local_rank, - rank=rank, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - vision_language_config=self.vision_language_config, - is_driver_worker=rank == 0, - )) + init_worker_all_kwargs = [ + self._get_worker_kwargs( + local_rank=node_workers[node_id].index(rank), + rank=rank, + distributed_init_method=distributed_init_method, + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) + ] self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_device") @@ -183,25 +166,16 @@ def collect_arg_helper_func(**kwargs): max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) - def execute_model(self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int = 0) -> List[SamplerOutput]: + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: all_outputs = self._run_workers( "execute_model", - driver_kwargs={ - "seq_group_metadata_list": seq_group_metadata_list, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - }, + driver_kwargs={"execute_model_req": execute_model_req}, use_ray_compiled_dag=USE_RAY_COMPILED_DAG) # Only the driver worker returns the sampling results. - output = all_outputs[0] - return output + return all_outputs[0] def _run_workers( self, diff --git a/vllm/logger.py b/vllm/logger.py index 3928e5367d1e6..153cdfb373bb4 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -1,73 +1,93 @@ -# Adapted from -# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py """Logging configuration for vLLM.""" import datetime +import json import logging import os import sys from functools import partial -from typing import Optional +from logging import Logger +from logging.config import dictConfig +from os import path +from typing import Dict, Optional -VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) +import vllm.envs as envs + +VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING +VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" _DATE_FORMAT = "%m-%d %H:%M:%S" +DEFAULT_LOGGING_CONFIG = { + "formatters": { + "vllm": { + "class": "vllm.logging.NewLineFormatter", + "datefmt": _DATE_FORMAT, + "format": _FORMAT, + }, + }, + "handlers": { + "vllm": { + "class": "logging.StreamHandler", + "formatter": "vllm", + "level": "INFO", + "stream": "ext://sys.stdout", + }, + }, + "loggers": { + "vllm": { + "handlers": ["vllm"], + "level": "DEBUG", + "propagate": False, + }, + }, + "version": 1, +} + + +def _configure_vllm_root_logger() -> None: + logging_config: Optional[Dict] = None + + if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH: + raise RuntimeError( + "VLLM_CONFIGURE_LOGGING evaluated to false, but " + "VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH " + "implies VLLM_CONFIGURE_LOGGING. Please enable " + "VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH.") -class NewLineFormatter(logging.Formatter): - """Adds logging prefix to newlines to align multi-line messages.""" + if VLLM_CONFIGURE_LOGGING: + logging_config = DEFAULT_LOGGING_CONFIG - def __init__(self, fmt, datefmt=None): - logging.Formatter.__init__(self, fmt, datefmt) + if VLLM_LOGGING_CONFIG_PATH: + if not path.exists(VLLM_LOGGING_CONFIG_PATH): + raise RuntimeError( + "Could not load logging config. File does not exist: %s", + VLLM_LOGGING_CONFIG_PATH) + with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8", + mode="r") as file: + custom_config = json.loads(file.read()) - def format(self, record): - msg = logging.Formatter.format(self, record) - if record.message != "": - parts = msg.split(record.message) - msg = msg.replace("\n", "\r\n" + parts[0]) - return msg + if not isinstance(custom_config, dict): + raise ValueError("Invalid logging config. Expected Dict, got %s.", + type(custom_config).__name__) + logging_config = custom_config + if logging_config: + dictConfig(logging_config) -_root_logger = logging.getLogger("vllm") -_default_handler: Optional[logging.Handler] = None +def init_logger(name: str) -> Logger: + """The main purpose of this function is to ensure that loggers are + retrieved in such a way that we can be sure the root vllm logger has + already been configured.""" -def _setup_logger(): - _root_logger.setLevel(logging.DEBUG) - global _default_handler - if _default_handler is None: - _default_handler = logging.StreamHandler(sys.stdout) - _default_handler.flush = sys.stdout.flush # type: ignore - _default_handler.setLevel(logging.INFO) - _root_logger.addHandler(_default_handler) - fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT) - _default_handler.setFormatter(fmt) - # Setting this will avoid the message - # being propagated to the parent logger. - _root_logger.propagate = False + return logging.getLogger(name) -# The logger is initialized when the module is imported. +# The root logger is initialized when the module is imported. # This is thread-safe as the module is only imported once, # guaranteed by the Python GIL. -if VLLM_CONFIGURE_LOGGING: - _setup_logger() - - -def init_logger(name: str): - # Use the same settings as above for root logger - logger = logging.getLogger(name) - logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG")) - - if VLLM_CONFIGURE_LOGGING: - if _default_handler is None: - raise ValueError( - "_default_handler is not set up. This should never happen!" - " Please open an issue on Github.") - logger.addHandler(_default_handler) - logger.propagate = False - return logger - +_configure_vllm_root_logger() logger = init_logger(__name__) diff --git a/vllm/logging/__init__.py b/vllm/logging/__init__.py new file mode 100644 index 0000000000000..b9aec380776f3 --- /dev/null +++ b/vllm/logging/__init__.py @@ -0,0 +1,5 @@ +from vllm.logging.formatter import NewLineFormatter + +__all__ = [ + "NewLineFormatter", +] diff --git a/vllm/logging/formatter.py b/vllm/logging/formatter.py new file mode 100644 index 0000000000000..b24b4e11d1fcb --- /dev/null +++ b/vllm/logging/formatter.py @@ -0,0 +1,15 @@ +import logging + + +class NewLineFormatter(logging.Formatter): + """Adds logging prefix to newlines to align multi-line messages.""" + + def __init__(self, fmt, datefmt=None, style="%"): + logging.Formatter.__init__(self, fmt, datefmt, style) + + def format(self, record): + msg = logging.Formatter.format(self, record) + if record.message != "": + parts = msg.split(record.message) + msg = msg.replace("\n", "\r\n" + parts[0]) + return msg diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 53efebb604048..8403604286903 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -74,7 +74,8 @@ async def get_outlines_guided_decoding_logits_processor( result = await loop.run_in_executor(global_thread_pool, _get_cached_logits_processor, guide, - tokenizer, mode) + tokenizer, mode, + request.guided_whitespace_pattern) logits_processor = copy(result) # reset logits processor's internal state @@ -117,9 +118,10 @@ def _get_guide_and_mode( @lru_cache(maxsize=32) def _get_cached_logits_processor(guide: str, tokenizer: PreTrainedTokenizerBase, - mode: GuidedDecodingMode): + mode: GuidedDecodingMode, + whitespace_pattern: Union[str, None]): if mode == GuidedDecodingMode.JSON: - return JSONLogitsProcessor(guide, tokenizer) + return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern) elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: return RegexLogitsProcessor(guide, tokenizer) elif mode == GuidedDecodingMode.GRAMMAR: diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 25ab5bf8b6a9c..a131c6a1b92b4 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -18,7 +18,7 @@ import math from collections import defaultdict from functools import lru_cache -from typing import Callable, DefaultDict, Dict, List, Optional, Union +from typing import Callable, DefaultDict, Dict, List, Union import torch from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM @@ -80,10 +80,9 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): class JSONLogitsProcessor(RegexLogitsProcessor): - def __init__(self, - schema: Union[str, Dict, BaseModel], + def __init__(self, schema: Union[str, Dict, BaseModel], tokenizer: PreTrainedTokenizerBase, - whitespace_pattern: Optional[str] = None): + whitespace_pattern: Union[str, None]): """Compile the FSM that drives the JSON-guided generation. Parameters diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index baf1d4f266181..d101aa323b0e1 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -67,6 +67,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ops.gelu_tanh_and_mul(out, x) return out + def extra_repr(self) -> str: + return f'approximate={repr(self.approximate)}' + class NewGELU(nn.Module): diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json new file mode 100644 index 0000000000000..9287808a94d0e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json @@ -0,0 +1,140 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index b4f81527141a8..3cb0419404625 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -203,14 +203,15 @@ def moe_align_block_size( - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. """ - sorted_ids = torch.empty( - (topk_ids.numel() + num_experts * (block_size - 1), ), - dtype=torch.int32, - device=topk_ids.device) - expert_ids = torch.empty((topk_ids.numel() + num_experts, ), + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids = torch.empty((max_num_tokens_padded, ), dtype=torch.int32, device=topk_ids.device) sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + expert_ids = torch.empty((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index a6619714b8aab..8de0794158986 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -64,3 +64,8 @@ def forward( self.variance_epsilon, ) return out + + def extra_repr(self) -> str: + s = f"hidden_size={self.weight.data.size(0)}" + s += f", eps={self.variance_epsilon}" + return s diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 4d43ed4c5f14a..7726dcb9a5fbd 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -181,6 +181,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output_bias = self.bias if self.skip_bias_add else None return output, output_bias + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + return s + class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. @@ -246,6 +252,10 @@ def __init__( self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) + tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) param_data = param.data @@ -254,6 +264,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer(param_data, + loaded_weight, + shard_id=0) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -271,6 +287,14 @@ def forward(self, input_): output_bias = self.bias if self.skip_bias_add else None return output, output_bias + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size_per_partition}" + s += f", bias={self.bias is not None}" + s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += f", gather_output={self.gather_output}" + return s + class MergedColumnParallelLinear(ColumnParallelLinear): """Packed linear layers with column parallelism. @@ -317,7 +341,12 @@ def weight_loader(self, param_data = param.data output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. is_metadata = getattr(param, "is_metadata", False) + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) + if loaded_shard_id is None: # Loaded weight is already packed. if output_dim is None: @@ -331,14 +360,13 @@ def weight_loader(self, current_shard_offset += output_size packed_dim = getattr(param, "packed_dim", None) for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - - # If marlin, we need to adjust the offset and size to - # account for the tiling. + # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -353,15 +381,14 @@ def weight_loader(self, if output_dim is not None: shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_size = self.output_sizes[loaded_shard_id] // tp_size + # Special case for quantization. # If quantized, we need to adjust the offset and size to account # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - - # If marlin, we need to adjust the offset and size to - # account for the tiling. + # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -370,11 +397,17 @@ def weight_loader(self, start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 shard_size = loaded_weight.shape[0] shard_offset = loaded_shard_id * shard_size param_data = param_data.narrow(0, shard_offset, shard_size) + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer( + param_data, loaded_weight, loaded_shard_id) + else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -455,7 +488,11 @@ def weight_loader(self, loaded_shard_id: Optional[str] = None): param_data = param.data output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. is_metadata = getattr(param, "is_metadata", False) + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) if loaded_shard_id is None: # Loaded weight is already packed. @@ -473,14 +510,14 @@ def weight_loader(self, ] packed_dim = getattr(param, "packed_dim", None) for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantized Weights. # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - # If marlin, we need to adjust the offset and size to - # account for the tiling. + # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -502,6 +539,7 @@ def weight_loader(self, shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size shard_size = self.num_kv_heads * self.head_size + # Special case for Quantized Weights. # If quantized, we need to adjust the offset and size to account # for the packing. packed_dim = getattr(param, "packed_dim", None) @@ -509,8 +547,7 @@ def weight_loader(self, shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor - # If marlin, we need to adjust the offset and size to - # account for the tiling. + # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) @@ -523,12 +560,17 @@ def weight_loader(self, start_idx = shard_id * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for for AQLM codebooks. elif is_metadata: # metadata indicates fixed size concatenated along dim 0 shard_size = loaded_weight.shape[0] shard_index = ["q", "k", "v"].index(loaded_shard_id) param_data = param_data.narrow(0, shard_index * shard_size, shard_size) + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer( + param_data, loaded_weight, loaded_shard_id) else: ignore_warning = getattr(param, "ignore_warning", False) if not ignore_warning: @@ -611,6 +653,10 @@ def __init__( self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) + tp_rank = get_tensor_model_parallel_rank() input_dim = getattr(param, "input_dim", None) param_data = param.data @@ -619,6 +665,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer(param_data, + loaded_weight, + shard_id=0) + assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -647,3 +699,11 @@ def forward(self, input_): output = output_ output_bias = self.bias return output, output_bias + + def extra_repr(self) -> str: + s = f"input_features={self.input_size_per_partition}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + s += f", tp_size={self.tp_size}" + s += f", reduce_results={self.reduce_results}" + return s diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 22620d9fc86d9..91eb96998c3cf 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -70,6 +70,12 @@ def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, logits = logits[:, :self.org_vocab_size] return logits + def extra_repr(self) -> str: + s = f"vocab_size={self.vocab_size}" + s += f", forg_vocab_size={self.org_vocab_size}" + s += f", scale={self.scale}, logits_as_input={self.logits_as_input}" + return s + def _prune_hidden_states( hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index ba9f3149649c1..b57e1dde81a5f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,23 +1,36 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch.nn import Module from torch.nn.parameter import Parameter from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) + QuantizationConfig) from vllm.model_executor.utils import set_weight_attrs +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + class Fp8Config(QuantizationConfig): """Config class for FP8.""" def __init__( self, + is_checkpoint_fp8_serialized: bool = False, activation_scheme: str = "dynamic", ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning("Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change.") + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError( + f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme @classmethod @@ -30,10 +43,7 @@ def get_supported_act_dtypes(cls) -> List[torch.dtype]: @classmethod def get_min_capability(cls) -> int: - # TODO: PyTorch 2.3.0+ is required to run FP8 on - # SM 89 (e.g. Ada) GPUs. Specifically, this PR has to - # be included: https://github.com/pytorch/pytorch/pull/118881 - return 90 + return 89 @classmethod def get_config_filenames(cls) -> List[str]: @@ -41,11 +51,14 @@ def get_config_filenames(cls) -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = ("fp8" in quant_method) activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) - return cls(activation_scheme) + return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme) def get_quant_method( - self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: + self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]: if isinstance(layer, LinearBase): return Fp8LinearMethod(self) return None @@ -56,8 +69,12 @@ def get_scaled_act_names(self) -> List[str]: class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. - We now support common FP16/BF16 model checkpoints ONLY. The weight - scaling factor will be initialized after the model weights are loaded. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. Limitations: 1. Only support per-tensor quantization due to torch._scaled_mm support. @@ -71,6 +88,24 @@ class Fp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config + def _create_scale_param( + self, + scale_name: str, + layer: torch.nn.Module, + output_partition_sizes: List[int], + **extra_weight_attrs, + ) -> None: + scale = Parameter(torch.empty(len(output_partition_sizes), + dtype=torch.float32), + requires_grad=False) + layer.register_parameter(scale_name, scale) + set_weight_attrs( + scale, { + **extra_weight_attrs, + "fp8_scales_shard_indexer": + self.scales_shard_indexer, + }) + def create_weights( self, layer: torch.nn.Module, @@ -81,46 +116,150 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): + del input_size, output_size output_size_per_partition = sum(output_partition_sizes) + + layer.process_after_load = True + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized else + params_dtype) weight = Parameter(torch.empty(output_size_per_partition, input_size_per_partition, - dtype=params_dtype), + dtype=weight_dtype), requires_grad=False) layer.register_parameter("weight", weight) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - set_weight_attrs(weight, extra_weight_attrs) + set_weight_attrs(weight, { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0, + }) - w_scale = Parameter( - torch.empty(1, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("weight_scaling_factor", w_scale) + # If checkpoint is serialized fp8, load them. + # Otherwise, wait until process_weights_after_loading. + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + self._create_scale_param( + scale_name="weight_scale", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + + # ACTIVATION SCALE + if self.quant_config.activation_scheme == "static": + self._create_scale_param( + scale_name="act_scale", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs) + + def scales_shard_indexer( + self, param: torch.Tensor, loaded_weight: torch.Tensor, + shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]: + qkv_idxs = {"q": 0, "k": 1, "v": 2} + + if isinstance(shard_id, int): + pass + elif isinstance(shard_id, str): + if shard_id not in qkv_idxs: + raise ValueError(f"Unknown shard_id: {shard_id}") + shard_id = qkv_idxs[shard_id] + else: + ValueError(f"Shard id must be int or str but got {type(shard_id)}") + + return param[shard_id], loaded_weight def process_weights_after_loading(self, layer: Module) -> None: - # Although the quant_method is propagated to all layers, - # only linear layers invoke "create_weights". So we check - # whether "weight_scaling_facor" is registered to determine - # whether the layer is a linear layer that requires quantization. - if not hasattr(layer, "weight_scaling_factor"): + if (not hasattr(layer, "process_after_load") + or not layer.process_after_load): + return + + # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights. + if not self.quant_config.is_checkpoint_fp8_serialized: + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, + scale=None) + layer.weight = Parameter(qweight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.logical_widths = None + layer.act_scale = None return - qweight, weight_scale = ops.scaled_fp8_quant(layer.weight) - # torch._scaled_mm requires column-major in the second - # input (weight), so we transpose the quantized weight. - layer.weight = Parameter(qweight.t(), requires_grad=False) - layer.weight_scaling_factor.data.copy_(weight_scale) + # If checkpoint is fp8, requantize the separately quantized logical + # weights into a single fp8 weight with a single weight scale. + else: + # WEIGHT_SCALE / WEIGHT + # Loop over logical weights, requantizing with single scale. + max_w_scale = layer.weight_scale.max() + start = 0 + for idx, logical_width in enumerate(layer.logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize(layer.weight[start:end, :], + layer.weight_scale[idx]) + + layer.weight[start:end, :] = per_tensor_quantize( + weight_dq, layer.weight_scale.max()) + start = end + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + + # WEIGHT + # Transpose weight for passing to torch._scaled_mm + weight = layer.weight + layer.weight = Parameter(weight.t(), requires_grad=False) + + # ACT_SCALE + # Dynamic: set to None (required input to ops.scaled_fp8_quant). + # Static: set to max of the act_scales (since they are equal). + if self.quant_config.activation_scheme == "dynamic": + layer.act_scale = None + elif self.quant_config.activation_scheme == "static": + if not all_close_1d(layer.act_scale): + raise ValueError( + "All the act_scales for the logical weights of a layer " + f"must be equal. But got {layer.act_scale}") + layer.act_scale = Parameter(layer.act_scale.max(), + requires_grad=False) + else: + raise ValueError( + f"Unknown scheme {self.quant_config.activation_scheme}") def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qinput, x_scale = ops.scaled_fp8_quant(x) + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.act_scale is None and x_scale computed from x. + # If static, layer.act_scale is scalar and x_scale set to act_scale. + qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale) + + # Fused GEMM_DQ output, _ = torch._scaled_mm( qinput, layer.weight, out_dtype=x.dtype, scale_a=x_scale, - scale_b=layer.weight_scaling_factor, + scale_b=layer.weight_scale, bias=bias, ) + return output + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) + + +def per_tensor_quantize(tensor: torch.Tensor, + inv_scale: float) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) + return qweight.to(torch.float8_e4m3fn) + + +def per_tensor_dequantize(tensor: torch.Tensor, + inv_scale: float) -> torch.Tensor: + fake_qweight = tensor.to(torch.float16) + dq_weight = fake_qweight * inv_scale + return dq_weight diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index efbffa0878c4b..e2464008a875f 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -2,7 +2,6 @@ from enum import Enum from typing import Any, Dict, List, Optional -import numpy import torch from torch.nn.parameter import Parameter @@ -17,41 +16,13 @@ GPTQ_MARLIN_MIN_THREAD_K = 128 GPTQ_MARLIN_MAX_PARALLEL = 16 -GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4] +GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8] GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] GPTQ_MARLIN_SUPPORTED_SYM = [True] -# Precompute permutations for Marlin weight and scale shuffling -# -# Marlin works on [16,64] tiles. The goal of the permutations -# is to reorder the weight data so that it is compatible -# with the tensor-core format that is described here: -# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 -# -# As a result of this reordering, the vector loads inside the -# kernel will get the data as it is needed for tensor-core -# (without the need to use ldmatrix instructions) -def _get_perms(): - perm = [] - for i in range(32): - perm1 = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm) - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - perm = perm.reshape((-1, 8))[:, interleave].ravel() # type: ignore - perm = torch.from_numpy(perm) +# Permutations for Marlin scale shuffling +def get_scale_perms(num_bits): scale_perm = [] for i in range(8): scale_perm.extend([i + 8 * j for j in range(8)]) @@ -59,23 +30,21 @@ def _get_perms(): for i in range(4): scale_perm_single.extend( [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return perm, scale_perm, scale_perm_single - - -_perm, _scale_perm, _scale_perm_single = _get_perms() + return scale_perm, scale_perm_single def get_pack_factor(num_bits): - assert num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS, ( - f"Unsupported num_bits = {num_bits}") + assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS + ), f"Unsupported num_bits = {num_bits}" return 32 // num_bits -def marlin_permute_scales(s, size_k, size_n, group_size): +def marlin_permute_scales(s, size_k, size_n, group_size, num_bits): + scale_perm, scale_perm_single = get_scale_perms(num_bits) if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm] + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] else: - s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] s = s.reshape((-1, size_n)).contiguous() return s @@ -279,13 +248,15 @@ def create_weights( requires_grad=False, ) set_weight_attrs( - qweight, { + qweight, + { **extra_weight_attrs, "input_dim": 0, "output_dim": 1, "packed_dim": 0, "pack_factor": self.quant_config.pack_factor, - }) + }, + ) # Activation order g_idx = Parameter( @@ -296,10 +267,13 @@ def create_weights( requires_grad=False, ) # Ignore warning from fused linear layers such as QKVParallelLinear. - set_weight_attrs(g_idx, { - **extra_weight_attrs, "input_dim": 0, - "ignore_warning": True - }) + set_weight_attrs( + g_idx, + { + **extra_weight_attrs, "input_dim": 0, + "ignore_warning": True + }, + ) g_idx_sort_indices = Parameter( torch.empty( @@ -320,29 +294,34 @@ def create_weights( requires_grad=False, ) set_weight_attrs( - scales, { + scales, + { **extra_weight_attrs, "input_dim": scales_and_zp_input_dim, "output_dim": 1, - }) + }, + ) # Quantized zero-points qzeros = Parameter( - torch.empty(scales_and_zp_size, - output_size_per_partition // - self.quant_config.pack_factor, - dtype=torch.int32, - device="meta"), + torch.empty( + scales_and_zp_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + device="meta", + ), requires_grad=False, ) set_weight_attrs( - qzeros, { + qzeros, + { **extra_weight_attrs, "input_dim": scales_and_zp_input_dim, "output_dim": 1, "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, - }) + }, + ) # Allocate marlin workspace max_workspace_size = ( @@ -405,13 +384,14 @@ def replace_tensor(name, new_t): else: # Reset g_idx related tensors - layer.g_idx = Parameter(torch.empty(0, - dtype=torch.int, - device=cur_device), - requires_grad=False) - layer.g_idx_sort_indices = Parameter(torch.empty( - 0, dtype=torch.int, device=cur_device), - requires_grad=False) + layer.g_idx = Parameter( + torch.empty(0, dtype=torch.int, device=cur_device), + requires_grad=False, + ) + layer.g_idx_sort_indices = Parameter( + torch.empty(0, dtype=torch.int, device=cur_device), + requires_grad=False, + ) # Repack weights marlin_qweight = ops.gptq_marlin_repack( @@ -419,6 +399,7 @@ def replace_tensor(name, new_t): layer.g_idx_sort_indices, part_size_k, part_size_n, + self.quant_config.weight_bits, ) replace_tensor("qweight", marlin_qweight) @@ -428,15 +409,28 @@ def replace_tensor(name, new_t): if self.quant_config.desc_act: scales_size_k = full_size_k - marlin_scales = marlin_permute_scales(layer.scales, scales_size_k, - scales_size_n, - self.quant_config.group_size) + marlin_scales = marlin_permute_scales( + layer.scales, + scales_size_k, + scales_size_n, + self.quant_config.group_size, + self.quant_config.weight_bits, + ) replace_tensor("scales", marlin_scales) - output = ops.gptq_marlin_gemm(reshaped_x, layer.qweight, layer.scales, - layer.g_idx, layer.g_idx_sort_indices, - layer.workspace, size_m, part_size_n, - part_size_k, layer.is_k_full) + output = ops.gptq_marlin_gemm( + reshaped_x, + layer.qweight, + layer.scales, + layer.g_idx, + layer.g_idx_sort_indices, + layer.workspace, + self.quant_config.weight_bits, + size_m, + part_size_n, + part_size_k, + layer.is_k_full, + ) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 25365a9b50a1f..857d70fadcb57 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -156,6 +156,12 @@ def forward( self.cos_sin_cache, self.is_neox_style) return query, key + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + class LinearScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with linear scaling. diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index d79c99e5d0a45..1f19d2053d996 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -103,8 +103,7 @@ def forward( if self.include_gpu_probs_tensor: assert maybe_sampled_tokens_tensor is not None - sampled_tokens_tensor = maybe_sampled_tokens_tensor - on_device_tensors = (probs, sampled_tokens_tensor) + on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) else: on_device_tensors = None @@ -965,8 +964,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, has implications on the overall design of the sampler, e.g. how to record accurate logprobs for the user, so this improvement is deferred to later. """ - logprobs[sample_indices, :] = -float('inf') - logprobs[sample_indices, greedy_samples] = 0.0 + # NOTE: logprobs are not modified so they can be returned to the user. probs[sample_indices, :] = 0 probs[sample_indices, greedy_samples] = 1.0 @@ -976,7 +974,8 @@ def _build_sampler_output( sampling_metadata: SamplingMetadata, prompt_logprobs: List[Optional[PromptLogprobs]], sample_logprobs: List[SampleLogprobs], - on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]], + on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor, + torch.Tensor]], ) -> SamplerOutput: """Construct Python objects with the output of sampling. @@ -1005,14 +1004,17 @@ def _build_sampler_output( # If not specified, store None values in SamplerOutput. if on_device_tensors is not None: - sampled_token_probs, sampled_token_ids = on_device_tensors + (sampled_token_probs, logprobs_tensor, + sampled_token_ids) = on_device_tensors else: - sampled_token_probs, sampled_token_ids = (None, None) + sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, + None) return SamplerOutput( outputs=sampler_output, sampled_token_probs=sampled_token_probs, sampled_token_ids=sampled_token_ids, + logprobs=logprobs_tensor, ) @@ -1033,8 +1035,8 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: assert seq_group.is_prompt, ( "Caller should ensure the sequence group is in a prefill stage.") seq_ids = seq_group.seq_ids - subquery_len = seq_group.subquery_len - assert subquery_len is not None + query_len = seq_group.query_len + assert query_len is not None # prompt has only 1 seq id. assert len(seq_ids) == 1 seq_data = seq_group.seq_data[seq_ids[0]] @@ -1042,7 +1044,7 @@ def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: prompt_tokens = seq_data.prompt_token_ids # +1 because we are looking for a next prompt token. next_token_index_start = computed_len + 1 - next_token_index_end = min(computed_len + subquery_len + 1, + next_token_index_end = min(computed_len + query_len + 1, len(prompt_tokens)) next_prompt_tokens = prompt_tokens[ next_token_index_start:next_token_index_end] diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 088c0849243c0..4585b1679cb5c 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -105,6 +105,14 @@ def forward(self, input_): output = tensor_model_parallel_all_reduce(output_parallel) return output + def extra_repr(self) -> str: + s = f"num_embeddings={self.num_embeddings_per_partition}" + s += f", embedding_dim={self.embedding_dim}" + s += f", org_vocab_size={self.org_vocab_size}" + s += f', num_embeddings_padded={self.num_embeddings_padded}' + s += f', tp_size={self.tp_size}' + return s + class ParallelLMHead(VocabParallelEmbedding): """Parallelized LM head. diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 70e64167f8698..bafa2de62e5df 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -9,9 +9,10 @@ import torch from torch import nn -from vllm.config import (VLLM_USE_MODELSCOPE, DeviceConfig, LoadConfig, - LoadFormat, LoRAConfig, ModelConfig, ParallelConfig, - SchedulerConfig, VisionLanguageConfig) +from vllm.config import (DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig, + VisionLanguageConfig) +from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 2d654b2fefb8d..af433b86e604d 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -11,6 +11,7 @@ from torch import nn from transformers import PretrainedConfig +import vllm.envs as envs from vllm.config import ModelConfig, ParallelConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( @@ -44,7 +45,7 @@ class TensorizerConfig: str, bytes, os.PathLike, int] vllm_tensorized: bool verify_hash: Optional[bool] = False - num_readers: Optional[int] = 1 + num_readers: Optional[int] = None encryption_keyfile: Optional[str] = None s3_access_key_id: Optional[str] = None s3_secret_access_key: Optional[str] = None @@ -104,7 +105,7 @@ class TensorizerArgs: str, bytes, os.PathLike, int] vllm_tensorized: bool verify_hash: Optional[bool] = False - num_readers: Optional[int] = 1 + num_readers: Optional[int] = None encryption_keyfile: Optional[str] = None s3_access_key_id: Optional[str] = None s3_secret_access_key: Optional[str] = None @@ -125,8 +126,9 @@ class TensorizerArgs: the hashes stored in the metadata. A `HashMismatchError` will be raised if any of the hashes do not match. num_readers: Controls how many threads are allowed to read concurrently - from the source file. Default is 1. This greatly increases - performance. + from the source file. Default is `None`, which will dynamically set + the number of readers based on the number of available + resources and model size. This greatly increases performance. encryption_keyfile: File path to a binary file containing a binary key to use for decryption. `None` (the default) means no decryption. See the example script in @@ -141,13 +143,10 @@ class TensorizerArgs: def __post_init__(self): self.file_obj = self.tensorizer_uri - self.s3_access_key_id = (self.s3_access_key_id - or os.environ.get("S3_ACCESS_KEY_ID")) or None - self.s3_secret_access_key = ( - self.s3_secret_access_key - or os.environ.get("S3_SECRET_ACCESS_KEY")) or None - self.s3_endpoint = (self.s3_endpoint - or os.environ.get("S3_ENDPOINT_URL")) or None + self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID + self.s3_secret_access_key = (self.s3_secret_access_key + or envs.S3_SECRET_ACCESS_KEY) + self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL self.stream_params = { "s3_access_key_id": self.s3_access_key_id, "s3_secret_access_key": self.s3_secret_access_key, @@ -199,10 +198,12 @@ def add_cli_args( "use for decryption. Can be a file path or S3 network URI.") group.add_argument( "--num-readers", - default=1, + default=None, type=int, help="Controls how many threads are allowed to read concurrently " - "from the source file.") + "from the source file. Default is `None`, which will dynamically " + "set the number of readers based on the available resources " + "and model size. This greatly increases performance.") group.add_argument( "--s3-access-key-id", default=None, @@ -337,7 +338,7 @@ def deserialize(self): per_second = convert_bytes(deserializer.total_tensor_bytes / duration) after_mem = get_mem_usage() deserializer.close() - logger.info("Deserialized %s in %0.2fs, %f/s", total_bytes_str, + logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str, end - start, per_second) logger.info("Memory usage before: %s", before_mem) logger.info("Memory usage after: %s", after_mem) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index c5dd1a63e2f7a..efa4de7516212 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -78,6 +78,8 @@ def __init__( self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size // self.tp_size + self.quant_config = quant_config + # FIXME(pcmoritz): Make this more general to support different # quantization schemes self.use_fp8 = isinstance(quant_config, Fp8Config) @@ -86,59 +88,79 @@ def __init__( params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype + # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear(self.hidden_size, self.num_total_experts, bias=False, params_dtype=self.params_dtype, quant_config=None) - self.ws = nn.Parameter( + if self.use_fp8: + params_dtype = torch.float8_e4m3fn + + self.w13_weight = nn.Parameter( torch.empty(self.num_total_experts, 2 * self.intermediate_size, self.hidden_size, - device="cuda", - dtype=self.params_dtype)) - self.w2s = nn.Parameter( + dtype=params_dtype)) + self.w2_weight = nn.Parameter( torch.empty(self.num_total_experts, self.hidden_size, self.intermediate_size, - device="cuda", - dtype=self.params_dtype)) + dtype=params_dtype)) - set_weight_attrs(self.ws, { + set_weight_attrs(self.w13_weight, { "weight_loader": self.weight_loader, }) - set_weight_attrs(self.w2s, { + set_weight_attrs(self.w2_weight, { "weight_loader": self.weight_loader, }) - # Scaling factors for FP8 weights - self.ws_scale = nn.Parameter( - torch.ones( - self.num_total_experts, device="cuda", dtype=torch.float32), - requires_grad=False) if self.use_fp8 else None - self.w2s_scale = nn.Parameter( - torch.ones( - self.num_total_experts, device="cuda", dtype=torch.float32), - requires_grad=False) if self.use_fp8 else None - - # Scaling factors for FP8 activations - need_act_scales = (self.use_fp8 - and quant_config.activation_scheme == "static") - self.as_scale = nn.Parameter( - torch.zeros(1, device="cuda", dtype=torch.float32), - requires_grad=False) if need_act_scales else None - self.a2s_scale = nn.Parameter( - torch.zeros(1, device="cuda", dtype=torch.float32), - requires_grad=False) if need_act_scales else None - - if need_act_scales: - set_weight_attrs(self.as_scale, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.a2s_scale, { - "weight_loader": self.weight_loader, - }) + # Used for fp8. + self.w13_scale = None + self.w2_scale = None + self.a13_scale = None + self.a2_scale = None + + if self.use_fp8: + # WEIGHT_SCALE (for fp8) + self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts, + dtype=torch.float32), + requires_grad=False) + self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts, + dtype=torch.float32), + requires_grad=False) + + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(self.w13_scale, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2_scale, { + "weight_loader": self.weight_loader, + }) + + # ACT_SCALE (for fp8) + if quant_config.activation_scheme == "static": + if not quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8.") + self.a13_scale = nn.Parameter(torch.zeros( + self.num_total_experts, dtype=torch.float32), + requires_grad=False) + self.a2_scale = nn.Parameter(torch.zeros( + self.num_total_experts, dtype=torch.float32), + requires_grad=False) + + set_weight_attrs(self.a13_scale, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.a2_scale, { + "weight_loader": self.weight_loader, + }) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, expert_id: int): @@ -153,20 +175,49 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_size:2 * shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] - if "act_scale" in weight_name: - param_data[:] = param_data[:].max(loaded_weight) + if "act_scale" in weight_name or "weight_scale" in weight_name: + param_data[expert_id] = loaded_weight def process_weights_after_loading(self): - if self.use_fp8: - ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn) - w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn) + # Fp8 is the only case where we need to process after loading. + if not self.use_fp8: + return + + # If checkpoint is fp16, quantize here. + if not self.quant_config.is_checkpoint_fp8_serialized: + w13_weight = torch.empty_like(self.w13_weight.data, + dtype=torch.float8_e4m3fn) + w2_weight = torch.empty_like(self.w2_weight.data, + dtype=torch.float8_e4m3fn) for expert in range(self.num_total_experts): - ws[expert, :, :], self.ws_scale[expert] = ops.scaled_fp8_quant( - self.ws.data[expert, :, :]) - w2s[expert, :, :], self.w2s_scale[ - expert] = ops.scaled_fp8_quant(self.w2s.data[expert, :, :]) - self.ws = nn.Parameter(ws, requires_grad=False) - self.w2s = nn.Parameter(w2s, requires_grad=False) + w13_weight[expert, :, :], self.w13_scale[ + expert] = ops.scaled_fp8_quant( + self.w13_weight.data[expert, :, :]) + w2_weight[expert, :, :], self.w2_scale[ + expert] = ops.scaled_fp8_quant( + self.w2_weight.data[expert, :, :]) + self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) + self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) + + # If checkpoint is fp8 + static, cleanup act_scales. + # Since state_dict has an act_scale per expert but our kernels + # are passed one act_scale shared across all experts. + elif self.quant_config.activation_scheme == "static": + if self.a13_scale is None or self.a2_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + + if (not all_close_1d(self.a13_scale) + or not all_close_1d(self.a2_scale)): + print_warning_once( + "Found act_scales that are not equal for fp8 MoE layer. " + "Using the maximum across experts for each layer. ") + + self.a13_scale = nn.Parameter(self.a13_scale.max(), + requires_grad=False) + self.a2_scale = nn.Parameter(self.a2_scale.max(), + requires_grad=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_size = hidden_states.shape @@ -174,17 +225,17 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) final_hidden_states = fused_moe(hidden_states, - self.ws, - self.w2s, + self.w13_weight, + self.w2_weight, router_logits, self.top_k, renormalize=True, inplace=True, use_fp8=self.use_fp8, - w1_scale=self.ws_scale, - w2_scale=self.w2s_scale, - a1_scale=self.as_scale, - a2_scale=self.a2s_scale) + w1_scale=self.w13_scale, + w2_scale=self.w2_scale, + a1_scale=self.a13_scale, + a2_scale=self.a2_scale) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -226,7 +277,9 @@ def __init__(self, self.rope_theta = rope_theta self.sliding_window = sliding_window - if isinstance(quant_config, Fp8Config): + if isinstance( + quant_config, + Fp8Config) and not quant_config.is_checkpoint_fp8_serialized: print_warning_once( "For Mixtral FP8 quantization, we currently do not quantize " "the attention layers until their FP8 performance is improved." @@ -465,16 +518,23 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] expert_params_mapping = [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id) + ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", + f"experts.{expert_id}.{weight_name}.weight_scale", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + [ # These are the weights for the experts # (param_name, weight_name, expert_id) - ("ws" if weight_name in ["w1", "w3"] else "w2s", + ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", f"experts.{expert_id}.{weight_name}.weight", expert_id) for expert_id in range(self.config.num_local_experts) for weight_name in ["w1", "w2", "w3"] ] + [ # These are the activation scales for the experts # (param_name, weight_name, expert_id) - ("as_scale" if weight_name in ["w1", "w3"] else "a2s_scale", + ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", f"experts.{expert_id}.{weight_name}.act_scale", expert_id) for expert_id in range(self.config.num_local_experts) for weight_name in ["w1", "w2", "w3"] @@ -516,3 +576,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 12156b2ba1aa2..9969c45963e9a 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -16,17 +16,26 @@ @dataclass class SequenceGroupToSample: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + # Sequence ids for the sequence group in a previous step. seq_ids: List[int] sampling_params: SamplingParams # seq_id -> sequence data. seq_data: Dict[int, SequenceData] - # The length of the prompt of the sequence group. None if it is in a decode + # The length of the sequence (all tokens seen in the past + new token to + # compute attention) of the sequence group. None if it is in a decode # stage. - prompt_len: Optional[int] - # The length of the query tokens to compute in the current step. None if it - # is in a decode stage. The length of subquery_len <= prompt_len. - subquery_len: Optional[int] + seq_len: Optional[int] + # The length of new query tokens to compute in the current step. None if it + # is in a decode stage. The length of query_len <= seq_len if chunked + # prefill is enabled. + query_len: Optional[int] # A random number generator for sampling. generator: Optional[torch.Generator] # True if the sequence group is in prefill stage. False if it is in a @@ -46,8 +55,8 @@ def __post_init__(self): if len(self.prompt_logprob_indices) > 0: assert self.sampling_params.prompt_logprobs is not None if self.is_prompt: - assert self.prompt_len is not None - assert self.subquery_len is not None + assert self.seq_len is not None + assert self.query_len is not None class SamplingMetadata: @@ -94,8 +103,8 @@ def __init__( @staticmethod def prepare( seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - subquery_lens: Optional[List[int]], + seq_lens: List[int], + query_lens: Optional[List[int]], device: str, pin_memory: bool, ) -> "SamplingMetadata": @@ -104,8 +113,8 @@ def prepare( selected_token_indices, categorized_sample_indices, num_prompts, - ) = _prepare_seq_groups(seq_group_metadata_list, prompt_lens, - subquery_lens, device) + ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, + device) selected_token_indices = async_tensor_h2d(selected_token_indices, dtype=torch.long, target_device=device, @@ -137,8 +146,8 @@ def __repr__(self) -> str: def _prepare_seq_groups( seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - subquery_lens: Optional[List[int]], + seq_lens: List[int], + query_lens: Optional[List[int]], device: str, ) -> Tuple[List[SequenceGroupToSample], List[int], Dict[ SamplingType, List[Tuple[int, int]]], int]: @@ -146,9 +155,9 @@ def _prepare_seq_groups( Args: seq_group_metadata_list: A list of sequence group to batch. - prompt_lens: A list of prompt lens per sequence group. + seq_lens: A list of sequence lens per sequence group. Index of prompt len should match with seq_group_metadata_list. - subquery_lens: A list of query lengths. Prompt lens include the length + query_lens: A list of query lengths. Prompt lens include the length of entire prompt tokens, and it could be shorter. device: A device to use for random number generator, `SequenceGroupToSample.generator`. @@ -189,8 +198,8 @@ def _prepare_seq_groups( is_prompt = seq_group_metadata.is_prompt generator: Optional[torch.Generator] = None # If the current seq group is in decode stage, it is None. - prompt_len: Optional[int] = None - subquery_len: Optional[int] = None + seq_len: Optional[int] = None + query_len: Optional[int] = None prompt_logprob_indices: List[int] = [] sample_indices: List[int] = [] do_sample = seq_group_metadata.do_sample @@ -203,12 +212,12 @@ def _prepare_seq_groups( num_prompts += 1 num_prefill_sample = len(seq_ids) assert num_prefill_sample == 1 - assert subquery_lens is not None and prompt_lens is not None - subquery_len, prompt_len = subquery_lens[i], prompt_lens[i] + assert query_lens is not None and seq_lens is not None + query_len, seq_len = query_lens[i], seq_lens[i] # If we need sampling, exclude num_prefill_sample tokens from # prompt logprob. - prompt_logprob_len = (subquery_len - num_prefill_sample - if do_sample else subquery_len) + prompt_logprob_len = (query_len - num_prefill_sample + if do_sample else query_len) sample_len = num_prefill_sample if do_sample else 0 else: # Decode @@ -267,8 +276,8 @@ def sample(logits): seq_ids=seq_ids, sampling_params=sampling_params, seq_data=seq_group_metadata.seq_data, - prompt_len=prompt_len, - subquery_len=subquery_len, + seq_len=seq_len, + query_len=query_len, generator=generator, is_prompt=is_prompt, prompt_logprob_indices=list(prompt_logprob_indices), @@ -367,8 +376,8 @@ def from_sampling_metadata( and sampling_params.prompt_logprobs is not None): # For tokens in the prompt that we only need to get # their logprobs - subquery_len = seq_group.subquery_len - assert subquery_len is not None + query_len = seq_group.query_len + assert query_len is not None prefill_len = len(seq_group.prompt_logprob_indices) temperatures += [temperature] * prefill_len top_ps += [top_p] * prefill_len @@ -397,8 +406,8 @@ def from_sampling_metadata( if is_prompt: prompt_best_of.append(sampling_params.best_of) - subquery_len = seq_group.subquery_len - assert subquery_len is not None + query_len = seq_group.query_len + assert query_len is not None for seq_id in seq_ids: seq_data = seq_group.seq_data[seq_id] diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 0ed6a01a62212..5fa94eb149ffb 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -139,7 +139,10 @@ def __init__( self.top_p = top_p self.top_k = top_k self.min_p = min_p - self.seed = seed + if seed == -1: + self.seed = None + else: + self.seed = seed self.use_beam_search = use_beam_search self.length_penalty = length_penalty self.early_stopping = early_stopping @@ -275,7 +278,8 @@ def update_from_generation_config( self, generation_config: Dict[str, Any]) -> None: """Update if there are non-default values from generation_config""" # Update eos_token_id for generation - if eos_ids := generation_config.get("eos_token_id"): + if (not self.ignore_eos) and (eos_ids := + generation_config.get("eos_token_id")): # it can be either int or list of int if isinstance(eos_ids, int): eos_ids = [eos_ids] diff --git a/vllm/sequence.py b/vllm/sequence.py index 0e931ebbb6571..b486d1fedebd3 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1,8 +1,8 @@ """Sequence and its related classes.""" import copy import enum -from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from vllm.block import LogicalTokenBlock from vllm.lora.request import LoRARequest @@ -579,8 +579,10 @@ class SequenceGroupMetadata: query tokens for prefill, we don't need sampling. token_chunk_size: The number of tokens to be processed (per sequence). None if chunking is not required. - state: Internal state tied to this sequence group. lora_request: LoRA request. + computed_block_nums: The block numbers that are already computed, + used in prefix caching. + state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. """ @@ -698,6 +700,9 @@ class SamplerOutput: # On-device tensor containing probabilities of each token. sampled_token_probs: Optional["torch.Tensor"] = None + # On-device tensor containing the logprobs of each token. + logprobs: Optional["torch.Tensor"] = None + # On-device tensor containing the sampled token ids. sampled_token_ids: Optional["torch.Tensor"] = None @@ -729,3 +734,33 @@ def __repr__(self) -> str: f"sampled_token_probs={sampled_token_probs_repr}, " f"sampled_token_ids={sampled_token_ids_repr}, " f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") + + +@dataclass +class ExecuteModelRequest: + """The model execution request.""" + # The sequence group metadata list. + seq_group_metadata_list: List[SequenceGroupMetadata] + # Blocks to swap in. Dict of CPU -> GPU block number. + blocks_to_swap_in: Dict[int, int] = field(default_factory=dict) + # Blocks to swap out. Dict of GPU -> CPU block number. + blocks_to_swap_out: Dict[int, int] = field(default_factory=dict) + # Blocks to copy. Source to dest block. + blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list) + # The number of slots for lookahead decoding. + num_lookahead_slots: int = 0 + # The number of requests in the running queue. + running_queue_size: int = 0 + + def clone( + self, seq_group_metadata_list: List[SequenceGroupMetadata] + ) -> "ExecuteModelRequest": + """Clone the request with a new sequence group metadata list.""" + return ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=self.blocks_to_swap_in.copy(), + blocks_to_swap_out=self.blocks_to_swap_out.copy(), + blocks_to_copy=self.blocks_to_copy.copy(), + num_lookahead_slots=self.num_lookahead_slots, + running_queue_size=self.running_queue_size, + ) diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index c29b838f854c0..d5fd96907ddd7 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -1,9 +1,10 @@ from itertools import chain, count -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Iterator, List, Tuple import torch -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData, + SequenceGroupMetadata) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, @@ -40,11 +41,7 @@ def __init__(self, scorer_worker: WorkerBase, device: str, @nvtx_range("BatchExpansionTop1Scorer.score_proposals") def score_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - k: int, + execute_model_req: ExecuteModelRequest, proposals: SpeculativeProposals, ) -> SpeculativeScores: """Score the proposed tokens via the scorer model. @@ -57,11 +54,7 @@ def score_proposals( no speculation is produced for that sequence. Args: - seq_group_metadata_list: The input sequence group metadata. - blocks_to_swap_in: This is passed to the worker during scoring. - blocks_to_swap_out: This is passed to the worker during scoring. - blocks_to_copy: This is passed to the worker during scoring. - k: The fixed proposal length. + execute_model_req: The execution request. proposals: The speculative proposals to score. Returns: SpeculativeScores: The scores of each speculative token, along with @@ -80,33 +73,31 @@ def score_proposals( (spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens) = self._expand_batch( - seq_group_metadata_list=seq_group_metadata_list, + seq_group_metadata_list=execute_model_req.seq_group_metadata_list, proposal_token_ids_list=proposal_token_ids_list_without_skips, proposal_lens_list=proposal_lens_list, ) target_sampler_output = self._scorer_worker.execute_model( - seq_group_metadata_list=target_seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + execute_model_req=execute_model_req.clone( + seq_group_metadata_list=target_seq_group_metadata_list, )) assert len(target_sampler_output) == 1, "expected single-step output" target_sampler_output = target_sampler_output[0] - all_tokens, all_probs = self._contract_batch( - contracted_bs=len(seq_group_metadata_list), + all_tokens, all_probs, spec_logprobs = self._contract_batch( + contracted_bs=len(execute_model_req.seq_group_metadata_list), target_sampler_output=target_sampler_output, proposals=proposals, num_scoring_tokens=num_scoring_tokens, non_spec_indices=non_spec_indices, spec_indices=spec_indices, - k=k, + k=execute_model_req.num_lookahead_slots, ) return SpeculativeScores( probs=all_probs, token_ids=all_tokens, + logprobs=spec_logprobs, ) def _expand_batch( @@ -148,12 +139,12 @@ def _expand_batch( return (spec_indices, non_spec_indices, target_seq_group_metadata_list, num_scoring_tokens) - def _contract_batch(self, contracted_bs: int, - target_sampler_output: List[SamplerOutput], - proposals: SpeculativeProposals, - num_scoring_tokens: int, non_spec_indices: List[int], - spec_indices: List[int], - k: int) -> Tuple[torch.Tensor, torch.Tensor]: + def _contract_batch( + self, contracted_bs: int, + target_sampler_output: List[SamplerOutput], + proposals: SpeculativeProposals, num_scoring_tokens: int, + non_spec_indices: List[int], spec_indices: List[int], + k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Contract the expanded batch back into its original size. This maps the scores of speculative tokens back to their original sequences. @@ -161,8 +152,9 @@ def _contract_batch(self, contracted_bs: int, contracted_bs is the original batch size, and the batch size that the target_sampler_output will be contracted to. """ - (target_token_ids, target_probs, non_spec_target_token_ids, - non_spec_target_probs) = self._split_scoring_output( + (target_token_ids, target_probs, target_logprobs, + non_spec_target_token_ids, non_spec_target_probs, + non_spec_target_logprobs) = self._split_scoring_output( target_sampler_output, num_scoring_tokens) # Map distinct sequences used to score each token @@ -179,6 +171,8 @@ def _contract_batch(self, contracted_bs: int, spec_expanded_bs, k + 1) target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1, self._vocab_size) + target_logprobs = target_logprobs.squeeze().reshape( + spec_expanded_bs, k + 1, self._vocab_size) all_tokens = torch.full(size=(contracted_bs, k + 1), fill_value=-1, @@ -189,16 +183,26 @@ def _contract_batch(self, contracted_bs: int, self._vocab_size, device=self._device, dtype=torch.float32) + all_logprobs = torch.full(size=( + contracted_bs, + k + 1, + self._vocab_size, + ), + fill_value=-float("inf"), + device=self._device, + dtype=torch.float32) if non_spec_indices: all_tokens[non_spec_indices, :1] = non_spec_target_token_ids all_probs[non_spec_indices, :1, :] = non_spec_target_probs + all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs if spec_indices: all_tokens[spec_indices] = target_token_ids all_probs[spec_indices] = target_probs + all_logprobs[spec_indices] = target_logprobs - return all_tokens, all_probs + return all_tokens, all_probs, all_logprobs def _create_scoring_model_input( self, @@ -308,7 +312,8 @@ def _create_single_target_seq_group_metadata( def _split_scoring_output( self, sampler_output: SamplerOutput, num_scoring_tokens: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor]: """Split the target model output into speculative and non-speculative output. """ @@ -328,21 +333,29 @@ def _split_scoring_output( ) = sampler_output.sampled_token_probs.split(split_sizes) (spec_sampled_tokens, non_spec_sampled_tokens ) = sampler_output.sampled_token_ids.flatten().split(split_sizes) + ( + spec_logprobs, + non_spec_logprobs, + ) = sampler_output.logprobs.split(split_sizes) # Convert scores to tensors. sampler_output.sampled_token_probs = spec_probs sampler_output.sampled_token_ids = spec_sampled_tokens - target_token_ids, target_probs = sampler_output_to_torch( - [sampler_output]) + sampler_output.logprobs = spec_logprobs + (target_token_ids, target_probs, + target_logprobs) = sampler_output_to_torch([sampler_output], True) # Convert non-speculative output tokens to tensors. sampler_output.sampled_token_probs = non_spec_probs sampler_output.sampled_token_ids = non_spec_sampled_tokens - non_spec_target_token_ids, non_spec_target_probs = ( - sampler_output_to_torch([sampler_output])) - - return (target_token_ids, target_probs, non_spec_target_token_ids, - non_spec_target_probs) + sampler_output.logprobs = non_spec_logprobs + (non_spec_target_token_ids, non_spec_target_probs, + non_spec_target_logprobs) = sampler_output_to_torch([sampler_output], + True) + + return (target_token_ids, target_probs, target_logprobs, + non_spec_target_token_ids, non_spec_target_probs, + non_spec_target_logprobs) def _create_target_seq_id_iterator( self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index dd040779922e9..d311bfe984cbc 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -1,10 +1,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, List, Optional import torch -from vllm.sequence import SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest @dataclass @@ -38,6 +37,11 @@ class SpeculativeScores: # Probabilities of the speculative tokens according to the scoring model. probs: torch.Tensor + # Log-probabilities of the speculative tokens according to the scoring + # model. These values can be used to generate Logprob objects that are + # returned to the user. + logprobs: torch.Tensor + # Token ids sampled from the scoring model. Used for speculative bonus # tokens and also non-speculative normal decoding. token_ids: torch.Tensor @@ -53,11 +57,7 @@ class SpeculativeProposer(ABC): @abstractmethod def get_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - max_proposal_len: int, + execute_model_req: ExecuteModelRequest, ) -> SpeculativeProposals: raise NotImplementedError @@ -67,11 +67,7 @@ class SpeculativeScorer(ABC): @abstractmethod def score_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - k: int, + execute_model_req: ExecuteModelRequest, proposals: SpeculativeProposals, ) -> SpeculativeScores: raise NotImplementedError diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 7cf338bbae5f0..5044cc1ef85fd 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -1,12 +1,12 @@ import copy -from typing import Dict, List, Optional, Tuple +from typing import List, Tuple import torch -from vllm.sequence import SamplerOutput, SequenceGroupMetadata -from vllm.spec_decode.interfaces import (SpeculativeProposals, - SpeculativeProposer) -from vllm.spec_decode.util import sampler_output_to_torch +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.worker.worker import Worker @@ -26,50 +26,53 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Lazy initialization list. - self._proposer: DraftModelTop1Proposer + self._proposer: Top1Proposer def init_device(self): super().init_device() - self._proposer = DraftModelTop1Proposer( + self._proposer = Top1Proposer( self, self.device, - self.max_model_len, self.vocab_size, + max_proposal_len=self.max_model_len, ) + def set_include_gpu_probs_tensor(self): + # Need include_gpu_probs_tensor for multi_step_worker + self.model_runner.model.sampler.include_gpu_probs_tensor = True + @torch.inference_mode() - def execute_model_multi_step( + def sampler_output( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - num_steps: int, - ) -> List[SamplerOutput]: - """Run the model forward pass num_steps times. Returns the list of - sampler output, one per model forward pass. + execute_model_req: ExecuteModelRequest, + sample_len: int, + ) -> Tuple[List[SamplerOutput], bool]: + """Run the model forward pass sample_len times. Returns the list of + sampler output, one per model forward pass, along with indicator of + whether torch tensor in sampler output need to be transposed in latter + sampler_output_to_torch logic. + + For multi step worker, this indicator shall be True. """ - self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in, - blocks_to_swap_out, blocks_to_copy) + self._raise_if_unsupported(execute_model_req) # Shallow copy input data so modifications (such as appending tokens) # do not cause side-effects. copied_seq_group_metadata_list = self._shallow_copy_inputs( - seq_group_metadata_list) + execute_model_req.seq_group_metadata_list) + copied_execute_model_req = execute_model_req.clone( + copied_seq_group_metadata_list) - # Assert enough KV space for num_steps tokens per sequence. - self._assert_enough_kv_space(seq_group_metadata_list, num_steps) + # Assert enough KV space for sample_len tokens per sequence. + self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list, + sample_len) - # Run model num_steps times. + # Run model sample_len times. model_outputs = [] - for _ in range(num_steps): + for _ in range(sample_len): model_output = super().execute_model( - seq_group_metadata_list=copied_seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + execute_model_req=copied_execute_model_req) assert (len(model_output) == 1 ), "composing multistep workers not supported" model_output = model_output[0] @@ -78,27 +81,17 @@ def execute_model_multi_step( copied_seq_group_metadata_list) model_outputs.append(model_output) - return model_outputs + return model_outputs, True def get_spec_proposals( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - max_proposal_len: int, + execute_model_req: ExecuteModelRequest, ) -> SpeculativeProposals: """Produce speculations given an input batch of sequences. The number of speculative tokens per sequence is determined by max_proposal_len. """ - return self._proposer.get_proposals( - seq_group_metadata_list, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy, - max_proposal_len, - ) + return self._proposer.get_proposals(execute_model_req) def _append_new_tokens( self, model_output: SamplerOutput, @@ -189,188 +182,22 @@ def _assert_enough_kv_space( def _raise_if_unsupported( self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], + execute_model_req: ExecuteModelRequest, ) -> None: """MultiStepWorker does not yet implement support for cache swap operations or beam search. """ - if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]): + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): raise NotImplementedError( "MultiStepWorker does not support cache operations") if any( len(seq_group_metadata.seq_data.keys()) != 1 - for seq_group_metadata in seq_group_metadata_list): + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): raise NotImplementedError( "MultiStepWorker does not support beam search.") - - -class DraftModelTop1Proposer(SpeculativeProposer): - """Helper class which separates out sequences which would exceed the max - model length when speculated upon. - - This allows combinations of models such as JackFram/llama-68m draft with - meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of - 2048 while Llama2-13b has max_position_embeddings of 4096. - - We treat the sequences which exceed the proposal draft model length as - "non-spec sequences". Essentially they skip the draft model and go through - normal decoding in the target model. - - Currently, only proposal_lens of 0 and k are supported, where k is a global - batch proposal length. In the future vLLM should support per-sequence - proposal lengths. - """ - - def __init__( - self, - draft_worker: MultiStepWorker, - device: str, - max_model_len: int, - vocab_size: int, - ): - self._draft_worker = draft_worker - self._device = device - self._max_model_len = max_model_len - self._vocab_size = vocab_size - - def get_proposals( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - max_proposal_len: int, - ) -> SpeculativeProposals: - """Get speculative proposals given the input batch. - - Sequences which would exceed the max model length are skipped during - speculation. - """ - - # Split speculative- and non-speculative- sequences. - (proposal_lens, nonzero_proposal_len_seqs, - nonzero_proposal_len_indices) = self._split_by_max_model_len( - seq_group_metadata_list, max_proposal_len) - - if nonzero_proposal_len_seqs: - # Speculate tokens using the draft worker for the speculative - # sequences. - maybe_sampler_output = self._draft_worker.execute_model_multi_step( - seq_group_metadata_list=nonzero_proposal_len_seqs, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - num_steps=max_proposal_len, - ) - else: - # If no sequences can be speculated, set sampler output to None. - maybe_sampler_output = None - - # Combine speculative- and non-speculative sequences into the same - # representation. - proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs( - batch_size=len(seq_group_metadata_list), - max_proposal_len=max_proposal_len, - maybe_sampler_output=maybe_sampler_output, - proposal_lens=proposal_lens, - nonzero_proposal_len_indices=nonzero_proposal_len_indices, - ) - - proposals = SpeculativeProposals( - proposal_token_ids=proposal_tokens, - proposal_probs=proposal_probs, - proposal_lens=proposal_lens, - ) - - return proposals - - def _split_by_max_model_len( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - max_proposal_len: int, - ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]: - """Determine which sequences would exceed the max model length. - """ - - proposal_lens: List[int] = [] - nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] - nonzero_proposal_len_indices: List[int] = [] - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_data = next(iter(seq_group_metadata.seq_data.values())) - seq_len = seq_data.get_len() - - # Currently only proposal lens of 0 or the global batch proposal len - # are supported. - if seq_len + max_proposal_len < self._max_model_len: - proposal_lens.append(max_proposal_len) - nonzero_proposal_len_seqs.append(seq_group_metadata) - nonzero_proposal_len_indices.append(i) - else: - proposal_lens.append(0) - - return (proposal_lens, nonzero_proposal_len_seqs, - nonzero_proposal_len_indices) - - def _merge_outputs( - self, - batch_size: int, - max_proposal_len: int, - maybe_sampler_output: Optional[SamplerOutput], - proposal_lens: List[int], - nonzero_proposal_len_indices: List[int], - ) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]: - """After speculations are produced, merge the speculation results with - the skipped sequences. - """ - if maybe_sampler_output is None: - # If no speculative tokens, the sampler output will be None. - # In this case we return empty proposals. - proposal_tokens = torch.full(size=( - batch_size, - max_proposal_len, - ), - fill_value=-1, - dtype=torch.long, - device=self._device) - proposal_probs = torch.zeros(batch_size, - max_proposal_len, - self._vocab_size, - dtype=torch.float32, - device=self._device) - proposal_lens_tensor = torch.zeros(len(proposal_lens), - dtype=torch.long, - device=self._device) - return proposal_tokens, proposal_probs, proposal_lens_tensor - - sampler_output = maybe_sampler_output - proposal_tokens, proposal_probs = sampler_output_to_torch( - sampler_output) - - # Now, reformat the output GPU tensors such that each sequence has - # a proposal. the proposal can be empty, e.g. [-1, -1, -1] - - entire_proposal_tokens = torch.full(size=(batch_size, - *proposal_tokens.shape[1:]), - fill_value=-1, - dtype=torch.long, - device=self._device) - entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens - entire_proposal_probs = torch.zeros(batch_size, - *proposal_probs.shape[1:], - dtype=torch.float32, - device=self._device) - entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs - - proposal_tokens, proposal_probs = (entire_proposal_tokens, - entire_proposal_probs) - - proposal_lens_tensor = torch.zeros(batch_size, - dtype=torch.long, - device=self._device) - proposal_lens_tensor[nonzero_proposal_len_indices] = max_proposal_len - - return proposal_tokens, proposal_probs, proposal_lens_tensor diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py new file mode 100644 index 0000000000000..fed8be42054a5 --- /dev/null +++ b/vllm/spec_decode/ngram_worker.py @@ -0,0 +1,176 @@ +from typing import List, Optional, Tuple + +import torch + +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.top1_proposer import Top1Proposer +from vllm.worker.worker_base import LoraNotSupportedWorkerBase + + +class NGramWorker(LoraNotSupportedWorkerBase): + """NGramWorker provides a light drafter without need for model. + + Current NGramWorker only implement prompt lookup decoding, + and in future we may also do RAG type drafter and other scenerios + which don't rely on LLM model to give proposals. + """ + + def __init__(self, *args, **kwargs): + # Get local_rank/vocab_size from kwargs attribute + self.local_rank = kwargs["local_rank"] + self.vocab_size = kwargs["model_config"].get_vocab_size() + + # Lazy initialization list. + self._proposer: Top1Proposer + + def set_ngram_window_size(self, ngram_prompt_lookup_min: int, + ngram_prompt_lookup_max: int): + # Search valid candidate window between + # ngram_prompt_lookup_min/ngram_prompt_lookup_max + self.ngram_prompt_lookup_max = ngram_prompt_lookup_max + self.ngram_prompt_lookup_min = ngram_prompt_lookup_min + + def init_device(self): + self.device = torch.device(f"cuda:{self.local_rank}") + self.load_model = lambda *args, **kwargs: None + + # Current only support Top1Proposer + self._proposer = Top1Proposer( + self, + device=self.device, + vocab_size=self.vocab_size, + ) + + def set_include_gpu_probs_tensor(self): + # NGram don't need gpu sampler + pass + + def execute_model(self, execute_model_req: ExecuteModelRequest) -> None: + """NGram doesn't depend on model execution, just pass this function""" + pass + + def determine_num_available_blocks(self) -> None: + """NGram doesn't depend on model execution, no need to check blocks""" + pass + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """As there is no cache need to handle, just pass this function""" + pass + + def get_cache_block_size_bytes(self): + """Return the size of a cache block in bytes.""" + return 0 + + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + ) -> Tuple[Optional[List[SamplerOutput]], bool]: + """NGram match algo to pick proposal candidate. Returns the list of + sampler output, one per SequenceGroupMetadata. + + For ngram worker, we already done needed transposed internal, so the + indicator pass to sampler_output_to_torch shall be False. + """ + self._raise_if_unsupported(execute_model_req) + + arr = [] + has_spec_out = False + for seq_group_metadata in execute_model_req.seq_group_metadata_list: + seq_data = next(iter(seq_group_metadata.seq_data.values())) + + input_ids = torch.as_tensor(seq_data.get_token_ids(), + dtype=torch.long, + device=self.device) + input_length = seq_data.get_len() + + for ngram_size in range( + min(self.ngram_prompt_lookup_max, input_length - 1), + self.ngram_prompt_lookup_min, + -1, + ): + ngram_tensor = input_ids[-1 * ngram_size:] + windows = input_ids.unfold(dimension=0, + size=ngram_size, + step=1) + matches = (windows == ngram_tensor).all(dim=1) + match_indices = matches.nonzero(as_tuple=True)[0] + if match_indices.size()[0] > 1: + has_spec_out = True + res = seq_data.get_token_ids() + res = res[match_indices[0] + ngram_size:match_indices[0] + + ngram_size + sample_len] + res_len = len(res) + # pad 0 towards output as sample_len tokens required + res += [0] * (sample_len - res_len) + + break + else: + # if no candidate found, fill with 0 + res = [0] * sample_len + + arr.append(res) + + if not has_spec_out: + return None, False + + outputs = [] + token_ids = torch.as_tensor(arr, dtype=torch.long, device=self.device) + indices = token_ids.unsqueeze(2) + + token_probs = torch.zeros( + (len(execute_model_req.seq_group_metadata_list), sample_len, + self.vocab_size), + dtype=torch.float32, + device=self.device, + ) + token_probs.scatter_(2, indices, 1) + token_logprobs = torch.zeros( + (len(execute_model_req.seq_group_metadata_list), sample_len, + self.vocab_size), + dtype=torch.float32, + device=self.device, + ) + for i in range(len(execute_model_req.seq_group_metadata_list)): + outputs.append( + SamplerOutput( + outputs=None, + sampled_token_probs=token_probs[i], + logprobs=token_logprobs, + sampled_token_ids=token_ids[i], + )) + return outputs, False + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + + return self._proposer.get_proposals(execute_model_req) + + def _raise_if_unsupported( + self, + execute_model_req: ExecuteModelRequest, + ) -> None: + """NGramWorker does not yet implement support for cache swap + operations or beam search. + """ + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): + raise NotImplementedError( + "NGramWorker does not support cache operations") + + if any( + len(seq_group_metadata.seq_data.keys()) != 1 + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): + raise NotImplementedError( + "NGramWorker does not support beam search.") diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 4e70ea9686005..c2b119fbd5036 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,18 +1,21 @@ from functools import cached_property -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata, - SequenceGroupOutput, SequenceOutput) +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.multi_step_worker import MultiStepWorker -from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, +from vllm.spec_decode.ngram_worker import NGramWorker +from vllm.spec_decode.util import (create_sequence_group_output, + get_all_num_logprobs, get_all_seq_ids, + get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase @@ -48,8 +51,27 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): """ @classmethod - def from_workers(cls, proposer_worker: MultiStepWorker, - scorer_worker: WorkerBase) -> "SpecDecodeWorker": + def create_worker( + cls, + scorer_worker: WorkerBase, + draft_worker_kwargs, + ) -> "SpecDecodeWorker": + + if "ngram_prompt_lookup_max" in draft_worker_kwargs: + ngram_prompt_lookup_max = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_max")) + ngram_prompt_lookup_min = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_min")) + else: + ngram_prompt_lookup_max = 0 + + if ngram_prompt_lookup_max > 0: + proposer_worker = NGramWorker(**draft_worker_kwargs) + proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, + ngram_prompt_lookup_max) + else: + proposer_worker = MultiStepWorker(**draft_worker_kwargs) + return SpecDecodeWorker( proposer_worker, scorer_worker, @@ -59,7 +81,7 @@ def from_workers(cls, proposer_worker: MultiStepWorker, def __init__( self, - proposer_worker: MultiStepWorker, + proposer_worker: WorkerBase, scorer_worker: WorkerBase, rejection_sampler: RejectionSampler, metrics_collector: Optional[AsyncMetricsCollector] = None, @@ -134,8 +156,7 @@ def _configure_model_sampler_for_spec_decode(self): """ (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor ) = True - (self.proposer_worker.model_runner.model.sampler. - include_gpu_probs_tensor) = True + self.proposer_worker.set_include_gpu_probs_tensor() def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of cache blocks to use. @@ -169,69 +190,37 @@ def initialize_cache(self, num_gpu_blocks: int, @torch.inference_mode() def execute_model( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - num_lookahead_slots: int, - ) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Perform speculative decoding on the input batch. """ - assert seq_group_metadata_list is not None, ( + assert execute_model_req.seq_group_metadata_list is not None, ( "speculative decoding " "requires non-None seq_group_metadata_list") - logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d", - num_lookahead_slots) - # If no spec tokens, call the proposer and scorer workers normally. # Used for prefill. - if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0: - return self._run_no_spec( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) - - return self._run_speculative_decoding_step( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - k=num_lookahead_slots, - ) + if execute_model_req.num_lookahead_slots == 0 or len( + execute_model_req.seq_group_metadata_list) == 0: + return self._run_no_spec(execute_model_req) + + return self._run_speculative_decoding_step(execute_model_req) @nvtx_range("spec_decode_worker._run_no_spec") def _run_no_spec( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - ) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Run a prefill step, without any speculation. The input is sent to the proposer and scorer model so that the KV cache is consistent between the two. """ - logger.info("run proposer worker no spec") + #logger.info("run proposer worker no spec") - self.proposer_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + self.proposer_worker.execute_model(execute_model_req) - logger.info("run target worker no spec") - sampler_output = self.scorer_worker.execute_model( - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - ) + #logger.info("run target worker no spec") + sampler_output = self.scorer_worker.execute_model(execute_model_req) assert len(sampler_output) == 1 sampler_output = sampler_output[0] @@ -239,17 +228,13 @@ def _run_no_spec( # overhead when the engine runs in a different process than the workers. sampler_output.probs = None sampler_output.sampled_tokens = None + sampler_output.logprobs = None return [sampler_output] @nvtx_range("spec_decode_worker._run_speculative_decoding_step") def _run_speculative_decoding_step( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Optional[Dict[int, int]], - blocks_to_swap_out: Optional[Dict[int, int]], - blocks_to_copy: Optional[Dict[int, List[int]]], - k: int, - ) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Execute a single step of speculative decoding. This invokes the proposer worker to get k speculative tokens for each @@ -259,32 +244,27 @@ def _run_speculative_decoding_step( sequence. """ - logger.info("get spec proposals") + #logger.info("get spec proposals") # Generate proposals using draft worker. - assert blocks_to_swap_in is not None - assert blocks_to_swap_out is not None - assert blocks_to_copy is not None - proposals = self.proposer_worker.get_spec_proposals( - seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, - blocks_to_copy, k) - - logger.info("score proposals") + proposals = self.proposer_worker.get_spec_proposals(execute_model_req) + + #logger.info("score proposals") proposal_scores = self.scorer.score_proposals( - seq_group_metadata_list, - blocks_to_swap_in, - blocks_to_swap_out, - blocks_to_copy, - k, + execute_model_req, proposals, ) - logger.info("verify proposals") - accepted_token_ids = self._verify_tokens(seq_group_metadata_list, - proposal_scores, proposals, k) + #logger.info("verify proposals") + accepted_token_ids, target_logprobs = self._verify_tokens( + execute_model_req.seq_group_metadata_list, proposal_scores, + proposals, execute_model_req.num_lookahead_slots) - logger.info("create output list") - return self._create_output_sampler_list(seq_group_metadata_list, - accepted_token_ids, k) + #logger.info("create output list") + return self._create_output_sampler_list( + execute_model_req.seq_group_metadata_list, + accepted_token_ids, + target_logprobs=target_logprobs, + k=execute_model_req.num_lookahead_slots) @nvtx_range("spec_decode_worker._verify_tokens") def _verify_tokens( @@ -293,9 +273,12 @@ def _verify_tokens( proposal_scores: SpeculativeScores, proposals: SpeculativeProposals, max_proposal_len: int, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: """Determine which speculative tokens are accepted using the probabilities of each token according to the proposer and scorer models. + + Returns a tuple of Tensors, one for the accepted token ids and one for + the logprobs according to the scoring model. """ proposal_lens_list = proposals.proposal_lens.tolist() @@ -342,17 +325,19 @@ def _verify_tokens( non_spec_token_ids[:, 1:] = -1 accepted_token_ids = torch.cat( [accepted_token_ids, non_spec_token_ids]) + logprobs = proposal_scores.logprobs # Rearrange so that results are in the order of the original seq group # metadata. accepted_token_ids[original_indices] = accepted_token_ids.clone() - return accepted_token_ids + return accepted_token_ids, logprobs def _create_output_sampler_list( self, seq_group_metadata_list: List[SequenceGroupMetadata], accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] + target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size] k: int, ) -> List[SamplerOutput]: """Given the accepted token ids, create a list of SamplerOutput. @@ -360,30 +345,68 @@ def _create_output_sampler_list( The output is padded with -1 tokens such that each sequence has the same number of outputs. """ + batch_size, num_steps = accepted_token_ids.shape + + # Organize input tensors by step instead of by sequence. + target_logprobs_by_step = target_logprobs.transpose(0, 1) + accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1) + + # Get the logprobs/rank of the accepted tokens. + (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step) = get_sampled_token_logprobs( + logprob_tensor=target_logprobs_by_step, + sampled_token_ids=accepted_token_ids_by_step, + ) + + # Get the top-k logprobs (which may or may not include the logprob of + # the accepted token). + (topk_logprobs_by_step, + topk_indices_by_step) = target_logprobs_by_step.topk( + k=self.scorer_worker.model_config.max_logprobs, + dim=-1, + ) + + # Get the sequence ids and num_logprobs (sampling parameter) in the + # batch. seq_ids = get_all_seq_ids(seq_group_metadata_list) - - # shape: [k+1, batch_size] - accepted_token_ids_by_step = accepted_token_ids.transpose(0, - 1).tolist() + num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list) + + # Serialize all tensors to CPU Python lists. + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + accepted_token_id_ranks_by_step = ( + accepted_token_id_ranks_by_step.tolist()) + accepted_token_id_logprobs_by_step = ( + accepted_token_id_logprobs_by_step.tolist()) + topk_logprobs_by_step = topk_logprobs_by_step.tolist() + topk_indices_by_step = topk_indices_by_step.tolist() + + # Construct the output on a per-step, per-sequence basis. sampler_output_list = [] - for token_ids_by_step in accepted_token_ids_by_step: - if all(token_id == -1 for token_id in token_ids_by_step): + for step_index in range(num_steps): + if all(token_id == -1 + for token_id in accepted_token_ids_by_step[step_index]): break step_output_token_ids = [] - for token_id, seq_id in zip(token_ids_by_step, seq_ids): + for sequence_index in range(batch_size): + # Each sequence may have a different num_logprobs; retrieve it. + num_logprobs = num_logprobs_per_seq[sequence_index] + step_output_token_ids.append( - SequenceGroupOutput( - samples=[ - SequenceOutput( - parent_seq_id=seq_id, - output_token=token_id, - # TODO Add verifier logprobs. - logprobs={token_id: Logprob(0.0)}, - ) - ], - prompt_logprobs=None, + create_sequence_group_output( + token_id=accepted_token_ids_by_step[step_index] + [sequence_index], + token_id_logprob_rank=accepted_token_id_ranks_by_step[ + step_index][sequence_index], + token_id_logprob=accepted_token_id_logprobs_by_step[ + step_index][sequence_index], + seq_id=seq_ids[sequence_index], + topk_token_ids=topk_indices_by_step[step_index] + [sequence_index][:num_logprobs], + topk_logprobs=topk_logprobs_by_step[step_index] + [sequence_index][:num_logprobs], )) + sampler_output_list.append( SamplerOutput(outputs=step_output_token_ids)) diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py new file mode 100644 index 0000000000000..eb622a0e2e7f4 --- /dev/null +++ b/vllm/spec_decode/top1_proposer.py @@ -0,0 +1,200 @@ +from typing import List, Optional, Tuple + +import torch + +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeProposer) +from vllm.spec_decode.util import sampler_output_to_torch +from vllm.worker.worker_base import WorkerBase + + +class Top1Proposer(SpeculativeProposer): + """Helper class which separates out sequences which would exceed the max + model length when speculated upon. + + This allows combinations of models such as JackFram/llama-68m draft with + meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of + 2048 while Llama2-13b has max_position_embeddings of 4096. + + We treat the sequences which exceed the proposal draft model length as + "non-spec sequences". Essentially they skip the draft model and go through + normal decoding in the target model. + + Currently, only proposal_lens of 0 and k are supported, where k is a global + batch proposal length. In the future vLLM should support per-sequence + proposal lengths. + """ + + def __init__( + self, + worker: WorkerBase, + device: str, + vocab_size: int, + max_proposal_len: Optional[int] = None, + ): + self._worker = worker + self._device = device + self.max_proposal_len = max_proposal_len + self._vocab_size = vocab_size + + def get_proposals( + self, + execute_model_req: ExecuteModelRequest, + ) -> SpeculativeProposals: + """Get speculative proposals given the input batch. + + Sequences which would exceed the max model length are skipped during + speculation. + """ + proposal_len = execute_model_req.num_lookahead_slots + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + + # Split speculative- and non-speculative- sequences. + ( + proposal_lens, + nonzero_proposal_len_seqs, + nonzero_proposal_len_indices, + ) = self._split_by_max_model_len(seq_group_metadata_list, proposal_len) + + if nonzero_proposal_len_seqs: + # Speculate tokens using the draft worker for the speculative + # sequences. + # If sampler_transposed is true, then maybe_sampler_output's + # token_ids is like [batch] format in proposal_len size list, + # while if it is false, the format would be [proposal_len] + # in batch size list + nonzero_execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=nonzero_proposal_len_seqs, + num_lookahead_slots=proposal_len, + ) + maybe_sampler_output, transposed = self._worker.sampler_output( + execute_model_req=nonzero_execute_model_req, + sample_len=proposal_len, + ) + else: + # If no sequences can be speculated, set sampler output to None. + maybe_sampler_output = None + transposed = False + + # Combine speculative- and non-speculative sequences into the same + # representation. + proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs( + batch_size=len(seq_group_metadata_list), + proposal_len=proposal_len, + maybe_sampler_output=maybe_sampler_output, + proposal_lens=proposal_lens, + nonzero_proposal_len_indices=nonzero_proposal_len_indices, + sampler_transposed=transposed, + ) + + proposals = SpeculativeProposals( + proposal_token_ids=proposal_tokens, + proposal_probs=proposal_probs, + proposal_lens=proposal_lens, + ) + + return proposals + + def _split_by_max_model_len( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_len: int, + ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]: + """Determine which sequences would exceed the max model length.""" + + proposal_lens: List[int] = [] + nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] + nonzero_proposal_len_indices: List[int] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_data = next(iter(seq_group_metadata.seq_data.values())) + seq_len = seq_data.get_len() + + # Currently only proposal lens of 0 or the global batch proposal len + # are supported. + # If max_proposal_len is defined, then we shall no exccess this + # quota for nonzero_proposal + if (self.max_proposal_len is None + or seq_len + proposal_len < self.max_proposal_len): + proposal_lens.append(proposal_len) + nonzero_proposal_len_seqs.append(seq_group_metadata) + nonzero_proposal_len_indices.append(i) + else: + proposal_lens.append(0) + + return ( + proposal_lens, + nonzero_proposal_len_seqs, + nonzero_proposal_len_indices, + ) + + def _merge_outputs( + self, + batch_size: int, + proposal_len: int, + maybe_sampler_output: Optional[SamplerOutput], + proposal_lens: List[int], + nonzero_proposal_len_indices: List[int], + sampler_transposed: bool, + ) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]: + """After speculations are produced, merge the speculation results with + the skipped sequences. + """ + if maybe_sampler_output is None: + # If no speculative tokens, the sampler output will be None. + # In this case we return empty proposals. + proposal_tokens = torch.full( + size=( + batch_size, + proposal_len, + ), + fill_value=-1, + dtype=torch.long, + device=self._device, + ) + proposal_probs = torch.zeros( + batch_size, + proposal_len, + self._vocab_size, + dtype=torch.float32, + device=self._device, + ) + proposal_lens_tensor = torch.zeros(len(proposal_lens), + dtype=torch.long, + device=self._device) + return proposal_tokens, proposal_probs, proposal_lens_tensor + + sampler_output = maybe_sampler_output + proposal_tokens, proposal_probs, _ = sampler_output_to_torch( + sampler_output, sampler_transposed) + + # Now, reformat the output GPU tensors such that each sequence has + # a proposal. the proposal can be empty, e.g. [-1, -1, -1] + + entire_proposal_tokens = torch.full( + size=(batch_size, *proposal_tokens.shape[1:]), + fill_value=-1, + dtype=torch.long, + device=self._device, + ) + entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens + entire_proposal_probs = torch.zeros( + batch_size, + *proposal_probs.shape[1:], + dtype=torch.float32, + device=self._device, + ) + entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs + + proposal_tokens, proposal_probs = ( + entire_proposal_tokens, + entire_proposal_probs, + ) + + proposal_lens_tensor = torch.zeros(batch_size, + dtype=torch.long, + device=self._device) + proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len + + return proposal_tokens, proposal_probs, proposal_lens_tensor diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index eb6d4ca1da8e6..d6f80c82b80bf 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -1,10 +1,11 @@ from contextlib import contextmanager from itertools import chain -from typing import List, Tuple +from typing import Dict, List, Tuple import torch -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import (Logprob, SamplerOutput, SequenceGroupMetadata, + SequenceGroupOutput, SequenceOutput) SeqId = int @@ -21,6 +22,89 @@ def get_all_seq_ids( ])) +def get_all_num_logprobs( + seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]: + """Given a list of SequenceGroupMetadata, create a list of all num_logprobs. + + If the sampling params do not call for any logprobs, return 0 for that + sequence. + """ + + all_num_logprobs = [] + for seq_group_metadata in seq_group_metadata_list: + num_logprobs = seq_group_metadata.sampling_params.logprobs + if seq_group_metadata.sampling_params.logprobs is None: + num_logprobs = 0 + all_num_logprobs.append(num_logprobs) + + return all_num_logprobs + + +def get_sampled_token_logprobs( + # shape [num_steps, batch_size, vocab_size] + logprob_tensor: torch.Tensor, + sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size] +) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the logprobs for the sampled tokens. Returns the ranks and logprobs. + """ + num_steps, batch_size, vocab_size = logprob_tensor.shape + + selected_logprobs = logprob_tensor[torch.arange(num_steps).unsqueeze(1), + torch.arange(batch_size), + sampled_token_ids, ] + expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand( + -1, -1, vocab_size) + sampled_token_ids_ranks = (logprob_tensor >= + expanded_selected_logprobs).sum(-1) + + return sampled_token_ids_ranks, selected_logprobs + + +def create_sequence_group_output( + token_id: int, + token_id_logprob_rank: int, + token_id_logprob: float, + seq_id: SeqId, + topk_token_ids: List[int], + topk_logprobs: List[float], +) -> SequenceGroupOutput: + """Create a SequenceGroupOutput given the sampling results. + + Args: + token_id (int): The sampled token for the sequence. + token_id_logprob_rank (int): The logprob rank of the sampled token. + token_id_logprob (float): The logprob value of the sampled token. + seq_id (int): The sequence id. + topk_token_ids (List[int]): The list of top-k token ids. + topk_logprobs (List[float]): The list of top-k logprobs. + """ + # vLLM logprobs always include the sampled token. In addition, the user may + # request topk-logprobs (where top-k varies per user up to max_logprobs). + logprobs: Dict[int, Logprob] = { + token_id: Logprob( + logprob=token_id_logprob, + rank=token_id_logprob_rank, + ), + } + logprobs.update({ + topk_token_ids[topk_logprob_index]: Logprob( + logprob=topk_logprobs[topk_logprob_index], + rank=topk_logprob_index + 1, + ) + for topk_logprob_index, _ in enumerate(topk_token_ids) + }) + + return SequenceGroupOutput( + samples=[ + SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, + logprobs=logprobs) + ], + # TODO add prompt logprobs support. + prompt_logprobs=None, + ) + + def split_batch_by_proposal_len( seq_group_metadata_list: List[SequenceGroupMetadata], proposal_lens: List[int], select_proposal_len_zero: bool @@ -49,10 +133,13 @@ def split_batch_by_proposal_len( def sampler_output_to_torch( - sampler_output_list: List[SamplerOutput], -) -> Tuple[torch.Tensor, torch.Tensor]: + sampler_output_list: List[SamplerOutput], sampler_transposed: bool +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Utility function which converts a list of SamplerOutput to tensors. + sampler_transposed here is used as the indicator for whether + we need do additional tensor transpose logic here. + Returns: sampled_token_ids: torch.Tensor shape: [batch_size, len(sampler_output_list)] @@ -68,7 +155,19 @@ def sampler_output_to_torch( for sampler_output in sampler_output_list ], dim=0, - ).transpose(0, 1) + ) + + if sampler_transposed: + sampled_token_probs = sampled_token_probs.transpose(0, 1) + + # shape: [batch_size, num_sampler_output, vocab_size] + sampled_token_logprobs = torch.stack( + [sampler_output.logprobs for sampler_output in sampler_output_list], + dim=0, + ) + + if sampler_transposed: + sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1) # shape: [batch_size, num_sampler_output] sampled_token_ids = torch.stack( @@ -77,9 +176,11 @@ def sampler_output_to_torch( for sampler_output in sampler_output_list ], dim=0, - ).transpose(0, 1) + ) + if sampler_transposed: + sampled_token_ids = sampled_token_ids.transpose(0, 1) - return sampled_token_ids, sampled_token_probs + return sampled_token_ids, sampled_token_probs, sampled_token_logprobs def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index fa4693cb7dac1..f5684dbf1271c 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -5,7 +5,7 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) -from vllm.config import VLLM_USE_MODELSCOPE +from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizers import BaichuanTokenizer @@ -79,7 +79,7 @@ def get_tokenizer( revision=revision, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, # Ignore weights - we only need the tokenizer. - ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"]) + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) tokenizer_name = tokenizer_path if tokenizer_mode == "slow": diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index b2672f7f1da61..9029a5b16af72 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -15,20 +15,22 @@ import requests import torch -_config_home = os.getenv("XDG_CONFIG_HOME", os.path.expanduser("~/.config")) +import vllm.envs as envs + +_config_home = envs.VLLM_CONFIG_ROOT _USAGE_STATS_JSON_PATH = os.path.join(_config_home, "vllm/usage_stats.json") _USAGE_STATS_DO_NOT_TRACK_PATH = os.path.join(_config_home, "vllm/do_not_track") _USAGE_STATS_ENABLED = None -_USAGE_STATS_SERVER = os.environ.get("VLLM_USAGE_STATS_SERVER", - "https://stats.vllm.ai") +_USAGE_STATS_SERVER = envs.VLLM_USAGE_STATS_SERVER def is_usage_stats_enabled(): """Determine whether or not we can send usage stats to the server. The logic is as follows: - By default, it should be enabled. - - Two environment variables can disable it: + - Three environment variables can disable it: + - VLLM_DO_NOT_TRACK=1 - DO_NOT_TRACK=1 - VLLM_NO_USAGE_STATS=1 - A file in the home directory can disable it if it exists: @@ -36,8 +38,8 @@ def is_usage_stats_enabled(): """ global _USAGE_STATS_ENABLED if _USAGE_STATS_ENABLED is None: - do_not_track = os.environ.get("DO_NOT_TRACK", "0") == "1" - no_usage_stats = os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1" + do_not_track = envs.VLLM_DO_NOT_TRACK + no_usage_stats = envs.VLLM_NO_USAGE_STATS do_not_track_file = os.path.exists(_USAGE_STATS_DO_NOT_TRACK_PATH) _USAGE_STATS_ENABLED = not (do_not_track or no_usage_stats @@ -167,7 +169,7 @@ def _report_usage_once(self, model_architecture: str, # Metadata self.log_time = _get_current_timestamp_ns() - self.source = os.environ.get("VLLM_USAGE_SOURCE", "production") + self.source = envs.VLLM_USAGE_SOURCE data = vars(self) if extra_kvs: diff --git a/vllm/utils.py b/vllm/utils.py index 88447878f1706..6479a8dab320a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -19,8 +19,8 @@ import psutil import torch -from packaging.version import Version, parse +import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger T = TypeVar("T") @@ -174,7 +174,7 @@ def get_vllm_instance_id(): Instance id represents an instance of the VLLM. All processes in the same instance should have the same instance id. """ - return os.environ.get("VLLM_INSTANCE_ID", f"vllm-instance-{random_uuid()}") + return envs.VLLM_INSTANCE_ID or f"vllm-instance-{random_uuid()}" @lru_cache(maxsize=None) @@ -243,7 +243,7 @@ async def consumer(): def get_ip() -> str: - host_ip = os.environ.get("HOST_IP") + host_ip = envs.VLLM_HOST_IP if host_ip: return host_ip @@ -269,7 +269,8 @@ def get_ip() -> str: warnings.warn( "Failed to get the IP address, using 0.0.0.0 by default." - "The value can be set by the environment variable HOST_IP.", + "The value can be set by the environment variable" + " VLLM_HOST_IP or HOST_IP.", stacklevel=2) return "0.0.0.0" @@ -312,27 +313,6 @@ def cdiv(a: int, b: int) -> int: return -(a // -b) -@lru_cache(maxsize=None) -def get_nvcc_cuda_version() -> Optional[Version]: - cuda_home = os.environ.get('CUDA_HOME') - if not cuda_home: - cuda_home = '/usr/local/cuda' - if os.path.isfile(cuda_home + '/bin/nvcc'): - logger.info( - 'CUDA_HOME is not found in the environment. ' - 'Using %s as CUDA_HOME.', cuda_home) - else: - logger.warning('Not found nvcc in %s. Skip cuda version check!', - cuda_home) - return None - nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], - universal_newlines=True) - output = nvcc_output.split() - release_idx = output.index("release") + 1 - nvcc_cuda_version = parse(output[release_idx].split(",")[0]) - return nvcc_cuda_version - - def _generate_random_fp8( tensor: torch.tensor, low: float, @@ -353,21 +333,9 @@ def _generate_random_fp8( del tensor_tmp -def create_kv_caches_with_random( - num_blocks: int, - block_size: int, - num_layers: int, - num_heads: int, - head_size: int, - cache_dtype: Optional[Union[str, torch.dtype]], - model_dtype: Optional[Union[str, torch.dtype]] = None, - seed: int = 0, - device: Optional[str] = "cuda", -) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - +def get_kv_cache_torch_dtype( + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype: if isinstance(cache_dtype, str): if cache_dtype == "auto": if isinstance(model_dtype, str): @@ -386,6 +354,55 @@ def create_kv_caches_with_random( torch_dtype = cache_dtype else: raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + return torch_dtype + + +def create_kv_caches_with_random_flash( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + assert cache_dtype != "fp8" + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + scale = head_size**-0.5 + key_caches, value_caches = [], [] + for _ in range(num_layers): + key_value_cache = torch.empty(size=key_value_cache_shape, + dtype=torch_dtype, + device=device) + key_value_cache.uniform_(-scale, scale) + key_caches.append(key_value_cache[:, 0]) + value_caches.append(key_value_cache[:, 1]) + return key_caches, value_caches + + +def create_kv_caches_with_random( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) scale = head_size**-0.5 x = 16 // torch.tensor([], dtype=torch_dtype).element_size() @@ -521,7 +538,7 @@ def maybe_expand_dim(tensor: torch.Tensor, def merge_dicts(dict1: Dict[Any, List[Any]], dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]: """Merge 2 dicts that have key -> List of items. - + When a key conflicts, the values in dict1 is prioritized. """ merged_dict = defaultdict(list) @@ -581,7 +598,7 @@ def find_library(lib_name: str) -> str: # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1 locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line] # `LD_LIBRARY_PATH` searches the library in the user-defined paths - env_ld_library_path = os.getenv("LD_LIBRARY_PATH") + env_ld_library_path = envs.LD_LIBRARY_PATH if not locs and env_ld_library_path: locs = [ os.path.join(dir, lib_name) @@ -594,14 +611,15 @@ def find_library(lib_name: str) -> str: def find_nccl_library(): - so_file = os.environ.get("VLLM_NCCL_SO_PATH", "") + so_file = envs.VLLM_NCCL_SO_PATH + VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT # check if we have vllm-managed nccl vllm_nccl_path = None if torch.version.cuda is not None: cuda_major = torch.version.cuda.split(".")[0] path = os.path.expanduser( - f"~/.config/vllm/nccl/cu{cuda_major}/libnccl.so.*") + f"{VLLM_CONFIG_ROOT}/vllm/nccl/cu{cuda_major}/libnccl.so.*") files = glob.glob(path) vllm_nccl_path = files[0] if files else None @@ -626,7 +644,7 @@ def enable_trace_function_call_for_thread() -> None: if enabled via the VLLM_TRACE_FUNCTION environment variable """ - if int(os.getenv("VLLM_TRACE_FUNCTION", "0")): + if envs.VLLM_TRACE_FUNCTION: tmp_dir = tempfile.gettempdir() filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" f"_thread_{threading.get_ident()}_" diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index c34ee0648626b..26a60c652b6f4 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -77,7 +77,7 @@ def swap_out(self, src_to_dst: Dict[int, int]) -> None: self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], src_to_dst) - def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: + def copy(self, src_to_dsts: torch.Tensor) -> None: self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) @staticmethod diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 34d7d3dffea18..193b021b7a11e 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -80,7 +80,7 @@ def _prepare_prompt( input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - prompt_lens: List[int] = [] + seq_lens: List[int] = [] multi_modal_input_list: List[torch.Tensor] = [] for seq_group_metadata in seq_group_metadata_list: @@ -92,15 +92,15 @@ def _prepare_prompt( seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() computed_len = seq_data.get_num_computed_tokens() - prompt_len = len(prompt_tokens) + seq_len = len(prompt_tokens) - prompt_lens.append(prompt_len) # Prompt token num + seq_lens.append(seq_len) # Prompt token num input_tokens.extend(prompt_tokens) # Token ids # Token position ids # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, prompt_len))) + input_positions.extend(list(range(computed_len, seq_len))) if seq_group_metadata.multi_modal_data: multi_modal_input_list.append( @@ -109,15 +109,15 @@ def _prepare_prompt( # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, prompt_len - sliding_window). + # where start_idx is max(0, seq_len - sliding_window). # For example, if the prompt len is 10, sliding window is 8, and # block size is 4, the first two tokens are masked and the slot # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. start_idx = 0 if self.sliding_window is not None: - start_idx = max(0, prompt_len - self.sliding_window) + start_idx = max(0, seq_len - self.sliding_window) - for i in range(computed_len, prompt_len): + for i in range(computed_len, seq_len): if i < start_idx: slot_mapping.append(_PAD_SLOT_ID) continue @@ -151,19 +151,19 @@ def _prepare_prompt( attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - prompt_lens=prompt_lens, - num_prefills=len(prompt_lens), + seq_lens=seq_lens, + seq_lens_tensor=None, + max_seq_len=None, + num_prefills=len(seq_lens), num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, prefill_metadata=None, decode_metadata=None, - max_context_len=None, - context_lens=None, block_tables=torch.tensor([]), slot_mapping=slot_mapping, kv_cache_dtype=self.kv_cache_dtype, ) - return (input_tokens, input_positions, attn_metadata, prompt_lens, + return (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_input) def _prepare_decode( @@ -174,7 +174,7 @@ def _prepare_decode( input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - context_lens: List[int] = [] + seq_lens: List[int] = [] block_tables: List[List[int]] = [] for seq_group_metadata in seq_group_metadata_list: @@ -192,9 +192,9 @@ def _prepare_decode( position = seq_len - 1 input_positions.append(position) - context_len = seq_len if self.sliding_window is None else min( + seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) - context_lens.append(context_len) + seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] block_number = block_table[position // self.block_size] @@ -208,7 +208,7 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) - max_context_len = max(context_lens) + max_seq_len = max(seq_lens) input_tokens = torch.tensor(input_tokens, dtype=torch.long, @@ -219,9 +219,9 @@ def _prepare_decode( slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) - context_lens = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) max_block_table_len = max( len(block_table) for block_table in block_tables) @@ -236,14 +236,14 @@ def _prepare_decode( attn_metadata = self.attn_backend.make_metadata( is_prompt=False, slot_mapping=slot_mapping, - prompt_lens=None, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_seq_len=max_seq_len, num_prefill_tokens=0, num_decode_tokens=len(input_tokens), - max_context_len=max_context_len, num_prefills=0, prefill_metadata=None, decode_metadata=None, - context_lens=context_lens, block_tables=block_tables, kv_cache_dtype=self.kv_cache_dtype, ) @@ -265,20 +265,20 @@ def prepare_input_tensors( is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, attn_metadata, prompt_lens, + (input_tokens, input_positions, attn_metadata, seq_lens, multi_modal_input ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, attn_metadata) = self._prepare_decode(seq_group_metadata_list) - prompt_lens = [] + seq_lens = [] sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - # subquery_lens is not needed if chunked prefill is not + seq_lens, + # query_lens is not needed if chunked prefill is not # supported. Since CPU worker doesn't support chunked prefill - # just use prompt_lens instead. - prompt_lens, + # just use seq_lens instead. + seq_lens, self.device, pin_memory=False) # Broadcast the metadata. @@ -300,7 +300,7 @@ def prepare_input_tensors( sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, - prompt_lens=None, + seq_lens=None, selected_token_indices=selected_token_indices, categorized_sample_indices=None, generators=None, diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 83ededd742533..e1ef500ac07b8 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -13,7 +13,7 @@ init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase @@ -248,30 +248,35 @@ def _init_cache_engine(self) -> None: def cache_copy( self, - blocks_to_copy: Dict[int, List[int]], + blocks_to_copy: torch.Tensor, ) -> None: - if blocks_to_copy: + if blocks_to_copy.numel() > 0: self.cache_engine.copy(blocks_to_copy) @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, - blocks_to_swap_in: Optional[Dict[int, int]] = None, - blocks_to_swap_out: Optional[Dict[int, int]] = None, - blocks_to_copy: Optional[Dict[int, List[int]]] = None, + execute_model_req: Optional[ExecuteModelRequest] = None, ) -> List[SamplerOutput]: + + if execute_model_req is None: + seq_group_metadata_list = None + else: + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + if self.is_driver_worker: assert seq_group_metadata_list is not None num_seq_groups: int = len(seq_group_metadata_list) - assert blocks_to_swap_in is not None - assert blocks_to_swap_out is not None - assert blocks_to_copy is not None - assert len(blocks_to_swap_in) == 0 - assert len(blocks_to_swap_out) == 0 + assert execute_model_req is not None + blocks_to_copy = execute_model_req.blocks_to_copy + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device="cpu", + dtype=torch.int64).view(-1, 2) + assert len(execute_model_req.blocks_to_swap_in) == 0 + assert len(execute_model_req.blocks_to_swap_out) == 0 data: Dict[str, Any] = { "num_seq_groups": num_seq_groups, - "blocks_to_copy": blocks_to_copy, + "blocks_to_copy": execute_model_req.blocks_to_copy, } broadcast_tensor_dict(data, src=0) else: @@ -279,7 +284,6 @@ def execute_model( num_seq_groups = data["num_seq_groups"] blocks_to_copy = data["blocks_to_copy"] - assert blocks_to_copy is not None self.cache_copy(blocks_to_copy) # If there is no input, we don't need to execute the model. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0704f5fec54d0..ab248596490f6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -9,6 +9,7 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) +from vllm.attention.backends.flashinfer import FlashInferBackend from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce @@ -23,8 +24,8 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, SequenceData, SequenceGroupMetadata) -from vllm.utils import (CudaMemoryProfiler, is_hip, is_pin_memory_available, - make_tensor_with_pad) +from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, + is_pin_memory_available, make_tensor_with_pad) logger = init_logger(__name__) @@ -42,8 +43,8 @@ class PreparePromptMetadata(NamedTuple): input_tokens: List[int] input_positions: List[int] attn_metadata: Optional[AttentionMetadataPerStage] - prompt_lens: List[int] - subquery_lens: List[int] + seq_lens: List[int] + query_lens: List[int] lora_index_mapping: List[int] lora_prompt_mapping: List[int] lora_requests: Set[LoRARequest] @@ -56,8 +57,8 @@ def empty(cls): input_tokens=[], input_positions=[], attn_metadata=None, - prompt_lens=[], - subquery_lens=[], + seq_lens=[], + query_lens=[], lora_index_mapping=[], lora_prompt_mapping=[], lora_requests=set(), @@ -134,9 +135,8 @@ def __init__( self.graph_memory_pool: Optional[Tuple[ int, int]] = None # Set during graph capture. - self.max_context_len_to_capture = ( - self.model_config.max_context_len_to_capture - if self.model_config is not None else 0) + self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture + if self.model_config is not None else 0) self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype @@ -149,13 +149,16 @@ def __init__( self.model: torch.nn.Module # Set after load_model self.block_size: int # Set after initial profiling. # When using CUDA graph, the input block tables must be padded to - # max_context_len_to_capture. However, creating the block table in + # max_seq_len_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table # in numpy and only copy the actual input content at every iteration. # The shape of the cached block table will be # (max batch size to capture, max context len to capture / block size). self.graph_block_tables: torch.Tensor # Set after initial profiling. + # Set if the backend is flashinfer. + self.flashinfer_workspace_buffer: torch.Tensor + def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( @@ -218,7 +221,7 @@ def set_block_size(self, block_size: int) -> None: def get_max_block_per_batch(self) -> int: block_size = self.block_size - return (self.max_context_len_to_capture + block_size - 1) // block_size + return (self.max_seq_len_to_capture + block_size - 1) // block_size def _prepare_prompt( self, @@ -231,9 +234,9 @@ def _prepare_prompt( lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() - prompt_lens: List[int] = [] + seq_lens: List[int] = [] context_lens: List[int] = [] - subquery_lens: List[int] = [] + query_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] multi_modal_input_list: List[torch.Tensor] = [] @@ -257,21 +260,19 @@ def _prepare_prompt( token_chunk_size = seq_group_metadata.token_chunk_size seq_data = seq_group_metadata.seq_data[seq_id] - computed_len = seq_data.get_num_computed_tokens() + context_len = seq_data.get_num_computed_tokens() # We should use get_len here because in case of preemption # it contains output tokens. - prefill_end = min(seq_data.get_len(), - computed_len + token_chunk_size) - prompt_tokens = seq_data.get_token_ids()[computed_len:prefill_end] - prompt_len = prefill_end - prompt_lens.append(prompt_len) + seq_len = min(seq_data.get_len(), context_len + token_chunk_size) + prompt_tokens = seq_data.get_token_ids()[context_len:seq_len] + seq_lens.append(seq_len) # NOTE: This only works for oooooooxxx style attention. if computed_block_nums is not None and len( computed_block_nums) > 0 and self.sliding_window is None: # Prefix is not supported with sliding_window - computed_len = len(computed_block_nums) * self.block_size - prompt_tokens = prompt_tokens[computed_len:] + context_len = len(computed_block_nums) * self.block_size + prompt_tokens = prompt_tokens[context_len:] prefix_block_tables.append(computed_block_nums) elif self.scheduler_config.chunked_prefill_enabled: if seq_group_metadata.block_tables is not None: @@ -285,25 +286,25 @@ def _prepare_prompt( prefix_block_tables.append([]) # Right now, prefill start is always 0. However, this # assumption can be changed once chunked prefill is introduced. - assert computed_len == 0 + assert context_len == 0 # actual prompt lens - context_lens.append(computed_len) - subquery_lens.append(prompt_len - computed_len) + context_lens.append(context_len) + query_lens.append(seq_len - context_len) input_tokens.extend(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(list(range(computed_len, prefill_end))) + input_positions.extend(list(range(context_len, seq_len))) lora_id = seq_group_metadata.lora_int_id if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - lora_index_mapping += [lora_id] * (prompt_len - computed_len) + lora_index_mapping += [lora_id] * (seq_len - context_len) lora_prompt_mapping.extend( [lora_id] * - (prompt_len - computed_len + (seq_len - context_len if seq_group_metadata.sampling_params.prompt_logprobs else 1)) if seq_group_metadata.multi_modal_data: @@ -313,24 +314,25 @@ def _prepare_prompt( if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. - slot_mapping.extend([_PAD_SLOT_ID] * prompt_len) + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) continue # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, prompt_len - sliding_window). + # where start_idx is max(0, seq_len - sliding_window). # For example, if the prompt len is 10, sliding window is 8, and # block size is 4, the first two tokens are masked and the slot # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. start_idx = 0 if self.sliding_window is not None: - assert computed_len == 0, ( + assert context_len == 0, ( "Prefix caching is currently not supported with " "sliding window attention") - start_idx = max(0, prompt_len - self.sliding_window) + start_idx = max(0, seq_len - self.sliding_window) - for i in range(computed_len, prefill_end): + for i in range(context_len, seq_len): if i < start_idx: slot_mapping.append(_PAD_SLOT_ID) continue @@ -340,9 +342,9 @@ def _prepare_prompt( slot = block_number * self.block_size + block_offset slot_mapping.append(slot) - max_subquery_len = max(subquery_lens) - max_prompt_len = max(prompt_lens) - assert max_subquery_len > 0 + max_query_len = max(query_lens) + max_seq_len = max(seq_lens) + assert max_query_len > 0 context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, @@ -369,50 +371,57 @@ def _prepare_prompt( # Query length can be shorter than key (i.e., prompt) when prefill # is chunked or prefix cached. - subquery_lens_tensor = torch.tensor(subquery_lens, - dtype=torch.long, - device=self.device) - subquery_start_loc = torch.zeros(subquery_lens_tensor.shape[0] + 1, + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=self.device) + subquery_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, dtype=torch.int32, device=self.device) - prompt_lens_tensor = torch.tensor(prompt_lens, - dtype=torch.long, - device=self.device) - seq_start_loc = torch.zeros(prompt_lens_tensor.shape[0] + 1, + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, dtype=torch.int32, device=self.device) - torch.cumsum(subquery_lens_tensor, + torch.cumsum(query_lens_tensor, dim=0, dtype=subquery_start_loc.dtype, out=subquery_start_loc[1:]) - torch.cumsum(prompt_lens_tensor, + torch.cumsum(seq_lens_tensor, dim=0, dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - prompt_lens=prompt_lens, - prompt_lens_tensor=prompt_lens_tensor, - max_subquery_len=max_subquery_len, - max_context_len=None, - max_prompt_len=max_prompt_len, - subquery_start_loc=subquery_start_loc, - seq_start_loc=seq_start_loc, - context_lens=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=False, - ) + if self.attn_backend is FlashInferBackend: + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + use_cuda_graph=False, + seq_start_loc=seq_start_loc, + max_seq_len=max_seq_len, + block_tables=block_tables) + else: + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + subquery_start_loc=subquery_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + ) return PreparePromptMetadata( input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, - prompt_lens=prompt_lens, - subquery_lens=subquery_lens, + seq_lens=seq_lens, + query_lens=query_lens, lora_index_mapping=lora_index_mapping, lora_prompt_mapping=lora_prompt_mapping, lora_requests=lora_requests, @@ -427,12 +436,30 @@ def _prepare_decode( input_tokens: List[int] = [] input_positions: List[int] = [] slot_mapping: List[int] = [] - context_lens: List[int] = [] + seq_lens: List[int] = [] block_tables: List[List[int]] = [] lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + # The following fields are only for flashinfer + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + paged_kv_indices: List[int] = [] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request + paged_kv_last_page_len: List[int] = [] + if len(seq_group_metadata_list) == 0: return PrepareDecodeMetadata.empty() @@ -455,9 +482,9 @@ def _prepare_decode( position = seq_len - 1 input_positions.append(position) - context_len = seq_len if self.sliding_window is None else min( + seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) - context_lens.append(context_len) + seq_lens.append(seq_len) block_table = seq_group_metadata.block_tables[seq_id] block_number = block_table[position // self.block_size] @@ -473,15 +500,21 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + paged_kv_indices.extend(block_table) + paged_kv_indptr.append(paged_kv_indptr[-1] + len(block_table)) + last_page_len = seq_data.get_len() % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + paged_kv_last_page_len.append(last_page_len) + # vLLM uses cuda graph only for decoding requests. # See `capture_model` API for more details. # For decoding requests, batch_size == input_tokens. batch_size = len(input_tokens) - max_context_len = max(context_lens) - use_captured_graph = ( - not self.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_context_len <= self.max_context_len_to_capture) + max_seq_len = max(seq_lens) + use_captured_graph = (not self.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_seq_len <= self.max_seq_len_to_capture) if use_captured_graph: graph_batch_size = _get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size @@ -489,21 +522,21 @@ def _prepare_decode( input_tokens.append(0) input_positions.append(0) slot_mapping.append(_PAD_SLOT_ID) - context_lens.append(1) + seq_lens.append(1) block_tables.append([]) lora_index_mapping.append(0) batch_size = graph_batch_size - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) if use_captured_graph: # When using cuda-graph all these tensors should be # padded. - assert context_lens_tensor.shape[0] == len(input_tokens) - assert context_lens_tensor.shape[0] == len(input_positions) - assert context_lens_tensor.shape[0] == len(slot_mapping) + assert seq_lens_tensor.shape[0] == len(input_tokens) + assert seq_lens_tensor.shape[0] == len(input_positions) + assert seq_lens_tensor.shape[0] == len(slot_mapping) # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -523,19 +556,51 @@ def _prepare_decode( device=self.device, ) - attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, - prompt_lens=None, - prompt_lens_tensor=None, - max_subquery_len=None, - max_context_len=max_context_len, - max_prompt_len=None, - subquery_start_loc=None, - seq_start_loc=None, - context_lens=context_lens_tensor, - block_tables=block_tables, - use_cuda_graph=use_captured_graph, - ) + if self.attn_backend is FlashInferBackend: + if not hasattr(self, "flashinfer_workspace_buffer"): + # Allocate 16MB workspace buffer + # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html + self.flashinfer_workspace_buffer = torch.empty( + 16 * 1024 * 1024, dtype=torch.uint8, device=self.device) + paged_kv_indptr = torch.tensor(paged_kv_indptr, + dtype=torch.int, + device=self.device) + paged_kv_indices = torch.tensor(paged_kv_indices, + dtype=torch.int, + device=self.device) + paged_kv_last_page_len = torch.tensor(paged_kv_last_page_len, + dtype=torch.int, + device=self.device) + kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, + self.model_config.dtype) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + use_cuda_graph=False, + workspace_buffer=self.flashinfer_workspace_buffer, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_kv_last_page_len, + num_qo_heads=self.model_config.get_num_attention_heads( + self.parallel_config), + num_kv_heads=self.model_config.get_num_kv_heads( + self.parallel_config), + head_dim=self.model_config.get_head_size(), + page_size=self.block_size, + data_type=kv_cache_dtype) + else: + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_query_len=None, + max_seq_len=max_seq_len, + subquery_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) return PrepareDecodeMetadata( input_tokens=input_tokens, input_positions=input_positions, @@ -565,8 +630,8 @@ def prepare_input_tensors( input_tokens, input_positions, prefill_attn_metadata, - prompt_lens, - subquery_lens, + seq_lens, + query_lens, lora_index_mapping, lora_prompt_mapping, lora_requests, @@ -583,13 +648,13 @@ def prepare_input_tensors( decode_slot_mapping, ) = self._prepare_decode(decode_reqs) sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, prompt_lens, subquery_lens, - self.device, self.pin_memory) + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.pin_memory) if not self.scheduler_config.chunked_prefill_enabled: assert (len(prefill_reqs) and len(decode_reqs)) == 0 - num_prefills = len(prompt_lens) + num_prefills = len(seq_lens) num_prefill_tokens = len(input_tokens) num_decode_tokens = len(decode_input_tokens) @@ -886,7 +951,7 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) - context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() + seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() graph_batch_size = _get_graph_batch_size( @@ -908,14 +973,13 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: # Create dummy attn_metadata. decode_metadata = self.attn_backend.make_metadata( is_prompt=False, - prompt_lens=None, - prompt_lens_tensor=None, - max_subquery_len=None, - max_context_len=self.max_context_len_to_capture, - max_prompt_len=None, + seq_lens=None, + seq_lens_tensor=seq_lens[:batch_size], + max_query_len=None, + max_seq_len=self.max_seq_len_to_capture, subquery_start_loc=None, seq_start_loc=None, - context_lens=context_lens[:batch_size], + context_lens_tensor=None, block_tables=block_tables[:batch_size], use_cuda_graph=True, ) @@ -1025,7 +1089,7 @@ def capture( "positions": positions, "kv_caches": kv_caches, "slot_mapping": attn_metadata.slot_mapping, - "context_lens": attn_metadata.decode_metadata.context_lens, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, } self.output_buffers = {"hidden_states": hidden_states} @@ -1047,8 +1111,8 @@ def forward( self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping, non_blocking=True) - self.input_buffers["context_lens"].copy_( - attn_metadata.decode_metadata.context_lens, non_blocking=True) + self.input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) self.input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) # Run the graph. diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index a974e85c22f45..a336be04e124f 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -52,7 +52,7 @@ def _prepare_prompt( input_positions: List[List[int]] = [] input_block_ids: List[int] = [] - prompt_lens: List[int] = [] + seq_lens: List[int] = [] for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -61,26 +61,26 @@ def _prepare_prompt( seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() - prompt_len = len(prompt_tokens) - prompt_lens.append(prompt_len) + seq_len = len(prompt_tokens) + seq_lens.append(seq_len) input_tokens.append(prompt_tokens) - input_positions.append(list(range(prompt_len))) + input_positions.append(list(range(seq_len))) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] assert len(block_table) == 1 input_block_ids.append(block_table[0]) - max_prompt_len = max(prompt_lens) - assert max_prompt_len > 0 + max_seq_len = max(seq_lens) + assert max_seq_len > 0 input_tokens = make_tensor_with_pad(input_tokens, - max_prompt_len, + max_seq_len, pad=0, dtype=torch.long, device=self.device) input_positions = make_tensor_with_pad(input_positions, - max_prompt_len, + max_seq_len, pad=0, dtype=torch.long, device=self.device) @@ -88,7 +88,7 @@ def _prepare_prompt( dtype=torch.long, device=self.device) - return input_tokens, input_positions, input_block_ids, prompt_lens + return input_tokens, input_positions, input_block_ids, seq_lens def _prepare_decode( self, @@ -149,18 +149,18 @@ def prepare_input_tensors( # Prepare input tensors. if is_prompt: (input_tokens, input_positions, input_block_ids, - prompt_lens) = self._prepare_prompt(seq_group_metadata_list) + seq_lens) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, input_block_ids) = self._prepare_decode(seq_group_metadata_list) - prompt_lens = [] + seq_lens = [] sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, - prompt_lens, - # subquery_lens is not needed if chunked prefill is not + seq_lens, + # query_lens is not needed if chunked prefill is not # supported. Since neuron worker doesn't support chunked prefill - # just use prompt_lens instead. - prompt_lens, + # just use seq_lens instead. + seq_lens, self.device, self.pin_memory) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 39ad428f16fe3..538332ad003f1 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -11,13 +11,14 @@ VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, + get_tensor_model_parallel_cpu_group, init_distributed_environment) from vllm.distributed.device_communicators import pynccl_utils from vllm.distributed.device_communicators.custom_all_reduce import ( init_custom_ar) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner from vllm.worker.worker_base import WorkerBase @@ -196,7 +197,7 @@ def cache_swap( self, blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], + blocks_to_copy: torch.Tensor, ) -> None: # Issue cache operations. # TODO(woosuk): Profile swapping overhead and optimize if needed. @@ -204,25 +205,29 @@ def cache_swap( self.cache_engine.swap_in(blocks_to_swap_in) if blocks_to_swap_out: self.cache_engine.swap_out(blocks_to_swap_out) - if blocks_to_copy: + if blocks_to_copy.numel() > 0: self.cache_engine.copy(blocks_to_copy) @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, - blocks_to_swap_in: Optional[Dict[int, int]] = None, - blocks_to_swap_out: Optional[Dict[int, int]] = None, - blocks_to_copy: Optional[Dict[int, List[int]]] = None, - num_lookahead_slots: int = 0, + execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[SamplerOutput]: + if execute_model_req is None: + seq_group_metadata_list = None + else: + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + if self.is_driver_worker: assert seq_group_metadata_list is not None + assert execute_model_req is not None num_seq_groups = len(seq_group_metadata_list) - assert blocks_to_swap_in is not None - assert blocks_to_swap_out is not None - assert blocks_to_copy is not None + blocks_to_swap_in = execute_model_req.blocks_to_swap_in + blocks_to_swap_out = execute_model_req.blocks_to_swap_out + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device=self.device, + dtype=torch.int64).view(-1, 2) data: Dict[str, Any] = { "num_seq_groups": num_seq_groups, "blocks_to_swap_in": blocks_to_swap_in, @@ -237,9 +242,6 @@ def execute_model( blocks_to_swap_out = data["blocks_to_swap_out"] blocks_to_copy = data["blocks_to_copy"] - assert blocks_to_swap_in is not None - assert blocks_to_swap_out is not None - assert blocks_to_copy is not None self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) # If there is no input, we don't need to execute the model. @@ -288,6 +290,9 @@ def init_worker_distributed_environment( init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + if pynccl_utils.is_initialized(): pynccl_world_size = pynccl_utils.get_world_size() if pynccl_world_size != parallel_config.world_size: @@ -298,12 +303,9 @@ def init_worker_distributed_environment( elif parallel_config.world_size > 1: # NOTE(woosuk): We don't initialize pynccl process group when world size # is 1. - # NOTE(kaichao): By default, pynccl will use information inside - # `parallel_state` for initialization. - pynccl_utils.init_process_group() - - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + # NOTE(kaichao): By default, pynccl is initialized for tp group. + pynccl_utils.init_process_group( + group=get_tensor_model_parallel_cpu_group()) # Initialize a custom fast all-reduce implementation. if not parallel_config.disable_custom_all_reduce: diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 0a89e3a79769f..fb32feaca0c94 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -5,7 +5,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) @@ -48,10 +48,8 @@ def initialize_cache(self, num_gpu_blocks: int, @abstractmethod def execute_model( - self, seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, - int], - blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" raise NotImplementedError