From 0212c99583ce6a2cac299e034ab484d1722866b1 Mon Sep 17 00:00:00 2001 From: ZincCat Date: Thu, 24 Oct 2024 21:53:24 -0400 Subject: [PATCH] remove issue --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index b1a7925..556a4c8 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,6 @@ seq_len for query = 1, We can see that pure JAX implementation is actually the fastest, surpassing Palllas Flash Attention. The kernel also supports arbitrary query length and the inflection point is around 64, where the Palllas Flash Attention starts to outperform the pure JAX implementation when the query length is greater than 64. (For autograd, the inflection point is around 1024, which is quite bad). ## Issues -Ensure softmax safety in pure JAX implementation is costing ~15% performance. ## TODO - [ ] Implement the block sparse optimization