Skip to content

Commit

Permalink
[V1] Use more persistent buffers to optimize input preparation overhe…
Browse files Browse the repository at this point in the history
…ads (#11111)

Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon authored Dec 12, 2024
1 parent 1da8f0e commit f092153
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 59 deletions.
19 changes: 14 additions & 5 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,23 @@ def __init__(
self.req_ids: List[Optional[str]] = [None] * max_num_reqs
self.req_id_to_index: Dict[str, int] = {}

self.token_ids_cpu = np.empty((max_num_reqs, max_model_len),
dtype=np.int32)
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage.
self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)

# Attention-related.
self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32)
self.block_table = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu_tensor = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
Expand Down
119 changes: 65 additions & 54 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs

# Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type(
Expand All @@ -88,7 +89,7 @@ def __init__(
self.requests: Dict[str, CachedRequestState] = {}
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.scheduler_config.max_num_seqs,
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device,
Expand Down Expand Up @@ -117,6 +118,32 @@ def __init__(
dtype=self.dtype,
device=self.device)

self.input_ids_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.input_ids_np = self.input_ids_cpu.numpy()
self.positions_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.positions_np = self.positions_cpu.numpy()
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
self.seq_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()

def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests.
Expand Down Expand Up @@ -241,22 +268,14 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):

# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
indices = np.arange(num_reqs)
req_indices = np.repeat(indices, num_scheduled_tokens)
req_indices = np.repeat(np.arange(num_reqs), num_scheduled_tokens)

# Get batched arange.
# E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
arange_matrix = np.tile(np.arange(max_num_scheduled_tokens),
(num_reqs, 1))
mask = arange_matrix < num_scheduled_tokens[:, np.newaxis]
arange = arange_matrix[mask]
arange = np.concatenate([np.arange(n) for n in num_scheduled_tokens])

# Get positions.
positions = torch.empty((total_num_scheduled_tokens, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
positions_np = positions.numpy()
positions_np = self.positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np)
Expand All @@ -267,62 +286,54 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
# where M is the max_model_len.
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
token_indices = torch.from_numpy(token_indices)
input_ids = torch.empty((total_num_scheduled_tokens, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
torch.index_select(torch.from_numpy(
self.input_batch.token_ids_cpu).flatten(),
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
token_indices,
out=input_ids)
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens])

# Calculate the slot mapping.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size` here
# because M (max_model_len) is not necessarily divisible by block_size.
block_numbers = self.input_batch.block_table_cpu_tensor.flatten()[
req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size]
block_offsets = torch.from_numpy(positions_np % self.block_size)
slot_mapping = torch.empty((total_num_scheduled_tokens, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
torch.add(block_numbers * self.block_size,
block_offsets,
out=slot_mapping)
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size)
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
block_numbers = (self.input_batch.block_table_cpu_tensor.flatten()
[block_table_indices].numpy())
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping_np[:total_num_scheduled_tokens])

# Prepare the attention metadata.
query_start_loc = torch.empty((num_reqs + 1, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
query_start_loc_np = query_start_loc.numpy()
query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:])
self.query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens,
out=self.query_start_loc_np[1:num_reqs + 1])

seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
max_seq_len = seq_lens.max()
seq_start_loc = torch.empty((num_reqs + 1, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
seq_start_loc_np = seq_start_loc.numpy()
seq_start_loc_np[0] = 0
np.cumsum(seq_lens, out=seq_start_loc_np[1:])

self.input_ids[:total_num_scheduled_tokens].copy_(input_ids,
non_blocking=True)
self.positions[:total_num_scheduled_tokens].copy_(positions,
non_blocking=True)
query_start_loc = query_start_loc.to(self.device, non_blocking=True)
seq_start_loc = seq_start_loc.to(self.device, non_blocking=True)
slot_mapping = slot_mapping.to(self.device, non_blocking=True).long()
self.seq_start_loc_np[0] = 0
np.cumsum(seq_lens, out=self.seq_start_loc_np[1:num_reqs + 1])

# Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True)
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
seq_start_loc = self.seq_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
self.device, non_blocking=True).long()
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
Expand Down

0 comments on commit f092153

Please sign in to comment.