Skip to content

Commit

Permalink
update benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
zinccat committed Oct 22, 2024
1 parent 6ae428b commit e458bb6
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 4 deletions.
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Porting [FlexAttention](https://github.com/pytorch-labs/attention-gym) to pure JAX.

Example usage (**For faster performance using Flash Attention, check examples/example.py**):
Example usage:

```python
import jax
Expand Down Expand Up @@ -58,6 +58,19 @@ if __name__ == "__main__":

print(output.shape)
# (8, 8, 2048, 64)

# Autograd
def fn(query, key, value):
return flax_attention_pallas(
query,
key,
value,
score_mod=checkerboard,
).sum()
grad_fn = jax.grad(fn, 0)
grad_fn = jax.jit(grad_fn)

grad = grad_fn(query, key, value)
```

## Installation
Expand All @@ -82,3 +95,6 @@ Float16:
- FlaxAttention (This repo): 0.13s

We can see that the performance is about 20% slower than the original implementation. There are still some optimizations to be done.

## Issues
Autograd for Pallas is quite slow.
68 changes: 65 additions & 3 deletions examples/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def causal(batch: Array, head: Array, q_idx: Array, k_idx: Array) -> Array:
score_mod=checkerboard,
# block_mask=block_mask,
)
output.block_until_ready()
# print(output[0, 0, 0])

# benchmark
Expand All @@ -75,7 +76,7 @@ def causal(batch: Array, head: Array, q_idx: Array, k_idx: Array) -> Array:
score_mod=checkerboard,
# block_mask=block_mask,
)
output[0].block_until_ready()
output.block_until_ready()
end = timer()
print("Pure jax time taken:", end - start)

Expand All @@ -88,6 +89,8 @@ def causal(batch: Array, head: Array, q_idx: Array, k_idx: Array) -> Array:
key,
value,
)
output.block_until_ready()

start = timer()
for _ in range(100):
output = dot_product_attention(
Expand All @@ -109,6 +112,7 @@ def causal(batch: Array, head: Array, q_idx: Array, k_idx: Array) -> Array:
score_mod=checkerboard,
# mask_mod=causal,
)
output.block_until_ready()

start = timer()
for _ in range(100):
Expand All @@ -122,5 +126,63 @@ def causal(batch: Array, head: Array, q_idx: Array, k_idx: Array) -> Array:
output.block_until_ready()
end = timer()
print("Pallas attention time taken:", end - start)
print(output.shape)
# print(output[0, 0, 0])

def fn(query, key, value):
return flax_attention_pallas(
query,
key,
value,
score_mod=checkerboard,
).sum()
grad_fn = jax.grad(fn, 0)
grad_fn = jax.jit(grad_fn)

# warm up
grad = grad_fn(query, key, value)
grad.block_until_ready()

start = timer()
for _ in range(100):
grad = grad_fn(query, key, value)
grad.block_until_ready()
end = timer()
print("Gradient time taken:", end - start)

# original palllas attention
from jax.experimental.pallas.ops.gpu.attention import mha

def fn2(query, key, value):
return mha(
query,
key,
value,
segment_ids=None,
block_q=64,
block_k=64
).sum()

grad_fn2 = jax.grad(fn2, 0)

query = jnp.moveaxis(query, 1, 2)
key = jnp.moveaxis(key, 1, 2)
value = jnp.moveaxis(value, 1, 2)

output = mha(query, key, value, segment_ids=None, block_q=64, block_k=64)
output.block_until_ready()

start = timer()
for _ in range(100):
output = mha(query, key, value, segment_ids=None)
end = timer()
output.block_until_ready()
print("Original Pallas attention time taken:", end - start)

# warm up
grad = grad_fn2(query, key, value)

start = timer()
for _ in range(100):
grad = grad_fn2(query, key, value)
end = timer()
grad.block_until_ready()
print("Original Pallas attention time taken:", end - start)
61 changes: 61 additions & 0 deletions examples/benchmark_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
import torch
import numpy as np

if __name__ == "__main__":
batch_size = 8
num_heads = 8
seq_len_q = 2048
seq_len_kv = 2048
feature_size = 64

query_torch = torch.randn(batch_size, num_heads, seq_len_q, feature_size, dtype=torch.float16).cuda()
key_torch = torch.randn(batch_size, num_heads, seq_len_kv, feature_size, dtype=torch.float16).cuda()
value_torch = torch.randn(batch_size, num_heads, seq_len_kv, feature_size, dtype=torch.float16).cuda()

query_torch.requires_grad = True

def checkerboard_torch(score, batch, head, q_idx, k_idx):
score = torch.where((k_idx - q_idx) % 2 == 0, score * 0.5, score)
score = torch.where((k_idx - q_idx) % 2 == 1, score * 2.0, score)
return score

flex_attention = torch.compile(flex_attention)

# warmup

output_torch = flex_attention(
query_torch,
key_torch,
value_torch,
score_mod=checkerboard_torch,
)

grad = torch.autograd.grad(
output_torch.sum(), query_torch, create_graph=True
)[0]

# benchmark
from timeit import default_timer as timer
start = timer()
for _ in range(100):
output_torch = flex_attention(
query_torch,
key_torch,
value_torch,
score_mod=checkerboard_torch,
)
torch.cuda.synchronize()
end = timer()

print("Time taken for 100 iterations: ", end - start)

start = timer()
for _ in range(100):
grad = torch.autograd.grad(
output_torch.sum(), query_torch, create_graph=True
)[0]
torch.cuda.synchronize()
end = timer()

print("Time taken for 100 iterations backprop: ", end - start)

0 comments on commit e458bb6

Please sign in to comment.