diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu index bd184ee22682e..c3902f4c2a163 100644 --- a/csrc/prepare_inputs/advance_step.cu +++ b/csrc/prepare_inputs/advance_step.cu @@ -95,6 +95,16 @@ __global__ void advance_step_flashinfer_kernel( long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr, int64_t const block_tables_stride, int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) { + int const n_pad = num_seqs - num_queries; + if (n_pad && blockIdx.x == 0) { + // Handle cuda graph padding + int const offset = num_queries; + for (int i = threadIdx.x; i < n_pad; i += blockDim.x) { + input_tokens_ptr[offset + i] = 0; + input_positions_ptr[offset + i] = 0; + slot_mapping_ptr[offset + i] = -1; + } + } int num_query_blocks = div_ceil(num_queries, num_threads); if (blockIdx.x < num_query_blocks) { diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index cc1fd19252019..6fe5e6f76653b 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -5,6 +5,8 @@ import pytest +from tests.kernels.utils import override_backend_env_variable + from ..models.utils import check_logprobs_close, check_outputs_equal MODELS = [ @@ -19,10 +21,11 @@ @pytest.mark.parametrize("tp_size", [1]) @pytest.mark.parametrize("enable_chunked_prefill", [False, True]) @pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) @pytest.mark.parametrize("num_logprobs", [None, 5]) +@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN", "FLASHINFER"]) def test_multi_step_llm( hf_runner, vllm_runner, @@ -36,6 +39,8 @@ def test_multi_step_llm( num_scheduler_steps: int, num_prompts: int, num_logprobs: Optional[int], + attention_backend: str, + monkeypatch, ) -> None: """Test vLLM engine with multi-step scheduling via sync LLM Engine. @@ -63,6 +68,7 @@ def test_multi_step_llm( num_logprobs: corresponds to the `logprobs` argument to the OpenAI completions endpoint; `None` -> 1 logprob returned. """ + override_backend_env_variable(monkeypatch, attention_backend) prompts = example_prompts if len(prompts) < num_prompts: @@ -110,10 +116,11 @@ def test_multi_step_llm( @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("tp_size", [1]) @pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) @pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)]) +@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN", "FLASHINFER"]) def test_multi_step_llm_w_prompt_logprobs( vllm_runner, example_prompts, @@ -126,6 +133,8 @@ def test_multi_step_llm_w_prompt_logprobs( num_prompts: int, num_logprobs: Optional[int], num_prompt_logprobs: Optional[int], + attention_backend: str, + monkeypatch, ) -> None: """Test prompt logprobs with multi-step scheduling via sync LLM Engine. @@ -155,6 +164,7 @@ def test_multi_step_llm_w_prompt_logprobs( note that this argument is not supported by the OpenAI completions endpoint. """ + override_backend_env_variable(monkeypatch, attention_backend) prompts = example_prompts if len(prompts) < num_prompts: @@ -205,6 +215,7 @@ def test_multi_step_llm_w_prompt_logprobs( @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) @pytest.mark.parametrize("num_logprobs", [None, 5]) +@pytest.mark.parametrize("attention_backend", ["FLASH_ATTN"]) def test_multi_step_llm_chunked_prefill_prefix_cache( vllm_runner, example_prompts, @@ -216,6 +227,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( num_scheduler_steps: int, num_prompts: int, num_logprobs: Optional[int], + attention_backend: str, + monkeypatch, ) -> None: """Test vLLM engine with multi-step+"single-step chunked prefill"+APC. @@ -278,6 +291,8 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( # # The Incorrect scheduling behavior - if it occurs - will cause an exception # in the model runner resulting from `do_sample=False`. + override_backend_env_variable(monkeypatch, attention_backend) + assert len(example_prompts) >= 2 challenge_prompts = copy.deepcopy(example_prompts) challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient ' diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index e367468d05d26..5ee69ae71ca31 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -11,6 +11,7 @@ from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper from vllm.vllm_flash_attn import flash_attn_varlen_func + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 except ImportError: BatchDecodeWithPagedKVCacheWrapper = None @@ -22,20 +23,32 @@ import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionState, AttentionType) -from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, - compute_slot_mapping_start_idx, - is_block_tables_empty) +from vllm.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, + AttentionType, +) +from vllm.attention.backends.utils import ( + PAD_SLOT_ID, + compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty, +) from vllm.attention.ops.paged_attn import PagedAttention -from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, - make_tensor_with_pad) +from vllm.utils import ( + async_tensor_h2d, + get_kv_cache_torch_dtype, + make_tensor_with_pad, +) if TYPE_CHECKING: - from vllm.worker.model_runner import (ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata) + from vllm.worker.model_runner import ( + ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata, + ) class FlashInferBackend(AttentionBackend): @@ -112,7 +125,8 @@ def _get_workspace_buffer(self): self._workspace_buffer = torch.empty( FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, - device=self.runner.device) + device=self.runner.device, + ) return self._workspace_buffer def _get_prefill_wrapper(self): @@ -123,8 +137,8 @@ def _get_prefill_wrapper(self): def _get_decode_wrapper(self): if self._decode_wrapper is None: - num_qo_heads = (self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config)) + num_qo_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) num_kv_heads = self.runner.model_config.get_num_kv_heads( self.runner.parallel_config) use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( @@ -132,17 +146,20 @@ def _get_decode_wrapper(self): self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._get_workspace_buffer(), "NHD", - use_tensor_cores=use_tensor_cores) + use_tensor_cores=use_tensor_cores, + ) return self._decode_wrapper @contextmanager def graph_capture(self, max_batch_size: int): self._is_graph_capturing = True self._graph_decode_wrapper = None - self._graph_slot_mapping = torch.full((max_batch_size, ), - PAD_SLOT_ID, - dtype=torch.long, - device=self.runner.device) + self._graph_slot_mapping = torch.full( + (max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device, + ) self._graph_seq_lens = torch.ones(max_batch_size, dtype=torch.int32, device=self.runner.device) @@ -152,7 +169,8 @@ def graph_capture(self, max_batch_size: int): self._graph_indices_buffer = torch.empty( max_batch_size * self.runner.cache_config.num_gpu_blocks, dtype=torch.int32, - device=self.runner.device) + device=self.runner.device, + ) self._graph_indptr_buffer = torch.empty(max_batch_size + 1, dtype=torch.int32, device=self.runner.device) @@ -183,17 +201,21 @@ def graph_capture_get_metadata_for_batch( _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1] _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size] - num_qo_heads = (self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config)) + num_qo_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) num_kv_heads = self.runner.model_config.get_num_kv_heads( self.runner.parallel_config) use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( num_qo_heads // num_kv_heads > 4) - self._graph_decode_wrapper = \ + self._graph_decode_wrapper = ( CUDAGraphBatchDecodeWithPagedKVCacheWrapper( - self._graph_decode_workspace_buffer, _indptr_buffer, - self._graph_indices_buffer, _last_page_len_buffer, "NHD", - use_tensor_cores) + self._graph_decode_workspace_buffer, + _indptr_buffer, + self._graph_indices_buffer, + _last_page_len_buffer, + "NHD", + use_tensor_cores, + )) if self.runner.kv_cache_dtype.startswith("fp8"): kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( self.runner.kv_cache_dtype) @@ -236,7 +258,8 @@ def graph_capture_get_metadata_for_batch( q_data_type=self.runner.model_config.dtype, use_cuda_graph=True, decode_wrapper=self._graph_decode_wrapper, - prefill_wrapper=None) + prefill_wrapper=None, + ) attn_metadata.begin_forward() return attn_metadata @@ -247,19 +270,23 @@ def get_graph_input_buffers(self, "slot_mapping": attn_metadata.slot_mapping, } - def prepare_graph_input_buffers(self, - input_buffers, - attn_metadata, - is_encoder_decoder_model: bool = False): + def prepare_graph_input_buffers( + self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False, + ): return def begin_forward(self, model_input): assert not self._is_graph_capturing state = self - if model_input.attn_metadata.use_cuda_graph: + use_cuda_graph = model_input.attn_metadata.use_cuda_graph + is_decode = model_input.attn_metadata.num_prefills == 0 + if use_cuda_graph and is_decode: batch_size = model_input.input_tokens.shape[0] - state = (self.runner.graph_runners[model_input.virtual_engine] - [batch_size].attn_state) + state = self.runner.graph_runners[ + model_input.virtual_engine][batch_size].attn_state model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( ) model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() @@ -325,11 +352,12 @@ def __post_init__(self): # Refer to # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 supported_head_sizes = FlashInferBackend.get_supported_head_sizes() - if self.head_dim is not None and self.head_dim \ - not in supported_head_sizes: + if (self.head_dim is not None + and self.head_dim not in supported_head_sizes): raise ValueError( f"Only {supported_head_sizes} are supported for head_dim,", - f"received {self.head_dim}.") + f"received {self.head_dim}.", + ) def begin_forward(self): if self.num_prefill_tokens > 0: @@ -362,8 +390,11 @@ def begin_forward(self): self.paged_kv_indptr[:self.num_prefills + 1], self.paged_kv_indices, self.paged_kv_last_page_len[:self.num_prefills], - self.num_qo_heads, self.num_kv_heads, self.head_dim, - self.page_size) + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + ) if self.num_decode_tokens > 0: assert self.paged_kv_indices is not None assert self.paged_kv_indptr is not None @@ -393,7 +424,8 @@ def begin_forward(self): # kv-cache data type. data_type=self.data_type, # query data type. - q_data_type=self.q_data_type) + q_data_type=self.q_data_type, + ) def asdict_zerocopy(self, skip_fields: Optional[Set[str]] = None @@ -402,8 +434,8 @@ def asdict_zerocopy(self, skip_fields = set() # We need to skip the prefill/decode_wrapper field since it cannot be # broadcasted with nccl when TP is enabled. - skip_fields.add('prefill_wrapper') - skip_fields.add('decode_wrapper') + skip_fields.add("prefill_wrapper") + skip_fields.add("decode_wrapper") return super().asdict_zerocopy(skip_fields) @property @@ -418,21 +450,37 @@ def decode_metadata(self) -> Optional["FlashInferMetadata"]: return None return self - def advance_step(self, - model_input: "ModelInputForGPUWithSamplingMetadata", - sampled_token_ids: Optional[torch.Tensor], - block_size: int, - num_seqs: int, - num_queries: int, - turn_prefills_into_decodes: bool = False): + def advance_step( + self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False, + ): """ Update metadata in-place to advance one decode step. """ - assert not turn_prefills_into_decodes, \ - ("Chunked prefill is not supported with flashinfer yet." - "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " - "specific parameter.") + if turn_prefills_into_decodes: + # When Multi-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + # Flashinfer doesn't support speculative decoding + chunked-prefill + # + multi-step scheduling yet. + assert self.decode_query_len == 1 + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens_tensor is not None assert num_seqs > 0 assert num_queries > 0 @@ -462,7 +510,8 @@ def advance_step(self, paged_kv_indices=self.paged_kv_indices, paged_kv_indptr=self.paged_kv_indptr, paged_kv_last_page_len=self.paged_kv_last_page_len, - block_table_bound=self.block_table_bound) + block_table_bound=self.block_table_bound, + ) class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): @@ -506,8 +555,10 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.is_profile_run: bool = False def _add_seq_group( - self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", - chunked_prefill_enabled: bool): + self, + inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, + ): """Add a sequence group to the metadata. Specifically update/append 1. context length. 2. block table. @@ -517,12 +568,23 @@ def _add_seq_group( block_tables = inter_data.block_tables computed_block_nums = inter_data.computed_block_nums - for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block) in zip( - inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], - inter_data.orig_seq_lens, inter_data.seq_lens, - inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks): + for ( + seq_id, + token_len, + seq_len, + curr_seq_len, + query_len, + context_len, + curr_sliding_window_block, + ) in zip( + inter_data.seq_ids, + [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, + inter_data.seq_lens, + inter_data.query_lens, + inter_data.context_lens, + inter_data.curr_sliding_window_blocks, + ): self.context_lens.append(context_len) if is_prompt: mm_maps = inter_data.multi_modal_placeholder_maps @@ -534,9 +596,10 @@ def _add_seq_group( self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) + assert ( + query_len == 1 + ), "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len) self.num_decode_tokens += query_len self.curr_seq_lens.append(curr_seq_len) @@ -547,8 +610,8 @@ def _add_seq_group( block_table = [] if inter_data.prefix_cache_hit: block_table = computed_block_nums - elif ((chunked_prefill_enabled or not is_prompt) - and block_tables is not None): + elif (chunked_prefill_enabled + or not is_prompt) and block_tables is not None: block_table = block_tables[seq_id][-curr_sliding_window_block:] self.block_tables.append(block_table) @@ -558,9 +621,16 @@ def _add_seq_group( start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, context_len, self.sliding_window) - compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, - seq_len, context_len, start_idx, - self.block_size, inter_data.block_tables) + compute_slot_mapping( + is_profile_run, + self.slot_mapping, + seq_id, + seq_len, + context_len, + start_idx, + self.block_size, + inter_data.block_tables, + ) # It is not necessary to add paged_kv_indices, paged_kv_indptr, # and paged_kv_last_page_len for profile run because we will @@ -579,9 +649,9 @@ def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): # If seq_len = 15, block_size = 16, # block_table_bound is 0 + 1 with 1 valid block. self.total_blocks += len(block_table) - block_table_bound = seq_len // self.block_size + 1 \ - if seq_len % self.block_size != 0 \ - else seq_len // self.block_size + block_table_bound = (seq_len // self.block_size + 1 + if seq_len % self.block_size != 0 else seq_len // + self.block_size) self.paged_kv_indices.extend(block_table[:block_table_bound]) self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + block_table_bound) @@ -591,8 +661,36 @@ def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): last_page_len = self.block_size self.paged_kv_last_page_len.append(last_page_len) - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int): + def _get_graph_runner_block_tables(self, num_seqs: int, + block_tables: List[List[int]]): + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + graph_block_tables = self.runner.graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + + def build( + self, + seq_lens: List[int], + query_lens: List[int], + cuda_graph_pad_size: int, + batch_size: int, + ): """Build attention metadata with on-device tensors. Args: @@ -612,31 +710,15 @@ def build(self, seq_lens: List[int], query_lens: List[int], max_prefill_seq_len = max(self.prefill_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens decode_query_len = max(query_lens[self.num_prefills:], default=1) + num_seqs = len(seq_lens) if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size) num_decode_tokens = batch_size - self.num_prefill_tokens - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - input_block_tables = self.runner.graph_block_tables[:batch_size] - max_blocks = input_block_tables.shape[1] - for i, block_table in enumerate(self.block_tables): - if block_table: - num_blocks = len(block_table) - if num_blocks <= max_blocks: - input_block_tables[i, :num_blocks] = block_table - else: - # It may be possible to have more blocks allocated due - # to lookahead slots of multi-step, however, they are - # not used anyway, so can be safely ignored. - input_block_tables[ - i, :max_blocks] = block_table[:max_blocks] - - block_tables = torch.from_numpy(input_block_tables).to( - device, non_blocking=True) - + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) last_paged_kv_indptr = self.paged_kv_indptr[-1] self.paged_kv_indptr.extend([last_paged_kv_indptr] * cuda_graph_pad_size) @@ -667,14 +749,18 @@ def build(self, seq_lens: List[int], query_lens: List[int], for modality, placeholder_map in self.multimodal_placeholder_maps.items() } - torch.cumsum(seq_lens_tensor, - dim=0, - dtype=seq_start_loc.dtype, - out=seq_start_loc[1:]) - torch.cumsum(query_lens_tensor, - dim=0, - dtype=query_start_loc.dtype, - out=query_start_loc[1:]) + torch.cumsum( + seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:], + ) + torch.cumsum( + query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:], + ) if len(self.paged_kv_indptr) > 0: # extend to the maximum number of blocks as returned by the @@ -732,7 +818,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], data_type=kv_cache_dtype, q_data_type=self.runner.model_config.dtype, use_cuda_graph=use_captured_graph, - is_profile_run=self.is_profile_run) + is_profile_run=self.is_profile_run, + ) class FlashInferImpl(AttentionImpl): @@ -776,7 +863,6 @@ def forward( attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # TODO: directly write to output tensor if attn_type != AttentionType.DECODER: @@ -820,10 +906,12 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa - assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + assert ( + key.shape[0] == num_prefill_tokens + num_decode_tokens + ), f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + assert ( + value.shape[0] == num_prefill_tokens + num_decode_tokens + ), f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa query = query.contiguous( ) # Flashinfer requires query to be contiguous # Query for decode. KV is not needed because it is already cached. @@ -870,7 +958,8 @@ def forward( causal=True, k_scale=k_scale, v_scale=v_scale, - window_left=window_left) + window_left=window_left, + ) if decode_meta := attn_metadata.decode_metadata: assert decode_meta is not None assert decode_meta.decode_wrapper is not None @@ -881,7 +970,8 @@ def forward( logits_soft_cap=logits_soft_cap, k_scale=k_scale, v_scale=v_scale, - window_left=window_left) + window_left=window_left, + ) if prefill_output is None and decode_output is not None: # Decode only batch. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2b545d1b28bd2..907234a9f4b16 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -791,7 +791,7 @@ def _get_cuda_graph_pad_size(self, is_mscp: bool = self.runner.scheduler_config.is_multi_step and \ self.runner.scheduler_config.chunked_prefill_enabled decode_only = self.decode_only or is_mscp - if not decode_only: + if not decode_only or self.runner.is_profile_run: # Early exit so we can treat num_seqs as the batch_size below. return -1 @@ -1028,6 +1028,8 @@ def __init__( self.has_inner_state = model_config.has_inner_state + self.is_profile_run = False + # When using CUDA graph, the input block tables must be padded to # max_seq_len_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table @@ -1229,6 +1231,7 @@ def _prepare_model_input_tensors( @torch.inference_mode() def profile_run(self) -> None: + self.is_profile_run = True # Enable top-k sampling to reflect the accurate memory usage. sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens @@ -1330,6 +1333,7 @@ def profile_run(self) -> None: self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() + self.is_profile_run = False return def remove_all_loras(self): diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 18b03bf1bfb56..ad70162209a69 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -32,7 +32,7 @@ MULTI_STEP_ATTENTION_BACKENDS = [ "FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION" ] -MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"] +MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN", "FLASHINFER"] def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ -> List[str]: