Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flash decoding kernel ROCm #2855

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion server/Makefile-flash-att-v2
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
flash_att_v2_commit_cuda := v2.6.1
flash_att_v2_commit_rocm := 2092111b9f975b3347c652ff7fabd431130256c4
flash_att_v2_commit_rocm := 47bd46e0204a95762ae48712fd1a3978827c77fd

build-flash-attention-v2-cuda:
pip install -U packaging wheel
Expand Down
48 changes: 43 additions & 5 deletions server/text_generation_server/layers/attention/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master
from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
)
from loguru import logger
import vllm._custom_ops as ops

Expand Down Expand Up @@ -73,11 +77,44 @@ def paged_attention(
# limitations under the License.
#

if ATTENTION == "flashdecoding":
max_q = 1
max_k = max_s
import flash_attn_2_cuda

if softcap is None:
softcap = 0.0
out = flash_attn_2_cuda.varlen_fwd(
query,
kv_cache.key,
kv_cache.value,
None,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None, # pad_k
None,
block_tables,
None,
max_q,
max_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
True, # causal
-1, # Window_left
-1, # Window right
softcap,
False, # return softmax
None, # generator
)
return out[0]

if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping")

# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = kv_cache.value.shape[3]
# block_size = kv_cache.value.shape[3]
block_size = BLOCK_SIZE
num_seqs, num_heads, head_size = query.shape

num_kv_heads = kv_cache.key.shape[1]
Expand Down Expand Up @@ -247,14 +284,15 @@ def attention(
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd(
query,
key,
value,
# flashdecoding: pass the KV caches, paged: pass the KV.
kv_cache.key if ATTENTION == "flashdecoding" else key,
kv_cache.value if ATTENTION == "flashdecoding" else value,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
None,
seqlen.cu_seqlen_k,
None,
None,
block_tables if ATTENTION == "flashdecoding" else None,
None,
seqlen.max_q,
seqlen.max_k,
Expand Down
14 changes: 9 additions & 5 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,7 +1663,7 @@ def warmup(

for seqlen in tuning_sequences:
log_master(logger.info, f"Warming up TunableOp for seqlen={seqlen}")
self.tunableop_warmup(seqlen)
self.tunableop_warmup(seqlen, max_total_tokens)
torch.cuda.tunable.write_file(tunableop_filepath)
if os.environ.get("PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP") != "1":
torch.cuda.tunable.tuning_enable(False)
Expand Down Expand Up @@ -1710,7 +1710,7 @@ def warmup(
assert max_total_tokens is not None
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens

def tunableop_warmup(self, seqlen: int):
def tunableop_warmup(self, seqlen: int, max_bt: int):
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
Expand All @@ -1724,11 +1724,15 @@ def tunableop_warmup(self, seqlen: int):
[0, seqlen], device=self.device, dtype=torch.int32
)
max_s = seqlen

block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device
).repeat(seqlen)
block_tables = block_tables.reshape((seqlen, max_bt))

seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=1,
max_k=seqlen,
)

Expand All @@ -1738,7 +1742,7 @@ def tunableop_warmup(self, seqlen: int):
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=self.kv_cache,
block_tables=None,
block_tables=block_tables,
seqlen=seqlen,
slots=slots,
max_s=max_s,
Expand Down
Loading