Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: rickyx <[email protected]>
  • Loading branch information
rickyyx committed Nov 15, 2024
1 parent b3fa9d6 commit b232a45
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 8 deletions.
31 changes: 29 additions & 2 deletions tests/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import time
from typing import List, Optional
from collections import defaultdict
from typing import Any, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Tuple

from vllm import SamplingParams
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.inputs import EncoderDecoderInputs, token_inputs
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob, Sequence, SequenceGroup
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupMetadata)


def create_dummy_prompt(
Expand Down Expand Up @@ -217,3 +220,27 @@ def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int):
seq_group.update_num_computed_tokens(token_chunk_size)
for seq in seq_group.get_seqs():
seq.append_token_id(token_id, {token_id: Logprob(token_id)})


class SchedulerProxy:
"""
A proxy class to forward calls to the scheduler.
"""

def __init__(self, scheduler: Scheduler):
self.scheduler_ = scheduler
self.call_history: Dict[str, List[Any]] = defaultdict(list)

def __getattr__(self, name: str) -> Any:

def wrapper(*args, **kwargs):
result = getattr(self.scheduler_, name)(*args, **kwargs)
self.call_history[name].append((args, kwargs, result))
return result

return wrapper

def last_schedule_ret(
self, ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, Any]:
_, _, ret = self.call_history["schedule"][-1]
return ret
90 changes: 90 additions & 0 deletions tests/prefix_caching/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@

import pytest

from tests.conftest import VllmRunner
from tests.core.utils import SchedulerProxy, create_dummy_prompt
from tests.kernels.utils import override_backend_env_variable
from vllm import SamplingParams, TokensPrompt
from vllm.core.scheduler import Scheduler
from vllm.engine.llm_engine import LLMEngine
from vllm.transformers_utils.tokenizer import get_tokenizer

from ..models.utils import check_outputs_equal
Expand Down Expand Up @@ -188,3 +192,89 @@ def test_unstable_prompt_sequence(
for prompt in UNSTABLE_PROMPT_SEQUENCE:
vllm_model.generate(TokensPrompt(prompt_token_ids=prompt),
SamplingParams(max_tokens=1))


@pytest.mark.parametrize("model", MODELS)
def test_fully_cached_prefill_needs_uncached_token(model):
block_size = 16
max_num_batched_tokens = 16
num_output_tokens = 5
# Make a vllm engine
runner = VllmRunner(
model_name=model,
gpu_memory_utilization=0.7,
enable_chunked_prefill=True,
enforce_eager=True,
enable_prefix_caching=True,
block_size=block_size,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_batched_tokens,
)
engine: LLMEngine = runner.model.llm_engine

scheduler: Scheduler = SchedulerProxy(engine.scheduler[0]) # type: ignore
engine.scheduler[0] = scheduler

# SeqA
seqA_tokens = list(range(2 * block_size))
seqA, seq_groupA = create_dummy_prompt(
request_id="0",
prompt_tokens=seqA_tokens,
max_tokens=num_output_tokens,
block_size=block_size,
)

scheduler.add_seq_group(seq_groupA)

assert seqA.data.get_num_computed_tokens() == 0

# Prefill seqA
while not seqA.is_finished():
engine.step()

# seqB
seqB_tokens = [t + 1 for t in seqA_tokens] # shift by 1
seqB, seq_groupB = create_dummy_prompt(
request_id="1",
prompt_tokens=seqB_tokens,
max_tokens=num_output_tokens,
block_size=block_size,
)

# seqC is the same as seqA
seqC, seq_groupC = create_dummy_prompt(
request_id="2",
prompt_tokens=seqA_tokens,
max_tokens=num_output_tokens,
block_size=block_size,
)

scheduler.add_seq_group(seq_groupB)
scheduler.add_seq_group(seq_groupC)

# Even seqC is fully cached, it should not be prefilled since we
# require at least 1 uncached token.
engine.step()

sched_metas, sched_out, _ = scheduler.last_schedule_ret()
assert len(sched_out.scheduled_seq_groups) == 1
assert (sched_out.scheduled_seq_groups[0].seq_group.request_id ==
seq_groupB.request_id)
assert (sched_out.scheduled_seq_groups[0].token_chunk_size ==
max_num_batched_tokens)

# When seqB is finished, seqC could be prefilled.
while not seqB.is_finished():
engine.step()
sched_metas, sched_out, _ = scheduler.last_schedule_ret()
assert len(sched_out.scheduled_seq_groups) == 1
assert (sched_out.scheduled_seq_groups[0].seq_group.request_id ==
seq_groupB.request_id)

engine.step()
sched_metas, sched_out, _ = scheduler.last_schedule_ret()
assert len(sched_out.scheduled_seq_groups) == 1
assert (sched_out.scheduled_seq_groups[0].seq_group.request_id ==
seq_groupC.request_id)
assert sched_out.scheduled_seq_groups[0].token_chunk_size == len(
seqA_tokens)
18 changes: 12 additions & 6 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,11 +745,10 @@ def _schedule_swapped(
seq_group, SequenceStatus.SWAPPED, enable_chunking,
budget))

if (num_new_tokens_uncached + num_new_tokens_cached == 0
or not budget.can_schedule(
num_new_tokens=num_new_tokens_uncached,
num_new_seqs=num_new_seqs,
)):
if num_new_tokens_uncached == 0 or not budget.can_schedule(
num_new_tokens=num_new_tokens_uncached,
num_new_seqs=num_new_seqs,
):
break

if lora_int_id > 0 and curr_loras is not None:
Expand Down Expand Up @@ -985,7 +984,7 @@ def _schedule_prefills(
break

num_new_seqs = seq_group.get_max_num_running_seqs()
if num_new_tokens == 0 or not budget.can_schedule(
if num_new_tokens_uncached == 0 or not budget.can_schedule(
num_new_tokens=num_new_tokens_uncached,
num_new_seqs=num_new_seqs,
):
Expand Down Expand Up @@ -1728,6 +1727,13 @@ def _get_num_new_uncached_and_cached_tokens(
num_uncached_new_tokens += num_uncached_new_tokens_seq
num_cached_new_tokens += num_cached_new_tokens_seq

if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0:
# For a fully cached hit sequence, we actually need to recompute the
# last token. So we need at least 1 uncached token to schedule.
# See ModelRunner._compute_for_prefix_cache_hit for more details.
num_uncached_new_tokens = 1
num_cached_new_tokens -= 1

if enable_chunking and len(seqs) == 1:
# Chunk if a running request cannot fit in the given budget.
# If number of seq > 1, it means it is doing beam search
Expand Down

0 comments on commit b232a45

Please sign in to comment.