You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
There is some perf gap between AOT and JIT Triton for flash attention on most seqlen, n_heads, head_dim
We tried to tune the flash attention kernel and got some perf improvement on head_dim=128, However, it's still slower than JIT Triton kernel.
Suggestion Description
Thanks for this great work,
There is some perf gap between AOT and JIT Triton for flash attention on most seqlen, n_heads, head_dim
We tried to tune the flash attention kernel and got some perf improvement on head_dim=128, However, it's still slower than JIT Triton kernel.
Looks their triton kernel tune space has some difference and this is the main difference we found.
triton kernel tune space for aotriton - https://github.com/ROCm/aotriton/blob/main/tritonsrc/attn_torch_function.py#L47
triton kernel tune space for jit triton - https://github.com/ROCm/triton/blob/triton-mlir/python/tutorials/06-fused-attention.py#L84-L92
Is there any other main difference that make JIT Triton faster than AOT triton FA kernel?
Operating System
ubuntu 22
GPU
mi300
ROCm Component
rocBLAS
The text was updated successfully, but these errors were encountered: