Skip to content

Commit

Permalink
fix decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
zinccat committed Oct 24, 2024
1 parent c636239 commit a6860a4
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 22 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,12 @@ Float16:
| FlexAttention (Torch)| **0.11708855209872127** | **0.5104729640297592** |

We can see that the forward performance is about 20% slower than the original implementation, while backward about 8% slower. There are still some optimizations to be done.

Decoding (Float16):
seq_len for query = 1,
- FlexAttention: 0.0103s
- FlaxAttention (This repo): 0.0145s
- FlaxAttention (Without Pallas Flash Attention): **0.00650s**
- Jax Pallas Decoding Attention (no score_mod): 0.00998s

We can see that pure JAX implementation is actually the fastest, surpassing Palllas Flash Attention. The kernel also supports arbitrary query length and the inflection point is around 64, where the Palllas Flash Attention starts to outperform the pure JAX implementation when the query length is greater than 64.
15 changes: 1 addition & 14 deletions examples/benchmark_torch_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,4 @@ def checkerboard_torch(score, batch, head, q_idx, k_idx):
torch.cuda.synchronize()
end = timer()

print("Time taken for 100 iterations: ", end - start)

start = timer()
for _ in range(100):
output_torch = flex_attention(
query_torch,
key_torch,
value_torch,
score_mod=checkerboard_torch,
).sum().backward()
torch.cuda.synchronize()
end = timer()

print("Time taken for 100 iterations backprop: ", end - start)
print("Time taken for 100 iterations: ", end - start)
15 changes: 8 additions & 7 deletions flaxattention/kernel/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,14 @@ def mha(
score_mod_grad: _score_mod_signature | None = None,
):
del backward_pass_impl
batch_size, seq_len, num_heads, head_dim = q.shape
block_q = min(block_q, seq_len)
block_k = min(block_k, seq_len)
batch_size, seq_len_q, num_heads, head_dim = q.shape
seq_len_kv = k.shape[1]
block_q = min(block_q, seq_len_q)
block_k = min(block_k, seq_len_kv)
# Heuristics.
grid_ = grid
if grid_ is None:
grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads) # seq, batch, head
grid_ = (pl.cdiv(seq_len_q, block_q), batch_size, num_heads) # seq, batch, head

num_warps_ = num_warps
if num_warps_ is None:
Expand All @@ -229,13 +230,13 @@ def mha(

in_specs = [
pl.BlockSpec((None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)),
pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)),
pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)),
pl.BlockSpec((None, seq_len_kv, None, head_dim), lambda _, j, k: (j, 0, k, 0)),
pl.BlockSpec((None, seq_len_kv, None, head_dim), lambda _, j, k: (j, 0, k, 0)),
]
in_specs.append(
None # type: ignore[arg-type]
if segment_ids is None
else pl.BlockSpec((None, seq_len), lambda _, j, k: (j, 0))
else pl.BlockSpec((None, seq_len_kv), lambda _, j, k: (j, 0))
)
out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype)
return pl.pallas_call(
Expand Down
166 changes: 165 additions & 1 deletion tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from absl.testing import absltest # pylint: disable=import-error

from flaxattention import flax_attention
from flaxattention import flax_attention, flax_attention_pallas
from flaxattention.mods import generate_alibi_bias


Expand Down Expand Up @@ -61,6 +61,170 @@ def test_equivalence_with_torch(self):

np.testing.assert_almost_equal(output_jax, output_torch, decimal=2)

def test_pallas_equivalence_with_torch(self):
# Prepare inputs
batch_size = 4
num_heads = 8
seq_len_q = 64
seq_len_kv = 64
feature_size = 32

# Random tensors for query, key, and value
key = jax.random.normal(
jax.random.PRNGKey(0), (batch_size, num_heads, seq_len_kv, feature_size)
)
query = jax.random.normal(
jax.random.PRNGKey(1), (batch_size, num_heads, seq_len_q, feature_size)
)
value = jax.random.normal(
jax.random.PRNGKey(2), (batch_size, num_heads, seq_len_kv, feature_size)
)

output_jax = flax_attention_pallas(
query,
key,
value,
)

query_torch = jax2torch(query)
key_torch = jax2torch(key)
value_torch = jax2torch(value)

output_torch = (
flex_attention(
query_torch,
key_torch,
value_torch,
)
.detach()
.cpu()
.numpy()
)

np.testing.assert_almost_equal(output_jax, output_torch, decimal=2)

def test_decoding_equivalence_with_torch(self):
# Prepare inputs
batch_size = 4
num_heads = 8
seq_len_q = 1
seq_len_kv = 64
feature_size = 32

# Random tensors for query, key, and value
key = jax.random.normal(
jax.random.PRNGKey(0), (batch_size, num_heads, seq_len_kv, feature_size)
)
query = jax.random.normal(
jax.random.PRNGKey(1), (batch_size, num_heads, seq_len_q, feature_size)
)
value = jax.random.normal(
jax.random.PRNGKey(2), (batch_size, num_heads, seq_len_kv, feature_size)
)

output_jax = flax_attention(
query,
key,
value,
score_mod=generate_alibi_bias(num_heads),
)

query_torch = jax2torch(query)
key_torch = jax2torch(key)
value_torch = jax2torch(value)

def generate_alibi_bias_torch(H: int):
"""Returns an alibi bias score_mod given the number of heads H
Args:
H: number of heads
Returns:
alibi_bias: alibi bias score_mod
"""

def alibi_mod(score, b, h, q_idx, kv_idx):
scale = torch.exp2(-((h + 1) * 8.0 / H))
bias = (q_idx - kv_idx) * scale
return score + bias

return alibi_mod

output_torch = (
flex_attention(
query_torch,
key_torch,
value_torch,
score_mod=generate_alibi_bias_torch(num_heads),
)
.detach()
.cpu()
.numpy()
)

np.testing.assert_almost_equal(output_jax, output_torch, decimal=2)

def test_pallas_decoding_equivalence_with_torch(self):
# Prepare inputs
batch_size = 4
num_heads = 8
seq_len_q = 1
seq_len_kv = 64
feature_size = 32

# Random tensors for query, key, and value
key = jax.random.normal(
jax.random.PRNGKey(0), (batch_size, num_heads, seq_len_kv, feature_size)
)
query = jax.random.normal(
jax.random.PRNGKey(1), (batch_size, num_heads, seq_len_q, feature_size)
)
value = jax.random.normal(
jax.random.PRNGKey(2), (batch_size, num_heads, seq_len_kv, feature_size)
)

output_jax = flax_attention_pallas(
query,
key,
value,
score_mod=generate_alibi_bias(num_heads),
)

query_torch = jax2torch(query)
key_torch = jax2torch(key)
value_torch = jax2torch(value)

def generate_alibi_bias_torch(H: int):
"""Returns an alibi bias score_mod given the number of heads H
Args:
H: number of heads
Returns:
alibi_bias: alibi bias score_mod
"""

def alibi_mod(score, b, h, q_idx, kv_idx):
scale = torch.exp2(-((h + 1) * 8.0 / H))
bias = (q_idx - kv_idx) * scale
return score + bias

return alibi_mod

output_torch = (
flex_attention(
query_torch,
key_torch,
value_torch,
score_mod=generate_alibi_bias_torch(num_heads),
)
.detach()
.cpu()
.numpy()
)

np.testing.assert_almost_equal(output_jax, output_torch, decimal=2)

def test_gqa(self):
# Prepare inputs
batch_size = 4
Expand Down

0 comments on commit a6860a4

Please sign in to comment.