Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Add multipstep chunked-prefill support for FlashInfer #10467

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

elfiegg
Copy link
Contributor

@elfiegg elfiegg commented Nov 20, 2024

Support multi-step scheduling for chunked-prefill on FlashInfer, where prefill tokens are turned into decode tokens after the first single step.

cc @comaniac @yzh199 @WoosukKwon @youkaichao

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

"specific parameter.")

if turn_prefills_into_decodes:
# When mutli-Step is enabled with chunked-Prefill, prefills and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# When mutli-Step is enabled with chunked-Prefill, prefills and
# When Multi-Step is enabled with Chunked-Prefill, prefills and

@comaniac
Copy link
Collaborator

Please fix the linting.

@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 20, 2024
@taegeonum
Copy link

I've got the following error:

(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229] Traceback (most recent call last):
(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/vllm-public/vllm/worker/model_runner_base.py", line 116, in _wrapper
(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229]     return func(*args, **kwargs)
(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/vllm-public/vllm/worker/model_runner.py", line 1624, in execute_model
(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229]     self.attn_state.begin_forward(model_input)
(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/vllm-public/vllm/attention/backends/flashinfer.py", line 262, in begin_forward
(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229]     state = (self.runner.graph_runners[model_input.virtual_engine]
(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229] KeyError: 2052
(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229] 

@elfiegg
Copy link
Contributor Author

elfiegg commented Nov 21, 2024

I've got the following error:

(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229] Traceback (most recent call last):
(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/vllm-public/vllm/worker/model_runner_base.py", line 116, in _wrapper
(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229]     return func(*args, **kwargs)
(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/vllm-public/vllm/worker/model_runner.py", line 1624, in execute_model
(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229]     self.attn_state.begin_forward(model_input)
(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/vllm-public/vllm/attention/backends/flashinfer.py", line 262, in begin_forward
(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229]     state = (self.runner.graph_runners[model_input.virtual_engine]
(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229] KeyError: 2052
(VllmWorkerProcess pid=505659) ERROR 11-21 10:00:59 multiproc_worker_utils.py:229] 

Could you please share the repo command? Thank you! @taegeonum

@taegeonum
Copy link

@elfiegg VLLM_ATTENTION_BACKEND=FLASHINFER python3 -m vllm.entrypoints.openai.api_server --model <any_model> --quantization fp8 --kv-cache-dtype fp8 --gpu-memory-utilization 0.95 --num-scheduler-steps 10 --enable-chunked-prefill True --max-num-batched-tokens 512

@elfiegg elfiegg force-pushed the chunked_multistep branch 3 times, most recently from ad48534 to 97d6859 Compare November 27, 2024 23:35
@elfiegg
Copy link
Contributor Author

elfiegg commented Nov 27, 2024

@taegeonum apologies for the delay - the bug should be in cuda graph mode and has been fixed. Tests are configured.

@taegeonum
Copy link

@elfiegg Thanks! but got another error:


ERROR 11-28 15:40:26 engine.py:366] Error in model execution: CHECK_EQ(paged_kv_indptr.size(0), batch_size + 1) failed. 0 vs 1
ERROR 11-28 15:40:26 engine.py:366] Traceback (most recent call last):
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/vllm-public/vllm/worker/model_runner_base.py", line 116, in _wrapper
ERROR 11-28 15:40:26 engine.py:366]     return func(*args, **kwargs)
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/vllm-public/vllm/worker/model_runner.py", line 1652, in execute_model
ERROR 11-28 15:40:26 engine.py:366]     self.attn_state.begin_forward(model_input)
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/vllm-public/vllm/attention/backends/flashinfer.py", line 268, in begin_forward
ERROR 11-28 15:40:26 engine.py:366]     model_input.attn_metadata.begin_forward()
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/vllm-public/vllm/attention/backends/flashinfer.py", line 385, in begin_forward
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/venv-vllm-0.6.4-latest/lib/python3.10/site-packages/flashinfer/decode.py", line 530, in plan                                                                                                         [473/1637]
ERROR 11-28 15:40:26 engine.py:366]     self._wrapper.plan(
ERROR 11-28 15:40:26 engine.py:366] RuntimeError: CHECK_EQ(paged_kv_indptr.size(0), batch_size + 1) failed. 0 vs 1
ERROR 11-28 15:40:26 engine.py:366]
ERROR 11-28 15:40:26 engine.py:366] The above exception was the direct cause of the following exception:
ERROR 11-28 15:40:26 engine.py:366]
ERROR 11-28 15:40:26 engine.py:366] Traceback (most recent call last):
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/vllm-public/vllm/engine/multiprocessing/engine.py", line 357, in run_mp_engine
ERROR 11-28 15:40:26 engine.py:366]     engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/vllm-public/vllm/engine/multiprocessing/engine.py", line 119, in from_engine_args
ERROR 11-28 15:40:26 engine.py:366]     return cls(ipc_path=ipc_path,
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/vllm-public/vllm/engine/multiprocessing/engine.py", line 71, in __init__
ERROR 11-28 15:40:26 engine.py:366]     self.engine = LLMEngine(*args, **kwargs)
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/vllm-public/vllm/engine/llm_engine.py", line 338, in __init__
ERROR 11-28 15:40:26 engine.py:366]     self._initialize_kv_caches()
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/vllm-public/vllm/engine/llm_engine.py", line 476, in _initialize_kv_caches
ERROR 11-28 15:40:26 engine.py:366]     self.model_executor.determine_num_available_blocks())
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/vllm-public/vllm/executor/distributed_gpu_executor.py", line 39, in determine_num_available_blocks
ERROR 11-28 15:40:26 engine.py:366]     num_blocks = self._run_workers("determine_num_available_blocks", )
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/vllm-public/vllm/executor/multiproc_gpu_executor.py", line 195, in _run_workers
ERROR 11-28 15:40:26 engine.py:366]     driver_worker_output = driver_worker_method(*args, **kwargs)
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/venv-vllm-0.6.4-latest/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 11-28 15:40:26 engine.py:366]     return func(*args, **kwargs)
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/vllm-public/vllm/worker/worker.py", line 198, in determine_num_available_blocks
ERROR 11-28 15:40:26 engine.py:366]     self.model_runner.profile_run()
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/vllm-public/vllm/worker/multi_step_model_runner.py", line 662, in profile_run
ERROR 11-28 15:40:26 engine.py:366]     return self._base_model_runner.profile_run()
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/venv-vllm-0.6.4-latest/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 11-28 15:40:26 engine.py:366]     return func(*args, **kwargs)
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/vllm-public/vllm/worker/model_runner.py", line 1343, in profile_run
ERROR 11-28 15:40:26 engine.py:366]     self.execute_model(model_input, kv_caches, intermediate_tensors)
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/venv-vllm-0.6.4-latest/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 11-28 15:40:26 engine.py:366]     return func(*args, **kwargs)
ERROR 11-28 15:40:26 engine.py:366]   File "/group-volume/taegeon/vllm-public/vllm/worker/model_runner_base.py", line 146, in _wrapper
ERROR 11-28 15:40:26 engine.py:366]     raise type(err)(f"Error in model execution: "
ERROR 11-28 15:40:26 engine.py:366] RuntimeError: Error in model execution: CHECK_EQ(paged_kv_indptr.size(0), batch_size + 1) failed. 0 vs 1
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229] Traceback (most recent call last):
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/vllm-public/vllm/worker/model_runner_base.py", line 116, in _wrapper
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]     return func(*args, **kwargs)
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/vllm-public/vllm/worker/model_runner.py", line 1652, in execute_model
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]     self.attn_state.begin_forward(model_input)
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/vllm-public/vllm/attention/backends/flashinfer.py", line 268, in begin_forward
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]     model_input.attn_metadata.begin_forward()
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/vllm-public/vllm/attention/backends/flashinfer.py", line 385, in begin_forward
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]     self.decode_wrapper.begin_forward(
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/venv-vllm-0.6.4-latest/lib/python3.10/site-packages/flashinfer/decode.py", line 530, in plan
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]     self._wrapper.plan(
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229] RuntimeError: CHECK_EQ(paged_kv_indptr.size(0), batch_size + 1) failed. 0 vs 1
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229] The above exception was the direct cause of the following exception:
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229] Traceback (most recent call last):
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/vllm-public/vllm/executor/multiproc_worker_utils.py", line 223, in _run_worker_process
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]     output = executor(*args, **kwargs)
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/venv-vllm-0.6.4-latest/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]     return func(*args, **kwargs)
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/vllm-public/vllm/worker/worker.py", line 198, in determine_num_available_blocks
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]     self.model_runner.profile_run()
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/vllm-public/vllm/worker/multi_step_model_runner.py", line 662, in profile_run
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]     return self._base_model_runner.profile_run()
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/venv-vllm-0.6.4-latest/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]     return func(*args, **kwargs)
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/vllm-public/vllm/worker/model_runner.py", line 1343, in profile_run
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/venv-vllm-0.6.4-latest/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]     return func(*args, **kwargs)
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]   File "/group-volume/taegeon/vllm-public/vllm/worker/model_runner_base.py", line 146, in _wrapper
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229]     raise type(err)(f"Error in model execution: "
(VllmWorkerProcess pid=311147) ERROR 11-28 15:40:26 multiproc_worker_utils.py:229] RuntimeError: Error in model execution: CHECK_EQ(paged_kv_indptr.size(0), batch_size + 1) failed. 0 vs 1

