Skip to content

Commit

Permalink
[V1][Core] Remove should_shutdown to simplify core process termination (
Browse files Browse the repository at this point in the history
vllm-project#11113)

Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
tlrmchlsmth authored and elfiegg committed Dec 12, 2024
1 parent d1e21a9 commit c4a05d7
Show file tree
Hide file tree
Showing 6 changed files with 987 additions and 562 deletions.
19 changes: 17 additions & 2 deletions tests/multi_step/test_correctness_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import pytest

from tests.kernels.utils import override_backend_env_variable

from ..models.utils import check_logprobs_close, check_outputs_equal

MODELS = [
Expand All @@ -19,10 +21,11 @@
@pytest.mark.parametrize("tp_size", [1])
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs", [None, 5])
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN", "FLASHINFER"])
def test_multi_step_llm(
hf_runner,
vllm_runner,
Expand All @@ -36,6 +39,8 @@ def test_multi_step_llm(
num_scheduler_steps: int,
num_prompts: int,
num_logprobs: Optional[int],
attention_backend: str,
monkeypatch,
) -> None:
"""Test vLLM engine with multi-step scheduling via sync LLM Engine.
Expand Down Expand Up @@ -63,6 +68,7 @@ def test_multi_step_llm(
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
completions endpoint; `None` -> 1 logprob returned.
"""
override_backend_env_variable(monkeypatch, attention_backend)

prompts = example_prompts
if len(prompts) < num_prompts:
Expand Down Expand Up @@ -110,10 +116,11 @@ def test_multi_step_llm(
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("tp_size", [1])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [True])
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)])
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN", "FLASHINFER"])
def test_multi_step_llm_w_prompt_logprobs(
vllm_runner,
example_prompts,
Expand All @@ -126,6 +133,8 @@ def test_multi_step_llm_w_prompt_logprobs(
num_prompts: int,
num_logprobs: Optional[int],
num_prompt_logprobs: Optional[int],
attention_backend: str,
monkeypatch,
) -> None:
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine.
Expand Down Expand Up @@ -155,6 +164,7 @@ def test_multi_step_llm_w_prompt_logprobs(
note that this argument is not supported by the
OpenAI completions endpoint.
"""
override_backend_env_variable(monkeypatch, attention_backend)

prompts = example_prompts
if len(prompts) < num_prompts:
Expand Down Expand Up @@ -205,6 +215,7 @@ def test_multi_step_llm_w_prompt_logprobs(
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs", [None, 5])
@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"])
def test_multi_step_llm_chunked_prefill_prefix_cache(
vllm_runner,
example_prompts,
Expand All @@ -216,6 +227,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
num_scheduler_steps: int,
num_prompts: int,
num_logprobs: Optional[int],
attention_backend: str,
monkeypatch,
) -> None:
"""Test vLLM engine with multi-step+"single-step chunked prefill"+APC.
Expand Down Expand Up @@ -278,6 +291,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache(
#
# The Incorrect scheduling behavior - if it occurs - will cause an exception
# in the model runner resulting from `do_sample=False`.
override_backend_env_variable(monkeypatch, attention_backend)

assert len(example_prompts) >= 2
challenge_prompts = copy.deepcopy(example_prompts)
challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient '
Expand Down
29 changes: 24 additions & 5 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,9 @@ def prepare_graph_input_buffers(self,
def begin_forward(self, model_input):
assert not self._is_graph_capturing
state = self
if model_input.attn_metadata.use_cuda_graph:
use_cuda_graph = model_input.attn_metadata.use_cuda_graph
is_decode = model_input.attn_metadata.num_prefills == 0
if use_cuda_graph and is_decode:
batch_size = model_input.input_tokens.shape[0]
state = (self.runner.graph_runners[model_input.virtual_engine]
[batch_size].attn_state)
Expand Down Expand Up @@ -332,6 +334,8 @@ def __post_init__(self):
f"received {self.head_dim}.")

def begin_forward(self):
print("num_prefill_tokens:" + str(self.num_prefill_tokens) + "num_decode_tokens:" + str(self.num_decode_tokens))
print("batch size:" + str(self.paged_kv_indptr.shape))
if self.num_prefill_tokens > 0:
if self.paged_kv_indices is None:
return
Expand Down Expand Up @@ -429,10 +433,24 @@ def advance_step(self,
Update metadata in-place to advance one decode step.
"""

assert not turn_prefills_into_decodes, \
("Chunked prefill is not supported with flashinfer yet."
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
"specific parameter.")
if turn_prefills_into_decodes:
# When Multi-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
# Flashinfer doesn't support speculative decoding + chunked-prefill
# + multi-step scheduling yet.
assert self.decode_query_len == 1
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
self.num_prefill_tokens = 0
self.max_prefill_seq_len = 0
self.max_query_len = 1

self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens_tensor is not None

assert num_seqs > 0
assert num_queries > 0
Expand Down Expand Up @@ -824,6 +842,7 @@ def forward(
f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa
assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \
f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa
# print("prefill tokens: " + str(num_prefill_tokens) + ", decode tokens:" + str(num_decode_tokens))
query = query.contiguous(
) # Flashinfer requires query to be contiguous
# Query for decode. KV is not needed because it is already cached.
Expand Down
13 changes: 2 additions & 11 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import threading
import time
from multiprocessing.process import BaseProcess
from multiprocessing.sharedctypes import Synchronized
from typing import List, Tuple, Type, Union

import zmq
Expand Down Expand Up @@ -133,13 +132,9 @@ def __init__(
input_path: str,
output_path: str,
ready_path: str,
should_shutdown: Synchronized,
):
super().__init__(vllm_config, executor_class, usage_context)

# Signal from main process to shutdown (multiprocessing.Value).
self.should_shutdown = should_shutdown

# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
Expand Down Expand Up @@ -195,7 +190,6 @@ def make_engine_core_process(
input_path: str,
output_path: str,
ready_path: str,
should_shutdown: Synchronized,
) -> BaseProcess:
# The current process might have CUDA context,
# so we need to spawn a new process.
Expand All @@ -210,7 +204,6 @@ def make_engine_core_process(
"vllm_config": vllm_config,
"executor_class": executor_class,
"usage_context": usage_context,
"should_shutdown": should_shutdown
}
# Run EngineCore busy loop in background process.
proc = context.Process(target=EngineCoreProc.run_engine_core,
Expand Down Expand Up @@ -260,8 +253,8 @@ def signal_handler(signum, frame):
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""

# Loop until we get a shutdown signal.
while not self.should_shutdown:
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
if not self.scheduler.has_unfinished_requests():
while True:
Expand All @@ -272,8 +265,6 @@ def run_busy_loop(self):
except queue.Empty:
self._log_stats()
logger.debug("EngineCore busy loop waiting.")
if self.should_shutdown:
return
except BaseException:
raise

Expand Down
6 changes: 0 additions & 6 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import atexit
import multiprocessing
from typing import List, Union

import msgspec
Expand Down Expand Up @@ -149,21 +148,16 @@ def __init__(
self.input_socket.bind(input_path)

# Start EngineCore in background process.
self.should_shutdown = multiprocessing.Value('b', False, lock=False)
self.proc = EngineCoreProc.make_engine_core_process(
*args,
input_path=input_path,
output_path=output_path,
ready_path=ready_path,
should_shutdown=self.should_shutdown,
**kwargs,
)
atexit.register(self.shutdown)

def shutdown(self):
# Send shutdown signal to background process.
self.should_shutdown = True

# Shut down the zmq context.
self.ctx.destroy(linger=0)

Expand Down
Loading

0 comments on commit c4a05d7

Please sign in to comment.