Skip to content

Commit

Permalink
update benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
zinccat committed Oct 22, 2024
1 parent 0b88e59 commit 221d097
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 32 deletions.
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,11 @@ TF32:
- FlaxAttention (Without Pallas Flash Attention): 0.87s

Float16:
- FlexAttention: 0.11s
- FlaxAttention (This repo): 0.13s

We can see that the performance is about 20% slower than the original implementation. There are still some optimizations to be done.

## Issues
Autograd for Pallas is quite slow.
| Method | Forward Time (s) | Gradient Time (s) |
|------------------------|----------------------------|----------------------------|
| FlaxAttention (Pure JAX) | 0.5692746052518487 | 0.8823547409847379 |
| FlaxAttention (Pallas) | **0.13677988620474935** | **0.5575501238927245** |
| Flax (no score_mod) | 0.07136831432580948 | 0.03911650087684393 |
| FlexAttention (Torch)| **0.11708855209872127** | **0.5104729640297592** |

We can see that the forward performance is about 20% slower than the original implementation, while backward about 8% slower. There are still some optimizations to be done.
82 changes: 65 additions & 17 deletions examples/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,29 +80,30 @@ def causal(batch: Array, head: Array, q_idx: Array, k_idx: Array) -> Array:
end = timer()
print("Pure jax time taken:", end - start)

# try flax attention
from flax.nnx import dot_product_attention
def fn0(query, key, value):
return flax_attention(
query,
key,
value,
score_mod=checkerboard,
).sum()
grad_fn0 = jax.grad(fn0, 0)
grad_fn0 = jax.jit(grad_fn0)

# warm up
output = dot_product_attention(
query,
key,
value,
)
output.block_until_ready()
grad = grad_fn0(query, key, value)
grad.block_until_ready()

# print(grad[0, 0, 0])

start = timer()
for _ in range(100):
output = dot_product_attention(
query,
key,
value,
)
output.block_until_ready()
grad = grad_fn0(query, key, value)
grad.block_until_ready()
end = timer()
print("Flax attention time taken (no score_mod):", end - start)
print("Pure jax gradient time taken:", end - start)

# try mha kernel
# try mha pallas kernel

# warm up
output = flax_attention_pallas(
Expand All @@ -113,6 +114,7 @@ def causal(batch: Array, head: Array, q_idx: Array, k_idx: Array) -> Array:
# mask_mod=causal,
)
output.block_until_ready()
# print(output[0, 0, 0])

start = timer()
for _ in range(100):
Expand Down Expand Up @@ -141,12 +143,58 @@ def fn(query, key, value):
grad = grad_fn(query, key, value)
grad.block_until_ready()

# print(grad[0, 0, 0])

start = timer()
for _ in range(100):
grad = grad_fn(query, key, value)
grad.block_until_ready()
end = timer()
print("Gradient time taken:", end - start)
print("Pallas gradient time taken:", end - start)

# try flax attention
from flax.nnx import dot_product_attention

# warm up
output = dot_product_attention(
query,
key,
value,
)
output.block_until_ready()

start = timer()
for _ in range(100):
output = dot_product_attention(
query,
key,
value,
)
output.block_until_ready()
end = timer()
print("Flax attention time taken (no score_mod):", end - start)

def fn1(query, key, value):
return dot_product_attention(
query,
key,
value,
).sum()
grad_fn1 = jax.grad(fn1, 0)
grad_fn1 = jax.jit(grad_fn1)

# warm up
grad = grad_fn1(query, key, value)
grad.block_until_ready()

# print(grad[0, 0, 0])

start = timer()
for _ in range(100):
grad = grad_fn1(query, key, value)
grad.block_until_ready()
end = timer()
print("Flax gradient time taken (no score_mod):", end - start)

# original palllas attention
from jax.experimental.pallas.ops.gpu.attention import mha
Expand Down
13 changes: 6 additions & 7 deletions examples/benchmark_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ def checkerboard_torch(score, batch, head, q_idx, k_idx):
score_mod=checkerboard_torch,
)

grad = torch.autograd.grad(
output_torch.sum(), query_torch, create_graph=True
)[0]

# benchmark
from timeit import default_timer as timer
start = timer()
Expand All @@ -52,9 +48,12 @@ def checkerboard_torch(score, batch, head, q_idx, k_idx):

start = timer()
for _ in range(100):
grad = torch.autograd.grad(
output_torch.sum(), query_torch, create_graph=True
)[0]
output_torch = flex_attention(
query_torch,
key_torch,
value_torch,
score_mod=checkerboard_torch,
).sum().backward()
torch.cuda.synchronize()
end = timer()

Expand Down
2 changes: 1 addition & 1 deletion flaxattention/kernel/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def _mha_forward(
mask_mod: _mask_mod_signature | None,
score_mod_grad: _score_mod_signature | None,
):
del backward_pass_impl
del backward_pass_impl, score_mod_grad
batch_size, seq_len, num_heads, head_dim = q.shape
block_q = min(block_q, seq_len)
block_k = min(block_k, seq_len)
Expand Down

0 comments on commit 221d097

Please sign in to comment.