From f092153fbe349a9a1742940e3703bfcff6aa0a6d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 11 Dec 2024 23:14:20 -0800 Subject: [PATCH] [V1] Use more persistent buffers to optimize input preparation overheads (#11111) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_input_batch.py | 19 +++-- vllm/v1/worker/gpu_model_runner.py | 119 ++++++++++++++++------------- 2 files changed, 79 insertions(+), 59 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 25d95ac6e26af..9046b37f60005 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -53,14 +53,23 @@ def __init__( self.req_ids: List[Optional[str]] = [None] * max_num_reqs self.req_id_to_index: Dict[str, int] = {} - self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), - dtype=np.int32) + # TODO(woosuk): This buffer could be too large if max_model_len is big. + # Find a way to reduce the CPU memory usage. + self.token_ids_cpu_tensor = torch.zeros( + (max_num_reqs, max_model_len), + device="cpu", + dtype=torch.int32, + pin_memory=pin_memory, + ) + self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) # Attention-related. - self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), - device=self.device, - dtype=torch.int32) + self.block_table = torch.zeros( + (max_num_reqs, max_num_blocks_per_req), + device=self.device, + dtype=torch.int32, + ) self.block_table_cpu_tensor = torch.zeros( (max_num_reqs, max_num_blocks_per_req), device="cpu", diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e75be21ef2d91..aa91255e68d48 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -67,6 +67,7 @@ def __init__( self.max_model_len = model_config.max_model_len self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) self.max_num_tokens = scheduler_config.max_num_batched_tokens + self.max_num_reqs = scheduler_config.max_num_seqs # Model-related. self.num_attn_layers = model_config.get_num_layers_by_block_type( @@ -88,7 +89,7 @@ def __init__( self.requests: Dict[str, CachedRequestState] = {} # Persistent batch. self.input_batch = InputBatch( - max_num_reqs=self.scheduler_config.max_num_seqs, + max_num_reqs=self.max_num_reqs, max_model_len=self.max_model_len, max_num_blocks_per_req=self.max_num_blocks_per_req, device=self.device, @@ -117,6 +118,32 @@ def __init__( dtype=self.dtype, device=self.device) + self.input_ids_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.input_ids_np = self.input_ids_cpu.numpy() + self.positions_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + self.positions_np = self.positions_cpu.numpy() + self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.slot_mapping_np = self.slot_mapping_cpu.numpy() + self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.query_start_loc_np = self.query_start_loc_cpu.numpy() + self.seq_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.seq_start_loc_np = self.seq_start_loc_cpu.numpy() + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove stopped requests from the cached states. # Keep the states of the pre-empted requests. @@ -241,22 +268,14 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - indices = np.arange(num_reqs) - req_indices = np.repeat(indices, num_scheduled_tokens) + req_indices = np.repeat(np.arange(num_reqs), num_scheduled_tokens) # Get batched arange. # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - arange_matrix = np.tile(np.arange(max_num_scheduled_tokens), - (num_reqs, 1)) - mask = arange_matrix < num_scheduled_tokens[:, np.newaxis] - arange = arange_matrix[mask] + arange = np.concatenate([np.arange(n) for n in num_scheduled_tokens]) # Get positions. - positions = torch.empty((total_num_scheduled_tokens, ), - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - positions_np = positions.numpy() + positions_np = self.positions_np[:total_num_scheduled_tokens] np.add(self.input_batch.num_computed_tokens_cpu[req_indices], arange, out=positions_np) @@ -267,16 +286,13 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # where M is the max_model_len. token_indices = (positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1]) - token_indices = torch.from_numpy(token_indices) - input_ids = torch.empty((total_num_scheduled_tokens, ), - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - torch.index_select(torch.from_numpy( - self.input_batch.token_ids_cpu).flatten(), + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(), 0, - token_indices, - out=input_ids) + torch.from_numpy(token_indices), + out=self.input_ids_cpu[:total_num_scheduled_tokens]) # Calculate the slot mapping. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] @@ -284,45 +300,40 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # where K is the max_num_blocks_per_req and the block size is 2. # NOTE(woosuk): We can't simply use `token_indices // block_size` here # because M (max_model_len) is not necessarily divisible by block_size. - block_numbers = self.input_batch.block_table_cpu_tensor.flatten()[ - req_indices * self.max_num_blocks_per_req + - positions_np // self.block_size] - block_offsets = torch.from_numpy(positions_np % self.block_size) - slot_mapping = torch.empty((total_num_scheduled_tokens, ), - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - torch.add(block_numbers * self.block_size, - block_offsets, - out=slot_mapping) + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions_np // self.block_size) + # NOTE(woosuk): We use torch.index_select instead of np.take here + # because torch.index_select is much faster than np.take for large + # tensors. + block_numbers = (self.input_batch.block_table_cpu_tensor.flatten() + [block_table_indices].numpy()) + block_offsets = positions_np % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping_np[:total_num_scheduled_tokens]) # Prepare the attention metadata. - query_start_loc = torch.empty((num_reqs + 1, ), - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - query_start_loc_np = query_start_loc.numpy() - query_start_loc_np[0] = 0 - np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:]) + self.query_start_loc_np[0] = 0 + np.cumsum(num_scheduled_tokens, + out=self.query_start_loc_np[1:num_reqs + 1]) seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens) max_seq_len = seq_lens.max() - seq_start_loc = torch.empty((num_reqs + 1, ), - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - seq_start_loc_np = seq_start_loc.numpy() - seq_start_loc_np[0] = 0 - np.cumsum(seq_lens, out=seq_start_loc_np[1:]) - - self.input_ids[:total_num_scheduled_tokens].copy_(input_ids, - non_blocking=True) - self.positions[:total_num_scheduled_tokens].copy_(positions, - non_blocking=True) - query_start_loc = query_start_loc.to(self.device, non_blocking=True) - seq_start_loc = seq_start_loc.to(self.device, non_blocking=True) - slot_mapping = slot_mapping.to(self.device, non_blocking=True).long() + self.seq_start_loc_np[0] = 0 + np.cumsum(seq_lens, out=self.seq_start_loc_np[1:num_reqs + 1]) + + # Copy the tensors to the GPU. + self.input_ids[:total_num_scheduled_tokens].copy_( + self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True) + self.positions[:total_num_scheduled_tokens].copy_( + self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) + query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( + self.device, non_blocking=True) + seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to( + self.device, non_blocking=True) + slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to( + self.device, non_blocking=True).long() attn_metadata = FlashAttentionMetadata( num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens,