Skip to content

Commit

Permalink
Factoring cu_seqlen_qk for better abstracting over every model.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jul 1, 2024
1 parent 12a500c commit 573d88c
Show file tree
Hide file tree
Showing 17 changed files with 54 additions and 75 deletions.
2 changes: 2 additions & 0 deletions server/text_generation_server/layers/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from text_generation_server.utils.import_utils import SYSTEM
import os

from .common import Seqlen

if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
Expand Down
31 changes: 31 additions & 0 deletions server/text_generation_server/layers/attention/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from dataclasses import dataclass
from text_generation_server.models.globals import FLASH_DECODING
import torch
from typing import Optional


@dataclass
class Seqlen:
input_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]

def __init__(self, input_lengths):
self.input_lengths = input_lengths
if FLASH_DECODING:
device = self.input_lengths.device
shape = self.input_lengths.shape
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
cu_seqlen_k = torch.empty(shape[-1] + 1, device=device, dtype=torch.int32)
cu_seqlen_k[0] = 0
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])

self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
else:
self.cu_seqlen_q = None
self.cu_seqlen_k = None
10 changes: 5 additions & 5 deletions server/text_generation_server/layers/attention/cuda.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen

major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
Expand Down Expand Up @@ -40,8 +41,7 @@ def paged_attention(
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
cu_seqlen_q: torch.Tensor,
cu_seqlen_k: torch.Tensor,
seqlen: Seqlen,
max_s: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
Expand All @@ -66,7 +66,6 @@ def paged_attention(
block_size = BLOCK_SIZE
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = cu_seqlen_k

# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
Expand All @@ -88,8 +87,8 @@ def paged_attention(
key_cache,
value_cache,
None,
cu_seqlen_q,
cu_seqlen_k,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
block_tables,
None,
Expand All @@ -106,6 +105,7 @@ def paged_attention(
)
return out2[0]
else:
input_lengths = seqlen.input_lengths
from vllm._C import ops

use_v1 = max_s <= 8192 and (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,7 @@ def forward(
cu_seqlen_prefill,
kv_cache,
block_tables,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
slots,
max_s,
):
Expand Down Expand Up @@ -314,8 +313,7 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
max_s,
)

Expand Down Expand Up @@ -389,8 +387,7 @@ def forward(
cu_seqlen_prefill,
kv_cache,
block_tables,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
slots,
max_s,
):
Expand All @@ -404,8 +401,7 @@ def forward(
cu_seqlen_prefill,
kv_cache,
block_tables,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
slots,
max_s,
)
Expand Down Expand Up @@ -469,23 +465,6 @@ def forward(
)

residual = None
if cu_seqlen_prefill is None and FLASH_DECODING:
cu_seqlen_q = torch.arange(
input_lengths.shape[0] + 1,
device=input_ids.device,
dtype=torch.int32,
)
cu_seqlen_k = torch.cat(
[
torch.zeros(
(1,), device=input_lengths.device, dtype=input_lengths.dtype
),
input_lengths.cumsum(dim=-1),
]
).to(dtype=torch.int32)
else:
cu_seqlen_q = None
cu_seqlen_k = input_lengths

for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
Expand All @@ -496,8 +475,7 @@ def forward(
cu_seqlen_prefill,
kv_cache[i],
block_tables,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
slots,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,7 @@ def forward(
kv_cache,
block_tables,
slots,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
max_s,
adapter_data,
):
Expand Down Expand Up @@ -218,8 +217,7 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
max_s,
)

Expand Down Expand Up @@ -354,8 +352,7 @@ def forward(
kv_cache,
block_tables,
slots,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
max_s,
adapter_data,
):
Expand All @@ -370,8 +367,7 @@ def forward(
kv_cache,
block_tables,
slots,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
max_s,
adapter_data,
)
Expand Down Expand Up @@ -441,23 +437,6 @@ def forward(
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)
if cu_seqlen_prefill is None and FLASH_DECODING:
cu_seqlen_q = torch.arange(
input_lengths.shape[0] + 1,
device=inputs_embeds.device,
dtype=torch.int32,
)
cu_seqlen_k = torch.cat(
[
torch.zeros(
(1,), device=input_lengths.device, dtype=input_lengths.dtype
),
input_lengths.cumsum(dim=-1),
]
).to(dtype=torch.int32)
else:
cu_seqlen_q = None
cu_seqlen_k = input_lengths

residual = None
for i, layer in enumerate(self.layers):
Expand All @@ -470,8 +449,7 @@ def forward(
kv_cache[i],
block_tables,
slots,
cu_seqlen_q,
cu_seqlen_k,
input_lengths,
max_s,
adapter_data,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down Expand Up @@ -349,7 +348,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ def forward(
self.kv_head_mapping,
self.softmax_scale,
block_tables,
None,
input_lengths,
max_s,
)
Expand Down
10 changes: 6 additions & 4 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
from text_generation_server.models.globals import (
MEM_POOL,
FLASH_DECODING,
BLOCK_SIZE,
CUDA_GRAPHS,
get_adapter_to_index,
MODEL_ID,
)
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments
Expand All @@ -47,9 +49,6 @@

tracer = trace.get_tracer(__name__)

BLOCK_SIZE: int = (
256 if os.getenv("FLASH_DECODING", "").lower() in {"1", "true"} else 16
)

# Will be set in init
SLIDING_WINDOW: Optional[int] = None
Expand Down Expand Up @@ -927,6 +926,7 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
"slots": slots,
"input_lengths": input_lengths,
}
input_lengths = Seqlen(input_lengths=input_lengths)
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph

Expand Down Expand Up @@ -1086,6 +1086,7 @@ def tunableop_warmup(self, seqlen: int):

# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
seqlen = Seqlen(input_lengths=input_lengths)

# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
Expand All @@ -1096,7 +1097,7 @@ def tunableop_warmup(self, seqlen: int):
),
kv_cache=self.kv_cache,
block_tables=None,
input_lengths=input_lengths,
seqlen=seqlen,
slots=slots,
max_s=seqlen,
lm_head_indices=None,
Expand Down Expand Up @@ -1172,6 +1173,7 @@ def forward(
cuda_graph = None

if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
Expand Down

0 comments on commit 573d88c

Please sign in to comment.