You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi!
I found a weird detail when checking your codebase. When I was evaluating duoattention with Mistral-32k, I found this function serves as the implementation of attention:
If I'm not mistaken, your code will first compute (n-50)x(n-50) full attention for all kv_heads and then do decoding in the following, which is in conflict with the description in DuoAttention paper.
Is there anything I made it wrong?
The text was updated successfully, but these errors were encountered:
Chunked pre-filling in our approach is handled at a higher level, rather than within the attention function implementation itself. Specifically, the code section you referenced here does perform full attention for q_len == kv_seq_len, using flash attention as a pre-filling mechanism. However, the chunked pre-filling, as applied in our experiments, is set up in the outer loop (see this code section). In our paper experiments, we used a chunk size of 32K to pre-fill the benchmarks.
Since most samples in the LongBench dataset are shorter than 32K, and our benchmarks run comfortably on a single A100 GPU, we disabled chunked pre-filling in the publicly released code. The results from this approach are very close to those obtained using 32K chunked pre-filling.
Thank you for your response.
Are those model accuracy results listed in paper also measured under the condition of using full attention for pre-filling?
Hi!
I found a weird detail when checking your codebase. When I was evaluating duoattention with Mistral-32k, I found this function serves as the implementation of attention:
duo-attention/duo_attn/patch/mistral.py
Lines 146 to 306 in aa97830
These lines of code above conduct a full attention computation.
duo-attention/duo_attn/patch/mistral.py
Lines 225 to 233 in aa97830
If I'm not mistaken, your code will first compute (n-50)x(n-50) full attention for all kv_heads and then do decoding in the following, which is in conflict with the description in DuoAttention paper.
Is there anything I made it wrong?
The text was updated successfully, but these errors were encountered: