diff --git a/server/Makefile-flash-att-v2 b/server/Makefile-flash-att-v2 index a9cdf782270..9a946d97f8b 100644 --- a/server/Makefile-flash-att-v2 +++ b/server/Makefile-flash-att-v2 @@ -1,5 +1,5 @@ flash_att_v2_commit_cuda := v2.6.1 -flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4 +flash_att_v2_commit_rocm := 47bd46e0204a95762ae48712fd1a3978827c77fd build-flash-attention-v2-cuda: pip install -U packaging wheel diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 0cfac25bd93..69a245ad5aa 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -5,6 +5,10 @@ from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention import Seqlen from text_generation_server.utils.log import log_master +from text_generation_server.models.globals import ( + ATTENTION, + BLOCK_SIZE, +) from loguru import logger import vllm._custom_ops as ops @@ -73,11 +77,44 @@ def paged_attention( # limitations under the License. # + if ATTENTION == "flashdecoding": + max_q = 1 + max_k = max_s + import flash_attn_2_cuda + + if softcap is None: + softcap = 0.0 + out = flash_attn_2_cuda.varlen_fwd( + query, + kv_cache.key, + kv_cache.value, + None, + seqlen.cu_seqlen_q, + seqlen.cu_seqlen_k, + None, # pad_k + None, + block_tables, + None, + max_q, + max_k, + 0.0, # dropout + softmax_scale, + False, # zero_tensors + True, # causal + -1, # Window_left + -1, # Window right + softcap, + False, # return softmax + None, # generator + ) + return out[0] + if softcap is not None: raise RuntimeError("Paged attention doesn't support softcapping") # value_cache => [num_blocks, num_heads, head_size, block_size] - block_size = kv_cache.value.shape[3] + # block_size = kv_cache.value.shape[3] + block_size = BLOCK_SIZE num_seqs, num_heads, head_size = query.shape num_kv_heads = kv_cache.key.shape[1] @@ -247,14 +284,15 @@ def attention( # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. return flash_attn_2_cuda.varlen_fwd( query, - key, - value, + # flashdecoding: pass the KV caches, paged: pass the KV. + kv_cache.key if ATTENTION == "flashdecoding" else key, + kv_cache.value if ATTENTION == "flashdecoding" else value, out, seqlen.cu_seqlen_q, - seqlen.cu_seqlen_q, - None, + seqlen.cu_seqlen_k, None, None, + block_tables if ATTENTION == "flashdecoding" else None, None, seqlen.max_q, seqlen.max_k, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5d37699074f..79cb299667c 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1663,7 +1663,7 @@ def warmup( for seqlen in tuning_sequences: log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}") - self.tunableop_warmup(seqlen) + self.tunableop_warmup(seqlen, max_total_tokens) torch.cuda.tunable.write_file(tunableop_filepath) if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1": torch.cuda.tunable.tuning_enable(False) @@ -1710,7 +1710,7 @@ def warmup( assert max_total_tokens is not None return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens - def tunableop_warmup(self, seqlen: int): + def tunableop_warmup(self, seqlen: int, max_bt: int): input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device) position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device) slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) @@ -1724,11 +1724,15 @@ def tunableop_warmup(self, seqlen: int): [0, seqlen], device=self.device, dtype=torch.int32 ) max_s = seqlen + + block_tables = torch.arange( + max_bt, dtype=torch.int32, device=self.device + ).repeat(seqlen) + block_tables = block_tables.reshape((seqlen, max_bt)) + seqlen = Seqlen( input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, - cu_seqlen_q=cu_seqlen_prefill, - max_q=1, max_k=seqlen, ) @@ -1738,7 +1742,7 @@ def tunableop_warmup(self, seqlen: int): position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=self.kv_cache, - block_tables=None, + block_tables=block_tables, seqlen=seqlen, slots=slots, max_s=max_s,