Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Add BlockTable class #11693

Merged
merged 6 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions vllm/v1/worker/block_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import List

import numpy as np
import torch

from vllm.logger import init_logger

logger = init_logger(__name__)


class BlockTable:

def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_blocks_per_req: int,
pin_memory: bool,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req
self.pin_memory = pin_memory
self.device = device

self.block_table = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.block_table_np = self.block_table_cpu.numpy()
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)

def append_row(
self,
row_idx: int,
start: int,
block_ids: List[int],
) -> None:
num_blocks = len(block_ids)
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
self.num_blocks_per_row[row_idx] = start + num_blocks

def add_row(self, row_idx: int, block_ids: List[int]) -> None:
self.append_row(row_idx, 0, block_ids)

def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
src, :num_blocks]
self.num_blocks_per_row[tgt] = num_blocks

def commit(self, num_reqs: int) -> None:
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
non_blocking=True)

def clear(self) -> None:
self.block_table.fill_(0)
self.block_table_cpu.fill_(0)

def get_device_tensor(self) -> torch.Tensor:
"""Ruturns the device tensor of the block table."""
return self.block_table

def get_cpu_tensor(self) -> torch.Tensor:
"""Returns the CPU tensor of the block table."""
return self.block_table_cpu

def get_numpy_array(self) -> np.ndarray:
"""Returns the numpy array of the block table."""
return self.block_table_np
25 changes: 9 additions & 16 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import BlockTable

if TYPE_CHECKING:
from vllm.multimodal.inputs import PlaceholderRange
Expand Down Expand Up @@ -70,19 +71,14 @@ def __init__(
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)

# Attention-related.
self.block_table = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device=self.device,
dtype=torch.int32,
)
self.block_table_cpu_tensor = torch.zeros(
(max_num_reqs, max_num_blocks_per_req),
device="cpu",
dtype=torch.int32,
# Block table.
self.block_table = BlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_blocks_per_req=max_num_blocks_per_req,
pin_memory=pin_memory,
device=device,
)
self.block_table_cpu = self.block_table_cpu_tensor.numpy()

# Sampling-related.
self.temperature = torch.empty((max_num_reqs, ),
Expand Down Expand Up @@ -193,8 +189,7 @@ def add_request(
self.num_tokens[req_index] = request.num_tokens

self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
num_blocks = len(request.block_ids)
self.block_table_cpu[req_index, :num_blocks] = request.block_ids
self.block_table.add_row(req_index, request.block_ids)

sampling_params = request.sampling_params
self.temperature_cpu[req_index] = sampling_params.temperature
Expand Down Expand Up @@ -300,9 +295,7 @@ def condense(self, empty_req_indices: List[int]) -> None:
self.num_prompt_tokens[last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
# TODO(woosuk): Optimize the copy of block_table_cpu.
self.block_table_cpu[empty_index] = self.block_table_cpu[
last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
Expand Down
16 changes: 7 additions & 9 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
if num_new_blocks == 0:
continue
start_index = len(req_state.block_ids)
end_index = start_index + num_new_blocks
req_state.block_ids.extend(req_data.new_block_ids)
self.input_batch.block_table_cpu[
req_index, start_index:end_index] = req_data.new_block_ids
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)

req_ids_to_add: List[str] = []
# Add new requests to the cached states.
Expand Down Expand Up @@ -275,9 +274,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):

# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table[:num_reqs].copy_(
self.input_batch.block_table_cpu_tensor[:num_reqs],
non_blocking=True)
self.input_batch.block_table.commit(num_reqs)

# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
Expand Down Expand Up @@ -333,8 +330,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
block_numbers = (self.input_batch.block_table_cpu_tensor.flatten()
[block_table_indices].numpy())
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
Expand Down Expand Up @@ -450,7 +447,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_start_loc=seq_start_loc,
block_table=self.input_batch.block_table[:num_reqs],
block_table=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
Expand Down
Loading