@elfiegg
Copy link
Contributor Author

elfiegg commented Nov 28, 2024

Could you please share the reproduce command? Thanks! @taegeonum

@taegeonum
Copy link

VLLM_ATTENTION_BACKEND=FLASHINFER python3 -m vllm.entrypoints.openai.api_server --model model_path -tensor-parallel-size 8 --quantization fp8 --kv-cache-dtype fp8 --max-num-seqs 500 --max-model-len 32768 --max-num-batched-tokens 4096 --gpu-memory-utilization 0.95 --trust-remote-code --enable-chunked-prefill true --num-scheduler-steps 10

@taegeonum
Copy link

@elfiegg Hello, any progress? it would be good if we can use multistep+chunked prefill on FlashInfer.

@elfiegg
Copy link
Contributor Author

elfiegg commented Dec 3, 2024

@taegeonum sure - I'm just back from Thanksgiving vacation and will update this tomorrow

@elfiegg
Copy link
Contributor Author

elfiegg commented Dec 11, 2024

Hello @JaheimLee, There seems to be a bug in multistep+chunked prefill cuda graph mode, where it schedules batch tokens during profiling where it shouldn't really do that. The issue with the model config above seems unrelated to multistep+chunked prefill on FlashInfer. Using FlashAttn I observed the similar fail.

Also, if you turn off cuda graph mode via

VLLM_ATTENTION_BACKEND=FLASHINFER python3 -m vllm.entrypoints.openai.api_server --model model_path -tensor-parallel-size 8 --quantization fp8 --kv-cache-dtype fp8 --max-num-seqs 500 --max-model-len 32768 --max-num-batched-tokens 4096 --gpu-memory-utilization 0.95 --trust-remote-code --enable-chunked-prefill true --num-scheduler-steps 10 --enforce-eager

This will work. I'm trying to narrow down the issue and it seems it might relate to the PR: https://github.com/vllm-project/vllm/pull/8645/files here. cc @varun-sundar-rabindranath for more contexts if there is any.

@elfiegg
Copy link
Contributor Author

elfiegg commented Dec 12, 2024

Hello @JaheimLee, can you pull the latest changes and confirm if they fix the issue? Thanks!

@varun-sundar-rabindranath
Copy link
Contributor

@elfiegg on main with FLASH_ATTN backend,
command : VLLM_ATTENTION_BACKEND='FLASH_ATTN' python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-8B --port 9000 --max-num-seqs 500 --gpu-memory-utilization 0.90 --trust-remote-code --num-scheduler-steps 10 --enable-chunked-prefill --max-num-batched-tokens 4096

I see the following error,

ERROR 12-12 10:48:21 engine.py:366] local variable 'key_cache' referenced before assignment
ERROR 12-12 10:48:21 engine.py:366] Traceback (most recent call last):
ERROR 12-12 10:48:21 engine.py:366]   File "/home/varun/code/vllm/vllm/engine/multiprocessing/engine.py", line 357, in run_mp_engine
ERROR 12-12 10:48:21 engine.py:366]     engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
ERROR 12-12 10:48:21 engine.py:366]   File "/home/varun/code/vllm/vllm/engine/multiprocessing/engine.py", line 119, in from_engine_args
ERROR 12-12 10:48:21 engine.py:366]     return cls(ipc_path=ipc_path,
ERROR 12-12 10:48:21 engine.py:366]   File "/home/varun/code/vllm/vllm/engine/multiprocessing/engine.py", line 71, in __init__
ERROR 12-12 10:48:21 engine.py:366]     self.engine = LLMEngine(*args, **kwargs)
ERROR 12-12 10:48:21 engine.py:366]   File "/home/varun/code/vllm/vllm/engine/llm_engine.py", line 292, in __init__
ERROR 12-12 10:48:21 engine.py:366]     self._initialize_kv_caches()
ERROR 12-12 10:48:21 engine.py:366]   File "/home/varun/code/vllm/vllm/engine/llm_engine.py", line 432, in _initialize_kv_caches
ERROR 12-12 10:48:21 engine.py:366]     self.model_executor.determine_num_available_blocks())
ERROR 12-12 10:48:21 engine.py:366]   File "/home/varun/code/vllm/vllm/executor/gpu_executor.py", line 68, in determine_num_available_blocks
ERROR 12-12 10:48:21 engine.py:366]     return self.driver_worker.determine_num_available_blocks()
...
ERROR 12-12 10:48:21 engine.py:366]   File "/home/varun/code/vllm/vllm/attention/layer.py", line 287, in unified_attention_with_output
ERROR 12-12 10:48:21 engine.py:366]     self.impl.forward(query,
ERROR 12-12 10:48:21 engine.py:366]   File "/home/varun/code/vllm/vllm/attention/backends/flash_attn.py", line 810, in forward
ERROR 12-12 10:48:21 engine.py:366]     k_cache=key_cache,
ERROR 12-12 10:48:21 engine.py:366] UnboundLocalError: local variable 'key_cache' referenced before assignment

Is this what you were seeing when you said

The issue with the model config above seems unrelated to multistep+chunked prefill on FlashInfer. Using FlashAttn I observed the similar fail.
?

When I change the max_num_seqs to 512 however, the error goes away. Ill take a look.

@elfiegg
Copy link
Contributor Author

elfiegg commented Dec 12, 2024

@varun-sundar-rabindranath Yes that's the error msg. Seems like it is the symptom of scheduling decode requests during profiling - since the KVcache shape are not yet determined during profiling.
Preventing pad graph size during profiling seems to ease away the issue.

@varun-sundar-rabindranath
Copy link
Contributor

@elfiegg Thanks for catching this. I arrived at the same conclusion from my debugging. Your fix with self.is_profile_run should fix the issue.

fwiw, I am adding some of my analysis below:

With FLASH_ATTN, I believe this is what is happening,
When the scheduler schedules all prefill sequences, in multi-step chunked-prefill, we add cuda-graph padding so that when the prefills turn into decode, can we can use cuda-graphs.
For example, let the scheduler schedule 5 prefill sequences. and let the cuda graph padding be 3.

  • first step : flash_attn.py interprets this as 5 prefill and 3 decode sequences being scheduled. flash_attn.py runs both the prefill attention and decode attention.
  • second step - end : flash_attn.py interprets the state as 8 decode sequences.

This is fine when the engine is fully initialized. But during the profile_run, this is not the case. The KV caches have not been initialized. When we add cuda graph padding during the profile_run, flash_attn.py runs both prefill and decode attentions. This blows up due to

if kv_cache.numel() > 0:
if statement failing and key_cache not being initialized.

This manifests only when max_num_seqs is not a value in _BATCH_SIZES_TO_CAPTURE. When max_num_seqs is in _BATCH_SIZES_TO_CAPTURE, the profile_run doesn't add a graph padding.

@@ -1226,6 +1228,7 @@ def _prepare_model_input_tensors(

@torch.inference_mode()
def profile_run(self) -> None:
self.is_profile_run = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Pass is_profile_run with default false as a parameter to _get_cuda_graph_pad_size function to prevent accidentally accessing and setting/disabling the profile run in other places.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep thanks for the suggestion - The default value is False and is set in the runner: https://github.com/vllm-project/vllm/pull/10467/files#diff-d3df23c3e3bcfe97ee8507061c6de54f0eff23a8c75d7f5999062c42245290f8R1028

Whenever the value is accessed, it will be default to False if that's what you're suggesting. : )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two minor suggestions:

  1. Rename to in_profile_run because this is a temporary status.
  2. Use context manager to make sure this value is always reset after this function.

@taegeonum
Copy link

@elfiegg No exception during server boot-up. However, I've got an error when processing requests:

context^M
^[[1;36m(VllmWorkerProcess pid=1921076)^[[0;0m ERROR 12-13 12:10:24 multiproc_worker_utils.py:236]     return func(*args, **kwargs)^M
^[[1;36m(VllmWorkerProcess pid=1921076)^[[0;0m ERROR 12-13 12:10:24 multiproc_worker_utils.py:236]   File "/group-volume/taegeon/vllm-public/vllm/worker/multi_step_model_runner.py", line 510, in execute_model^M
^[[1;36m(VllmWorkerProcess pid=1921076)^[[0;0m ERROR 12-13 12:10:24 multiproc_worker_utils.py:236]     model_input = self._advance_step(^M
^[[1;36m(VllmWorkerProcess pid=1921076)^[[0;0m ERROR 12-13 12:10:24 multiproc_worker_utils.py:236]   File "/group-volume/taegeon/vllm-public/vllm/worker/multi_step_model_runner.py", line 637, in _advance_step^M
^[[1;36m(VllmWorkerProcess pid=1921076)^[[0;0m ERROR 12-13 12:10:24 multiproc_worker_utils.py:236]     attn_metadata.advance_step(^M
^[[1;36m(VllmWorkerProcess pid=1921076)^[[0;0m ERROR 12-13 12:10:24 multiproc_worker_utils.py:236]   File "/group-volume/taegeon/vllm-public/vllm/attention/backends/flashinfer.py", line 468, in advance_step^M
^[[1;36m(VllmWorkerProcess pid=1921076)^[[0;0m ERROR 12-13 12:10:24 multiproc_worker_utils.py:236]     ops.advance_step_flashinfer(^M
^[[1;36m(VllmWorkerProcess pid=1921076)^[[0;0m ERROR 12-13 12:10:24 multiproc_worker_utils.py:236]   File "/group-volume/taegeon/vllm-public/vllm/_custom_ops.py", line 45, in wrapper^M
^[[1;36m(VllmWorkerProcess pid=1921076)^[[0;0m ERROR 12-13 12:10:24 multiproc_worker_utils.py:236]     return fn(*args, **kwargs)^M
^[[1;36m(VllmWorkerProcess pid=1921076)^[[0;0m ERROR 12-13 12:10:24 multiproc_worker_utils.py:236]   File "/group-volume/taegeon/vllm-public/vllm/_custom_ops.py", line 268, in advance_step_flashinfer^M
^[[1;36m(VllmWorkerProcess pid=1921076)^[[0;0m ERROR 12-13 12:10:24 multiproc_worker_utils.py:236]     return torch.ops._C.advance_step_flashinfer(^M
^[[1;36m(VllmWorkerProcess pid=1921076)^[[0;0m ERROR 12-13 12:10:24 multiproc_worker_utils.py:236]   File "/group-volume/taegeon/venv-vllm-0.6.4-latest/lib/python3.10/site-packages/torch/_ops.py", line 1116, in __call__^M
^[[1;36m(VllmWorkerProcess pid=1921076)^[[0;0m ERROR 12-13 12:10:24 multiproc_worker_utils.py:236]     return self._op(*args, **(kwargs or {}))^M
^[[1;36m(VllmWorkerProcess pid=1921076)^[[0;0m ERROR 12-13 12:10:24 multiproc_worker_utils.py:236] RuntimeError: tensor: name = block_tables, shape = [504, 2048] is_cont = 1, type = int is not as expected: shape = [1, -1], type = Int
^[[1;36m(VllmWorkerProcess pid=1921079)^[[0;0m ERROR 12-13 12:10:24 multiproc_worker_utils.py:236] Exception in worker VllmWorkerProcess while processing method start_worker_execution_loop.^M
^[[1;36m(VllmWorkerProcess pid=1921079)^[[0;0m ERROR 12-13 12:10:24 multiproc_worker_utils.py:236] Traceback (most recent call last):^M
^[[1;36m(VllmWorkerProcess pid=1921079)^[[0;0m ERROR 12-13 12:10:24 multiproc_worker_utils.py:236]   File "/group-volume/taegeon/vllm-public/vllm/executor/multiproc_worker_utils.py", line 230, in _run_worker_process^M
^[[1;36m(VllmWorkerProcess pid=1921079)^[[0;0m ERROR 12-13 12:10:24 multiproc_worker_utils.py:236]     output = executor(*args, **kwargs)^M

@elfiegg elfiegg force-pushed the chunked_multistep branch 2 times, most recently from 94c6296 to 419b010 Compare December 20, 2024 23:12
@elfiegg
Copy link
Contributor Author

elfiegg commented Dec 20, 2024

@taegeonum - applogize for the delay and thanks so much for the patience. Can you pull the latest and try again? The issue you ran into was a CUDA graph padding issue that has been fixed.

@elfiegg elfiegg force-pushed the chunked_multistep branch 2 times, most recently from 32f925b to 5b3feed Compare December 20, 2024 23:33
@elfiegg elfiegg requested review from comaniac and mgoin December 21, 2024 00:13
@simon-mo
Copy link
Collaborator

@comaniac do you mind take another look?

@taegeonum
Copy link

@elfiegg Now it works without exceptions. Thanks!!

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delayed review. The change LGTM.
Also does this PR is only compatible with certain FlashInfer versions? If so where do we document it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment the CI time for this file after the PR? Since you pretty much make the number of tests 4x in this file, I'm a bit worry about the impact on CI time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about CI time, but from time command, the diff:
After:
real 7m39.973s
user 9m35.828s
sys 0m17.893s

Before
real 1m59.525s
user 2m42.044s
sys 0m9.140s

so it is about 4 times.

if model_input.attn_metadata.use_cuda_graph:
use_cuda_graph = model_input.attn_metadata.use_cuda_graph
is_decode = model_input.attn_metadata.num_prefills == 0
if use_cuda_graph and is_decode:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add comment about why we look at use_cuda_graph and is_decode to update the state?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -1226,6 +1228,7 @@ def _prepare_model_input_tensors(

@torch.inference_mode()
def profile_run(self) -> None:
self.is_profile_run = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two minor suggestions:

  1. Rename to in_profile_run because this is a temporary status.
  2. Use context manager to make sure this value is always reset after this function.

@JaheimLee
Copy link

Can both v0 and v1 use this feature?

@comaniac
Copy link
Collaborator

comaniac commented Jan 4, 2025

Can both v0 and v1 use this feature?

v1 doesn't have multi-step scheduling and we don't plan to add it.

@elfiegg elfiegg force-pushed the chunked_multistep branch 4 times, most recently from 5d6d263 to 93e2cda Compare January 8, 2025 01:21
@elfiegg elfiegg force-pushed the chunked_multistep branch from 93e2cda to c10e056 Compare January 8, 2025 01:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants