Skip to content

Commit

Permalink
optimize score_mod order
Browse files Browse the repository at this point in the history
  • Loading branch information
zinccat committed Oct 24, 2024
1 parent 0f183a3 commit 69c7f69
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 16 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ Float16:
| Method | Forward Time (s) | Gradient Time (s) |
|------------------------|----------------------------|----------------------------|
| FlaxAttention (Pure JAX) | 0.5692746052518487 | 0.8823547409847379 |
| FlaxAttention (Pallas) | **0.13677988620474935** | **0.5575501238927245** |
| FlaxAttention (Pallas) | **0.1342552239075303** | **0.5429903408512473** |
| Jax Attention (no score_mod) | 1.6788566000759602 | 1.0905949068255723 |
| FlexAttention (Torch)| **0.11708855209872127** | **0.5104729640297592** |

Expand Down
18 changes: 3 additions & 15 deletions flaxattention/kernel/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,10 @@ def body(start_k, carry):
span_q = start_q * block_q + jnp.arange(block_q)
span_k = start_k * block_k + jnp.arange(block_k)
# boolean mask for the current qk slice
qk = score_mod(qk, start_b, start_h, span_q, span_k)
qk = jnp.where(
mask_mod(start_b, start_h, span_q, span_k), qk, DEFAULT_MASK_VALUE
)
qk = jnp.where(
qk != DEFAULT_MASK_VALUE,
score_mod(qk, start_b, start_h, span_q, span_k),
DEFAULT_MASK_VALUE,
)
# Avoids Triton crash.
# if num_heads > 2:
# qk = qk.astype(q_ref.dtype)
Expand Down Expand Up @@ -431,14 +427,10 @@ def inner_loop_dkdv(start_q, carry):
qk_pre_mod = qk
span_q = start_q * block_q1 + jnp.arange(block_q1)
# boolean mask for the current qk slice
qk = score_mod(qk, start_b, start_h, span_q, span_k)
qk = jnp.where(
mask_mod(start_b, start_h, span_q, span_k), qk, DEFAULT_MASK_VALUE
)
qk = jnp.where(
qk != DEFAULT_MASK_VALUE,
score_mod(qk, start_b, start_h, span_q, span_k),
DEFAULT_MASK_VALUE,
)
if causal or segment_ids_ref is not None:
mask = None
if segment_ids_ref is not None:
Expand Down Expand Up @@ -511,14 +503,10 @@ def inner_loop_dq(start_k, dq):

span_k = start_k * block_k2 + jnp.arange(block_k2)
# boolean mask for the current qk slice
qk = score_mod(qk, start_b, start_h, span_q, span_k)
qk = jnp.where(
mask_mod(start_b, start_h, span_q, span_k), qk, DEFAULT_MASK_VALUE
)
qk = jnp.where(
qk != DEFAULT_MASK_VALUE,
score_mod(qk, start_b, start_h, span_q, span_k),
DEFAULT_MASK_VALUE,
)

if causal or segment_ids_ref is not None:
mask = None
Expand Down

0 comments on commit 69c7f69

Please sign in to comment.