From 6d646d08a2e0e73e83e313a5ae470c1f9e4f200e Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Tue, 3 Sep 2024 14:50:29 -0400 Subject: [PATCH] [Core] Optimize Async + Multi-step (#8050) --- .../multi_step/test_correctness_async_llm.py | 4 +- vllm/engine/async_llm_engine.py | 109 +++++---- vllm/engine/llm_engine.py | 222 ++++++++---------- vllm/engine/output_processor/multi_step.py | 62 +++-- vllm/sequence.py | 4 +- vllm/worker/model_runner.py | 4 +- vllm/worker/multi_step_model_runner.py | 165 ++++++++++--- vllm/worker/multi_step_worker.py | 4 +- 8 files changed, 326 insertions(+), 248 deletions(-) diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index d054ca341694a..0cbe8371e235a 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -103,13 +103,13 @@ async def test_multi_step( model, server_args + distributed_args, num_logprobs, - max_wait_seconds=3 * 240) + max_wait_seconds=5 * 240) test_completions = await completions_with_server_args( prompts, model, ms_server_args + distributed_args, num_logprobs, - max_wait_seconds=3 * 240) + max_wait_seconds=5 * 240) # Assert multi-step scheduling produces identical tokens # to single-step scheduling. diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 159281dabde4a..7fe8053fffb7b 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -280,40 +280,27 @@ async def step_async( scheduler_outputs = cached_outputs.scheduler_outputs allow_async_output_proc = cached_outputs.allow_async_output_proc - # Detect async + multi-step - use_async_and_multi_step = (self.scheduler_config.is_multi_step - and allow_async_output_proc) - ctx = self.scheduler_contexts[virtual_engine] + # Clear outputs for each new scheduler iteration + ctx.request_outputs.clear() + # skip the scheduler if there are any remaining steps in the seq groups. # This ensures that the scheduler is only called again when the current # batch has completed. if not self._has_remaining_steps(seq_group_metadata_list): - # Clear outputs on scheduler iteration start - ctx.request_outputs.clear() - # Schedule iteration (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc ) = self.scheduler[virtual_engine].schedule() - # Detect async + multi-step - use_async_and_multi_step = (self.scheduler_config.is_multi_step - and allow_async_output_proc) + ctx.seq_group_metadata_list = seq_group_metadata_list + ctx.scheduler_outputs = scheduler_outputs # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: - self._process_model_outputs(virtual_engine=virtual_engine, - is_async=True) - - # For async + multi-step, init the queue - if use_async_and_multi_step: - assert len(ctx.output_queue) == 0 - assert seq_group_metadata_list is not None - ctx.output_queue.append( - (None, seq_group_metadata_list, scheduler_outputs)) + self._process_model_outputs(ctx=ctx) if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): @@ -351,26 +338,20 @@ async def step_async( last_sampled_token_ids=last_sampled_token_ids) if allow_async_output_proc: - async_callback = self.async_callback_multi_step[ - virtual_engine] if use_async_and_multi_step \ - else self.async_callback[virtual_engine] - - execute_model_req.async_callback = async_callback - execute_model_req.use_async_and_multi_step = \ - use_async_and_multi_step + execute_model_req.async_callback = self.async_callbacks[ + virtual_engine] # Execute the model. output = await self.model_executor.execute_model_async( execute_model_req) + # we need to do this here so that last step's sampled_token_ids can # be passed to the next iteration for PP. if self.scheduler_config.is_multi_step: self._update_cached_scheduler_output(virtual_engine, output) else: - if not use_async_and_multi_step and len(ctx.output_queue) > 0: - assert not self.scheduler_config.is_multi_step - self._process_model_outputs(virtual_engine=virtual_engine, - is_async=True) + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) output = [] # Finish the current step for all the sequence groups. @@ -384,24 +365,22 @@ async def step_async( self.cached_scheduler_outputs[ virtual_engine] = SchedulerOutputState() - if use_async_and_multi_step: - # For async + multi-step, clear the queue - ctx.output_queue.clear() - else: - ctx.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs)) + is_async = allow_async_output_proc + is_last_step = True + ctx.output_queue.append( + (output, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step)) - if output and allow_async_output_proc: - assert len( - output - ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 - self._advance_to_next_step( - output[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) + if output and allow_async_output_proc: + assert len( + output + ) == 1, "Async postprocessor expects only a single output set" + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) if not allow_async_output_proc: - self._process_model_outputs(virtual_engine=virtual_engine, - is_async=False) + self._process_model_outputs(ctx=ctx) # Log stats. self.do_log_stats(scheduler_outputs, output) @@ -411,17 +390,12 @@ async def step_async( else: # Multi-step case - if use_async_and_multi_step: - return [] - else: - ctx.request_outputs = [] + return ctx.request_outputs if not self.has_unfinished_requests(): # Drain async postprocessor (if exists) if len(ctx.output_queue) > 0: - assert not self.scheduler_config.is_multi_step - self._process_model_outputs(virtual_engine=virtual_engine, - is_async=True) + self._process_model_outputs(ctx=ctx) assert len(ctx.output_queue) == 0 return ctx.request_outputs @@ -640,6 +614,17 @@ def __init__(self, self.log_requests = log_requests self.engine = self._init_engine(*args, **kwargs) + # This ensures quick processing of request outputs + # so the append to asyncio queues is not delayed, + # especially for multi-step. + # + # TODO: Currently, disabled for engine_use_ray, ask + # Cody/Will/Woosuk about this case. + self.use_process_request_outputs_callback = not self.engine_use_ray + if self.use_process_request_outputs_callback: + self.engine.process_request_outputs_callback = \ + self.process_request_outputs + if self.engine_use_ray: print_warning_once( "DEPRECATED. `--engine-use-ray` is deprecated and will " @@ -883,13 +868,27 @@ async def engine_step(self, virtual_engine: int) -> bool: request_outputs = await self.engine.step_async(virtual_engine) # Put the outputs into the corresponding streams. - finished = True + # If used as a callback, then already invoked inside + # LLMEngine's _process_model_outputs + if not self.use_process_request_outputs_callback: + all_finished = self.process_request_outputs(request_outputs) + else: + # For callback case, we only need to detect when all + # requests are finished + all_finished = all(request_output.finished + for request_output in request_outputs) + + return not all_finished + + def process_request_outputs(self, request_outputs) -> bool: + # Put the outputs into the corresponding streams. + all_finished = True for request_output in request_outputs: self._request_tracker.process_request_output( request_output, verbose=self.log_requests) - finished = finished and request_output.finished + all_finished = all_finished and request_output.finished - return not finished + return all_finished async def _engine_abort(self, request_ids: Iterable[str]): if self.engine_use_ray: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1eab83f3b9889..8c5ca81fb1905 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -93,13 +93,14 @@ class SchedulerOutputState: @dataclass class SchedulerContext: output_queue: Deque[Tuple[Optional[List[SamplerOutput]], - List[SequenceGroupMetadata], - SchedulerOutputs]] = field( - default_factory=lambda: deque()) - + List[SequenceGroupMetadata], SchedulerOutputs, + bool, + bool]] = field(default_factory=lambda: deque()) request_outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = field( default_factory=lambda: []) + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + scheduler_outputs: Optional[SchedulerOutputs] = None class LLMEngine: @@ -357,6 +358,26 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: # different process. self.tokenizer.ping() + self.cached_scheduler_outputs = [ + SchedulerOutputState() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + self.scheduler_contexts = [ + SchedulerContext() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + self.async_callbacks = [ + functools.partial(self._process_model_outputs, + ctx=self.scheduler_contexts[v_id]) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + + # Currently used by AsyncLLMEngine to ensure quick append + # of request outputs to asyncio queues + self.process_request_outputs_callback = None + # Create the scheduler. # NOTE: the cache_config here have been updated with the numbers of # GPU and CPU blocks, which are profiled in the distributed executor. @@ -364,9 +385,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: Scheduler( scheduler_config, cache_config, lora_config, parallel_config.pipeline_parallel_size, - functools.partial(self._process_model_outputs, - virtual_engine=v_id, - is_async=True) + self.async_callbacks[v_id] if model_config.use_async_output_proc else None) for v_id in range(parallel_config.pipeline_parallel_size) ] @@ -417,30 +436,6 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: ), )) - self.cached_scheduler_outputs = [ - SchedulerOutputState() - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - - self.scheduler_contexts = [ - SchedulerContext() - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - - self.async_callback = [ - functools.partial(self._process_model_outputs, - virtual_engine=v_id, - is_async=True) - for v_id in range(self.parallel_config.pipeline_parallel_size) - ] - - self.async_callback_multi_step = [ - functools.partial(self._process_model_outputs, - virtual_engine=v_id, - is_async=False) - for v_id in range(self.parallel_config.pipeline_parallel_size) - ] - def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -1249,11 +1244,7 @@ def _process_sequence_group_outputs( return - def _process_model_outputs(self, - virtual_engine: int, - is_async: bool, - sampler_output: Optional[SamplerOutput] = None, - is_last_output: bool = False) -> None: + def _process_model_outputs(self, ctx: SchedulerContext) -> None: """Apply the model output to the sequences in the scheduled seq groups. virtual_engine: The engine id to operate on @@ -1273,24 +1264,12 @@ def _process_model_outputs(self, """ now = time.time() - is_multi_step = sampler_output is not None - - ctx: SchedulerContext = self.scheduler_contexts[virtual_engine] - if len(ctx.output_queue) == 0: return None - if is_multi_step: - # Async + multi-step case - (outputs, seq_group_metadata_list, - scheduler_outputs) = ctx.output_queue[0] - assert outputs is None - outputs = [sampler_output] - else: - # Async standard case - (outputs, seq_group_metadata_list, - scheduler_outputs) = ctx.output_queue.popleft() - + # Get pending async postprocessor + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step) = ctx.output_queue.popleft() assert outputs is not None # Sanity check @@ -1306,6 +1285,7 @@ def _process_model_outputs(self, outputs_by_sequence_group = outputs finished_before: List[int] = [] + finished_now: List[int] = [] for i, seq_group_meta in enumerate(seq_group_metadata_list): scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] @@ -1343,26 +1323,44 @@ def _process_model_outputs(self, if self.model_config.embedding_mode: self._process_sequence_group_outputs(seq_group, output) - continue + else: + self.output_processor.process_prompt_logprob(seq_group, output) + if seq_group_meta.do_sample: + self.output_processor.process_outputs( + seq_group, output, is_async) - self.output_processor.process_prompt_logprob(seq_group, output) - if seq_group_meta.do_sample: - self.output_processor.process_outputs(seq_group, output, - is_async) + if seq_group.is_finished(): + finished_now.append(i) - # For async + multi-step, free finished seqs and create outputs - # only on the final step. - if is_multi_step and not is_last_output: - return + # Generate outputs for the requests that finished this iteration + for i in finished_now: + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - for scheduler in self.scheduler: - scheduler.free_finished_seq_groups() + seq_group = scheduled_seq_group.seq_group + seq_group.maybe_set_first_token_time(now) + request_output = RequestOutputFactory.create(seq_group) + ctx.request_outputs.append(request_output) - # Create the outputs. - for i, _ in enumerate(seq_group_metadata_list): - scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + # Free currently finished requests + if finished_now: + for scheduler in self.scheduler: + scheduler.free_finished_seq_groups() + + # For multi-step, do not create outputs each iteration + if not is_last_step: + # Immediately process request outputs here (if callback is given) + if (finished_now + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + return + + # Create the outputs + # Note: scheduled_seq_groups and seq_group_metadata_list + # must match with the indices + for i, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): - if not is_multi_step and i in finished_before: + if i in finished_before or i in finished_now: continue # Avoids double processing seq_group = scheduled_seq_group.seq_group @@ -1376,11 +1374,15 @@ def _process_model_outputs(self, request_output = RequestOutputFactory.create(seq_group) ctx.request_outputs.append(request_output) - # For async + multi-step, do stats only on the last output. - # Otherwise, do stats if the execution is async - do_stats = is_multi_step or is_async + # Immediately process request outputs here (if callback is given) + if (ctx.request_outputs + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) - if do_stats: + # For async case, we need to record the stats here. + # For non-async case, the stats are done in the + # LLMEngine/AsyncLLMEngine directly + if is_async: # Log stats. self.do_log_stats(scheduler_outputs, outputs, finished_before) @@ -1485,40 +1487,26 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: scheduler_outputs = cached_outputs.scheduler_outputs allow_async_output_proc = cached_outputs.allow_async_output_proc - # Detect async + multi-step - use_async_and_multi_step = (self.scheduler_config.is_multi_step - and allow_async_output_proc) - ctx = self.scheduler_contexts[virtual_engine] + # Clear outputs for each new scheduler iteration + ctx.request_outputs.clear() + # Skip the scheduler if there are any remaining steps in the seq groups. # This ensures that the scheduler is only called again when the current # batch has completed. if not self._has_remaining_steps(seq_group_metadata_list): - - # Clear outputs on scheduler iteration start - ctx.request_outputs.clear() - # Schedule iteration (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc ) = self.scheduler[virtual_engine].schedule() - # Detect async + multi-step - use_async_and_multi_step = (self.scheduler_config.is_multi_step - and allow_async_output_proc) + ctx.seq_group_metadata_list = seq_group_metadata_list + ctx.scheduler_outputs = scheduler_outputs # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: - self._process_model_outputs(virtual_engine=virtual_engine, - is_async=True) - - # For async + multi-step, init the queue - if use_async_and_multi_step: - assert len(ctx.output_queue) == 0 - assert seq_group_metadata_list is not None - ctx.output_queue.append( - (None, seq_group_metadata_list, scheduler_outputs)) + self._process_model_outputs(ctx=ctx) if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): @@ -1555,13 +1543,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: last_sampled_token_ids=last_sampled_token_ids) if allow_async_output_proc: - async_callback = self.async_callback_multi_step[ - virtual_engine] if use_async_and_multi_step \ - else self.async_callback[virtual_engine] - - execute_model_req.async_callback = async_callback - execute_model_req.use_async_and_multi_step = \ - use_async_and_multi_step + execute_model_req.async_callback = self.async_callbacks[ + virtual_engine] output = self.model_executor.execute_model( execute_model_req=execute_model_req) @@ -1573,10 +1556,8 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: else: # Nothing scheduled => If there is pending async postprocessor, # then finish it here. - if not use_async_and_multi_step and len(ctx.output_queue) > 0: - assert not self.scheduler_config.is_multi_step - self._process_model_outputs(virtual_engine=virtual_engine, - is_async=True) + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) # No outputs in this case output = [] @@ -1590,28 +1571,24 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[0] = SchedulerOutputState() - if use_async_and_multi_step: - # For async + multi-step, clear the queue - ctx.output_queue.clear() - else: - # Add results to the output_queue - # (for async or non-async postprocessing) - ctx.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs)) + # Add results to the output_queue + is_async = allow_async_output_proc + is_last_step = True + ctx.output_queue.append( + (output, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step)) - if output and allow_async_output_proc: - assert len(output) == 1, ( - "Multi step decoding does not work " - "with async output processing.") + if output and allow_async_output_proc: + assert len(output) == 1, ( + "Async postprocessor expects only a single output set") - self._advance_to_next_step( - output[0], seq_group_metadata_list, - scheduler_outputs.scheduled_seq_groups) + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) # Check if need to run the usual non-async path if not allow_async_output_proc: - self._process_model_outputs(virtual_engine=virtual_engine, - is_async=False) + self._process_model_outputs(ctx=ctx) # Log stats. self.do_log_stats(scheduler_outputs, output) @@ -1620,17 +1597,12 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: self.do_tracing(scheduler_outputs) else: # Multi-step case - if use_async_and_multi_step: - return [] - else: - ctx.request_outputs = [] + return ctx.request_outputs if not self.has_unfinished_requests(): # Drain async postprocessor (if exists) if len(ctx.output_queue) > 0: - assert not self.scheduler_config.is_multi_step - self._process_model_outputs(virtual_engine=virtual_engine, - is_async=True) + self._process_model_outputs(ctx=ctx) assert len(ctx.output_queue) == 0 # Stop the execute model loop in parallel workers until there are diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index e182cee8ba18e..c73db765fc3b5 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -85,9 +85,6 @@ def process_outputs(self, no tokens need to be appended since it is already done externally (before the next schedule() call) """ - # TODO: Add support for async if necessary - assert not is_async - # Sequences can be in RUNNING or FINISHED_ABORTED state # once scheduled, as a sequence is moved to FINSIHED_ABORTED # if a client disconnects from the api server. @@ -101,19 +98,41 @@ def process_outputs(self, "Beam search not supported in multi-step decoding.") seq = seqs[0] - # Since there's only one sequence per sequence group, we can take the - # first sample. - samples = [output.samples[0] for output in outputs] - - # -1 means the output token is not valid (eg. due to spec decode - # rejecting tokens). - valid_samples = [ - sample for sample in samples if sample.output_token != -1 - ] - assert valid_samples - - self._process_seq_outputs(seq, valid_samples, - sequence_group.sampling_params) + if is_async: + # Async case: We process tokens one by one. Here, we know the token + # was already appended, so we only need to do the rest of the + # postprocessor: Detokenization + stopping logic + self._process_decode_and_stop(seq, sequence_group.sampling_params) + else: + # Standard multi-step case + + # Since there's only one sequence per sequence group, + # we can take the first sample. + samples = [output.samples[0] for output in outputs] + + # -1 means the output token is not valid (eg. due to spec decode + # rejecting tokens). + valid_samples = [ + sample for sample in samples if sample.output_token != -1 + ] + assert valid_samples + + self._process_seq_outputs(seq, valid_samples, + sequence_group.sampling_params) + + def _process_decode_and_stop(self, seq: Sequence, + sampling_params: SamplingParams) -> None: + new_char_count = 0 + if sampling_params.detokenize: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + + # TODO(sang): Support lora. + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count=new_char_count, + sampling_params=sampling_params, + ) def _process_seq_outputs(self, seq: Sequence, valid_samples: List[SequenceOutput], @@ -151,16 +170,7 @@ def _process_seq_outputs(self, seq: Sequence, logprobs=output_logprob, ) - new_char_count = 0 - if sampling_params.detokenize: - new_char_count = self.detokenizer.decode_sequence_inplace( - seq, sampling_params) + self._process_decode_and_stop(seq, sampling_params) - # TODO(sang): Support lora. - self.stop_checker.maybe_stop_sequence( - seq, - new_char_count=new_char_count, - sampling_params=sampling_params, - ) if seq.is_finished(): break diff --git a/vllm/sequence.py b/vllm/sequence.py index 87b3d21fa7ae3..a5ebf152ce776 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1225,7 +1225,6 @@ class ExecuteModelRequest( last_sampled_token_ids: Optional[torch.Tensor] = None # Async callback async_callback: Optional[Callable] = None - use_async_and_multi_step: bool = False @property def is_first_multi_step(self) -> bool: @@ -1272,5 +1271,4 @@ def clone( finished_requests_ids=self.finished_requests_ids, last_sampled_token_ids=self.last_sampled_token_ids.clone() if self.last_sampled_token_ids is not None else None, - async_callback=self.async_callback, - use_async_and_multi_step=self.use_async_and_multi_step) + async_callback=self.async_callback) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 8a3c99a45b149..74f7d4e0860d3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -21,6 +21,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) +from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture from vllm.inputs import INPUT_REGISTRY, InputRegistry @@ -96,7 +97,8 @@ class ModelInputForGPU(ModelRunnerInputBase): finished_requests_ids: Optional[List[str]] = None virtual_engine: int = 0 async_callback: Optional[Callable] = None - use_async_and_multi_step: bool = False + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + scheduler_outputs: Optional[SchedulerOutputs] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index be0c75bc00dbd..b52f2a07e344e 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -22,6 +22,7 @@ get_pythonized_sample_results) from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, Logprob, SequenceGroupMetadata, SequenceOutput) +from vllm.utils import PyObjectCache from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ( @@ -37,6 +38,29 @@ logger = init_logger(__name__) +def seq_output_builder(): + return SequenceOutput( + 0, 0, + {0: Logprob(logprob=float('inf'), rank=None, decoded_token=None)}) + + +def completion_seq_group_output_builder(): + return CompletionSequenceGroupOutput([], None) + + +# Used by pythonization to reduce python object allocations +class PythonizationCache: + + def __init__(self): + self.cached_seq_output = PyObjectCache(seq_output_builder) + self.cached_completion_seq_group_output = PyObjectCache( + completion_seq_group_output_builder) + + def reset(self): + self.cached_seq_output.reset() + self.cached_completion_seq_group_output.reset() + + @dataclass class ModelOutput: """The output of a single model forward pass. @@ -59,6 +83,7 @@ class ModelOutput: pythonized: bool = False # On-device tensor containing the logprobs of each token. logprobs: Optional["torch.Tensor"] = None + pythonization_cache: Optional[PythonizationCache] = None def pythonize(self, input_metadata: "StatefulModelInput", copy_stream: torch.cuda.Stream, @@ -97,7 +122,8 @@ def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput", with torch.cuda.stream(copy_stream): _pythonize_sampler_output(input_metadata, self.sampler_output, pinned_sampled_token_buffer, - self.sampled_token_ids, self.logprobs) + self.sampled_token_ids, self.logprobs, + self.pythonization_cache) # Erase the logprobs GPU-side tensor. # Note that although _pythonize_sampler_output() runs in its @@ -209,6 +235,8 @@ def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): self._copy_stream = torch.cuda.Stream() self.pinned_sampled_token_ids: Optional[torch.Tensor] = None + self.pythonization_cache = PythonizationCache() + def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any]) -> StatefulModelInput: model_input = (StatefulModelInput.from_broadcasted_tensor_dict( @@ -237,14 +265,22 @@ def _async_process_outputs(self, model_input: StatefulModelInput, output_proc_callback: Callable): # Proceed with pythonization and output_proc in order. # Stop on the first one that fails to pythonize + output_proc_callback() + cont = True for model_output in model_input.cached_outputs: if not model_output.pythonized: model_output.maybe_pythonize(model_input, self._copy_stream, self.pinned_sampled_token_ids) if model_output.pythonized: - output_proc_callback( - sampler_output=model_output.sampler_output) + ctx = output_proc_callback.keywords["ctx"] + is_async = False + is_last_step = False + ctx.output_queue.append( + ([model_output.sampler_output + ], ctx.seq_group_metadata_list, + ctx.scheduler_outputs, is_async, is_last_step)) + output_proc_callback() else: cont = False @@ -255,21 +291,46 @@ def _final_process_outputs(self, model_input: StatefulModelInput, output_proc_callback: Optional[Callable]): assert model_input.frozen_model_input is not None + has_async_callback = output_proc_callback is not None + outputs = [] for output_id in range(len(model_input.cached_outputs)): - is_last_output = output_id == len(model_input.cached_outputs) - 1 - output = model_input.cached_outputs[output_id] - if not output.pythonized: + is_last_step = output_id == len(model_input.cached_outputs) - 1 + + # For non-async case: + # -- We simply add the outputs + # For async case: + # -- Invoke callback, pythonize, add to callback queue and repeat + # -- For last output, just add to callback queue + if has_async_callback: + assert output_proc_callback is not None + + # Invoke callback before pythonize (to overlap with GPU) + output_proc_callback() + + # Pythonize + if not output.pythonized: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + + # For non last step, add to callback queue to chain + # callbacks=>pythonize pairs (for GPU overlap) + if not is_last_step: + ctx = output_proc_callback.keywords[ # type: ignore + "ctx"] # type: ignore + is_async = False + is_last_step = False + ctx.output_queue.append( + ([output.sampler_output + ], ctx.seq_group_metadata_list, + ctx.scheduler_outputs, is_async, is_last_step)) + else: + outputs.append(output.sampler_output) + else: output.pythonize(model_input, self._copy_stream, self.pinned_sampled_token_ids) - - if model_input.frozen_model_input.use_async_and_multi_step: - assert output_proc_callback is not None - output_proc_callback(sampler_output=output.sampler_output, - is_last_output=is_last_output) - - outputs.append(output.sampler_output) + outputs.append(output.sampler_output) return outputs @@ -330,7 +391,7 @@ def execute_model( model_input, model_input.cached_outputs[-1].sampler_output) output_proc_callback = None - if frozen_model_input.use_async_and_multi_step: + if frozen_model_input.async_callback is not None: output_proc_callback = frozen_model_input.async_callback assert output_proc_callback is not None async_callback = functools.partial( @@ -367,7 +428,7 @@ def execute_model( model_input.cached_outputs.append( ModelOutput(output[0], output_ready_event, output[0].sampled_token_ids, False, - output[0].logprobs)) + output[0].logprobs, self.pythonization_cache)) # These GPU tensors are not required by multi-step; # erase them to ensure they are not pythonized or @@ -378,7 +439,7 @@ def execute_model( # Pythonize the output if CPU is ahead and the previous step is # ready. - if not frozen_model_input.use_async_and_multi_step: + if frozen_model_input.async_callback is None: for model_output in model_input.cached_outputs: model_output.maybe_pythonize(model_input, self._copy_stream, @@ -397,6 +458,7 @@ def execute_model( if model_input.is_last_step: outputs = self._final_process_outputs(model_input, output_proc_callback) + self.pythonization_cache.reset() return outputs # should be [SamplerOutput] @@ -537,6 +599,7 @@ def _pythonize_sampler_output( pinned_sampled_token_buffer: torch.Tensor, sampled_token_ids: torch.Tensor, logprobs_tensor: Optional[torch.Tensor], + cache: Optional[PythonizationCache], ) -> None: """ This function is only called when the output tensors are ready. See :class:`ModelOutput`. @@ -597,6 +660,9 @@ def _pythonize_sampler_output( for sgdx, (seq_group, sample_result) in enumerate(zip(seq_groups, samples_list)): + if seq_group.sampling_params.logits_processors: + assert len(seq_group.sampling_params.logits_processors) == 0, ( + "Logits Processors are not supported in multi-step decoding") if do_pythonize_logprobs: assert prompt_logprobs is not None @@ -621,23 +687,56 @@ def _pythonize_sampler_output( seq_ids = seq_group.seq_ids next_token_ids = sample_result parent_ids = [0] - seq_outputs: List[SequenceOutput] = [] - if seq_group.sampling_params.logits_processors: - assert len(seq_group.sampling_params.logits_processors) == 0, ( - "Logits Processors are not supported in multi-step decoding") + + if cache is not None: + completion_seq_group_output: CompletionSequenceGroupOutput = \ + cache.cached_completion_seq_group_output.get_object() + completion_seq_group_output.samples.clear() + seq_outputs: List[ + SequenceOutput] = completion_seq_group_output.samples + else: + seq_outputs = [] + for tdx, (parent_id, next_token_id) in enumerate(zip(parent_ids, next_token_ids)): - seq_outputs.append( - SequenceOutput(seq_ids[parent_id], next_token_id, - (group_sample_logprobs[tdx] - if logprobs_are_requested else { - next_token_id: - Logprob(logprob=float('inf'), - rank=None, - decoded_token=None) - }))) - output.outputs.append( - CompletionSequenceGroupOutput( - seq_outputs, - (group_prompt_logprobs if logprobs_are_requested else None))) + if cache is not None: + seq_output: SequenceOutput = cache.cached_seq_output.get_object( + ) + seq_output.parent_seq_id = seq_ids[parent_id] + seq_output.output_token = next_token_id + + if logprobs_are_requested: + seq_output.logprobs = group_sample_logprobs[tdx] + else: + logprobs = next(iter(seq_output.logprobs.values())) + seq_output.logprobs.clear() + + logprobs.logprob = float('inf') + logprobs.rank = None + logprobs.decoded_token = None + + seq_output.logprobs[next_token_id] = logprobs + + seq_outputs.append(seq_output) + + else: + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, + (group_sample_logprobs[tdx] + if logprobs_are_requested else { + next_token_id: + Logprob(logprob=float('inf'), + rank=None, + decoded_token=None) + }))) + if cache is not None: + completion_seq_group_output.prompt_logprobs = \ + group_prompt_logprobs if logprobs_are_requested else None + output.outputs.append(completion_seq_group_output) + else: + output.outputs.append( + CompletionSequenceGroupOutput( + seq_outputs, (group_prompt_logprobs + if logprobs_are_requested else None))) + assert len(output.outputs) > 0 diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index 517b0ab78c460..562285f828cc7 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -67,9 +67,7 @@ def _get_driver_input_and_broadcast( if execute_model_req.async_callback: model_input.frozen_model_input = dataclasses.replace( # type: ignore model_input.frozen_model_input, - async_callback=execute_model_req.async_callback, - use_async_and_multi_step=execute_model_req. - use_async_and_multi_step) + async_callback=execute_model_req.async_callback) else: # on subsequent steps we reuse the worker input and model input multi_step_state = self.multi_step_states[virtual_engine]