Skip to content

Commit

Permalink
remove issue
Browse files Browse the repository at this point in the history
  • Loading branch information
zinccat committed Oct 25, 2024
1 parent 7a63d02 commit 0212c99
Showing 1 changed file with 0 additions and 1 deletion.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ seq_len for query = 1,
We can see that pure JAX implementation is actually the fastest, surpassing Palllas Flash Attention. The kernel also supports arbitrary query length and the inflection point is around 64, where the Palllas Flash Attention starts to outperform the pure JAX implementation when the query length is greater than 64. (For autograd, the inflection point is around 1024, which is quite bad).

## Issues
Ensure softmax safety in pure JAX implementation is costing ~15% performance.

## TODO
- [ ] Implement the block sparse optimization
Expand Down

0 comments on commit 0212c99

Please sign in to comment.