From 573d88c44699b746ade96634fb1796abe0f73e7f Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 1 Jul 2024 10:55:00 +0000 Subject: [PATCH] Factoring cu_seqlen_qk for better abstracting over every model. --- .../layers/attention/__init__.py | 2 ++ .../layers/attention/common.py | 31 ++++++++++++++++++ .../layers/attention/cuda.py | 10 +++--- .../custom_modeling/flash_cohere_modeling.py | 32 +++---------------- .../custom_modeling/flash_dbrx_modeling.py | 1 - .../custom_modeling/flash_gemma_modeling.py | 1 - .../custom_modeling/flash_gpt2_modeling.py | 1 - .../custom_modeling/flash_llama_modeling.py | 32 +++---------------- .../custom_modeling/flash_mistral_modeling.py | 1 - .../custom_modeling/flash_mixtral_modeling.py | 1 - .../custom_modeling/flash_neox_modeling.py | 1 - .../custom_modeling/flash_phi_modeling.py | 1 - .../custom_modeling/flash_qwen2_modeling.py | 1 - .../custom_modeling/flash_rw_modeling.py | 2 -- .../flash_santacoder_modeling.py | 1 - .../flash_starcoder2_modeling.py | 1 - .../models/flash_causal_lm.py | 10 +++--- 17 files changed, 54 insertions(+), 75 deletions(-) create mode 100644 server/text_generation_server/layers/attention/common.py diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index e74180e7a86..c8bccefec89 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -1,6 +1,8 @@ from text_generation_server.utils.import_utils import SYSTEM import os +from .common import Seqlen + if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": raise ImportError("`USE_FLASH_ATTENTION` is false.") if SYSTEM == "cuda": diff --git a/server/text_generation_server/layers/attention/common.py b/server/text_generation_server/layers/attention/common.py new file mode 100644 index 00000000000..ca74bdc2649 --- /dev/null +++ b/server/text_generation_server/layers/attention/common.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass +from text_generation_server.models.globals import FLASH_DECODING +import torch +from typing import Optional + + +@dataclass +class Seqlen: + input_lengths: torch.Tensor + cu_seqlen_q: Optional[torch.Tensor] + cu_seqlen_k: Optional[torch.Tensor] + + def __init__(self, input_lengths): + self.input_lengths = input_lengths + if FLASH_DECODING: + device = self.input_lengths.device + shape = self.input_lengths.shape + cu_seqlen_q = torch.arange( + shape[0] + 1, + device=device, + dtype=torch.int32, + ) + cu_seqlen_k = torch.empty(shape[-1] + 1, device=device, dtype=torch.int32) + cu_seqlen_k[0] = 0 + torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:]) + + self.cu_seqlen_q = cu_seqlen_q + self.cu_seqlen_k = cu_seqlen_k + else: + self.cu_seqlen_q = None + self.cu_seqlen_k = None diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index e0f09847a8a..94b69899ef5 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -1,6 +1,7 @@ import torch from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE +from text_generation_server.layers.attention import Seqlen major, minor = torch.cuda.get_device_capability() is_sm75 = major == 7 and minor == 5 @@ -40,8 +41,7 @@ def paged_attention( kv_head_mapping: torch.Tensor, softmax_scale: float, block_tables: torch.Tensor, - cu_seqlen_q: torch.Tensor, - cu_seqlen_k: torch.Tensor, + seqlen: Seqlen, max_s: int, ): # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py @@ -66,7 +66,6 @@ def paged_attention( block_size = BLOCK_SIZE num_seqs, num_heads, head_size = query.shape max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE - input_lengths = cu_seqlen_k # NOTE(woosuk): We use a simple heuristic to decide whether to use # PagedAttention V1 or V2. If the number of partitions is 1, we use @@ -88,8 +87,8 @@ def paged_attention( key_cache, value_cache, None, - cu_seqlen_q, - cu_seqlen_k, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, None, block_tables, None, @@ -106,6 +105,7 @@ def paged_attention( ) return out2[0] else: + input_lengths = seqlen.input_lengths from vllm._C import ops use_v1 = max_s <= 8192 and ( diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index 3fea834a5f0..c51cce3b473 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -260,8 +260,7 @@ def forward( cu_seqlen_prefill, kv_cache, block_tables, - cu_seqlen_q, - cu_seqlen_k, + input_lengths, slots, max_s, ): @@ -314,8 +313,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - cu_seqlen_q, - cu_seqlen_k, + input_lengths, max_s, ) @@ -389,8 +387,7 @@ def forward( cu_seqlen_prefill, kv_cache, block_tables, - cu_seqlen_q, - cu_seqlen_k, + input_lengths, slots, max_s, ): @@ -404,8 +401,7 @@ def forward( cu_seqlen_prefill, kv_cache, block_tables, - cu_seqlen_q, - cu_seqlen_k, + input_lengths, slots, max_s, ) @@ -469,23 +465,6 @@ def forward( ) residual = None - if cu_seqlen_prefill is None and FLASH_DECODING: - cu_seqlen_q = torch.arange( - input_lengths.shape[0] + 1, - device=input_ids.device, - dtype=torch.int32, - ) - cu_seqlen_k = torch.cat( - [ - torch.zeros( - (1,), device=input_lengths.device, dtype=input_lengths.dtype - ), - input_lengths.cumsum(dim=-1), - ] - ).to(dtype=torch.int32) - else: - cu_seqlen_q = None - cu_seqlen_k = input_lengths for i, layer in enumerate(self.layers): hidden_states, residual = layer( @@ -496,8 +475,7 @@ def forward( cu_seqlen_prefill, kv_cache[i], block_tables, - cu_seqlen_q, - cu_seqlen_k, + input_lengths, slots, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 74dc9cf7cbe..9d56e4efd9d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -344,7 +344,6 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index a885dad6d5d..82891823756 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -253,7 +253,6 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index ef238297624..7e7510c737b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -253,7 +253,6 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 7aba63896aa..4ca8cd0ab4e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -173,8 +173,7 @@ def forward( kv_cache, block_tables, slots, - cu_seqlen_q, - cu_seqlen_k, + input_lengths, max_s, adapter_data, ): @@ -218,8 +217,7 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - cu_seqlen_q, - cu_seqlen_k, + input_lengths, max_s, ) @@ -354,8 +352,7 @@ def forward( kv_cache, block_tables, slots, - cu_seqlen_q, - cu_seqlen_k, + input_lengths, max_s, adapter_data, ): @@ -370,8 +367,7 @@ def forward( kv_cache, block_tables, slots, - cu_seqlen_q, - cu_seqlen_k, + input_lengths, max_s, adapter_data, ) @@ -441,23 +437,6 @@ def forward( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( position_ids, max_s, hidden_states.dtype ) - if cu_seqlen_prefill is None and FLASH_DECODING: - cu_seqlen_q = torch.arange( - input_lengths.shape[0] + 1, - device=inputs_embeds.device, - dtype=torch.int32, - ) - cu_seqlen_k = torch.cat( - [ - torch.zeros( - (1,), device=input_lengths.device, dtype=input_lengths.dtype - ), - input_lengths.cumsum(dim=-1), - ] - ).to(dtype=torch.int32) - else: - cu_seqlen_q = None - cu_seqlen_k = input_lengths residual = None for i, layer in enumerate(self.layers): @@ -470,8 +449,7 @@ def forward( kv_cache[i], block_tables, slots, - cu_seqlen_q, - cu_seqlen_k, + input_lengths, max_s, adapter_data, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 673e501b995..d1ba5564201 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -237,7 +237,6 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 2d7b023f3f7..2e839d154ef 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -299,7 +299,6 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index d4e7713d4a6..b87fd4ca00e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -176,7 +176,6 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index 5872b59f6c7..3f445f97ad9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -215,7 +215,6 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 14aee59b3c0..69f38c3ac0d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -157,7 +157,6 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 735d0b90bc3..04d4ba51507 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -225,7 +225,6 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) @@ -349,7 +348,6 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 3873c65347e..badfc36727d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -309,7 +309,6 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py index 4450deb897f..f6a2e15d2a8 100644 --- a/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py @@ -263,7 +263,6 @@ def forward( self.kv_head_mapping, self.softmax_scale, block_tables, - None, input_lengths, max_s, ) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index fa78ee2208e..7ad1c8c50c0 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -31,10 +31,12 @@ from text_generation_server.models.globals import ( MEM_POOL, FLASH_DECODING, + BLOCK_SIZE, CUDA_GRAPHS, get_adapter_to_index, MODEL_ID, ) +from text_generation_server.layers.attention import Seqlen from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.segments import SegmentConcatBuilder, find_segments @@ -47,9 +49,6 @@ tracer = trace.get_tracer(__name__) -BLOCK_SIZE: int = ( - 256 if os.getenv("FLASH_DECODING", "").lower() in {"1", "true"} else 16 -) # Will be set in init SLIDING_WINDOW: Optional[int] = None @@ -927,6 +926,7 @@ def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): "slots": slots, "input_lengths": input_lengths, } + input_lengths = Seqlen(input_lengths=input_lengths) graph = torch.cuda.CUDAGraph() self.cuda_graphs[bs]["graph"] = graph @@ -1086,6 +1086,7 @@ def tunableop_warmup(self, seqlen: int): # Dummy value, some models (starcoder2) don't accept `None`. input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + seqlen = Seqlen(input_lengths=input_lengths) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( @@ -1096,7 +1097,7 @@ def tunableop_warmup(self, seqlen: int): ), kv_cache=self.kv_cache, block_tables=None, - input_lengths=input_lengths, + seqlen=seqlen, slots=slots, max_s=seqlen, lm_head_indices=None, @@ -1172,6 +1173,7 @@ def forward( cuda_graph = None if cu_seqlen_prefill is not None or cuda_graph is None: + input_lengths = Seqlen(input_lengths=input_lengths) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids,