Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seems like chunked pre-filling is inactive when evaluating longbench. #9

Open
BirdChristopher opened this issue Nov 8, 2024 · 2 comments

Comments

@BirdChristopher
Copy link

Hi!
I found a weird detail when checking your codebase. When I was evaluating duoattention with Mistral-32k, I found this function serves as the implementation of attention:

def mistral_duo_attention_forward_one_way_reordered(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
)
# new data structure for past_key_value
# past_key_value = (full_KV, streaming_KV)
# full_KV: (2 x bsz, num_full_key_value_heads, full_kv_seq_len, head_dim)
# streaming_KV: (2 x bsz, num_streaming_key_value_heads, cache_size, head_dim)
kv_seq_len = key_states.shape[1]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[2]
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states,
key_states,
cos,
sin,
unsqueeze_dim=2, # unsqueeze_dim=2 for the flash attention
)
if not hasattr(self, "full_attn_head_mask") or self.full_attn_head_mask is None:
self.full_attn_head_mask = self.full_attention_heads > 0.5
self.num_full_attn_head = self.full_attn_head_mask.sum().item()
self.num_streaming_attn_head = (
self.num_key_value_heads - self.num_full_attn_head
)
self.num_full_query_head = self.num_full_attn_head * self.num_key_value_groups
self.num_streaming_query_head = self.num_heads - self.num_full_query_head
full_key_states = key_states[:, :, : self.num_full_attn_head, :]
full_value_states = value_states[:, :, : self.num_full_attn_head, :]
streaming_key_states = key_states[:, :, self.num_full_attn_head :, :]
streaming_value_states = value_states[:, :, self.num_full_attn_head :, :]
if past_key_value is not None:
# reuse k, v, self_attention
past_full_KV = past_key_value[0].transpose(1, 2)
past_streaming_KV = past_key_value[1].transpose(1, 2)
past_full_key_states = past_full_KV[:bsz]
past_full_value_states = past_full_KV[bsz:]
full_key_states = torch.cat([past_full_key_states, full_key_states], dim=1)
full_value_states = torch.cat(
[past_full_value_states, full_value_states], dim=1
)
past_streaming_key_states = past_streaming_KV[:bsz]
past_streaming_value_states = past_streaming_KV[bsz:]
streaming_key_states = torch.cat(
[past_streaming_key_states, streaming_key_states], dim=1
)
streaming_value_states = torch.cat(
[past_streaming_value_states, streaming_value_states], dim=1
)
if q_len == kv_seq_len:
# pre-filling: use flash attention
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
causal=True,
dropout_p=0.0,
)
else:
# decoding or continous filling
if self.num_full_attn_head > 0:
full_query_states = query_states[:, :, : self.num_full_query_head, :]
full_attn_output = flash_attn_func(
full_query_states,
full_key_states,
full_value_states,
causal=True,
dropout_p=0.0,
)
else:
full_attn_output = None
if self.num_streaming_attn_head > 0:
streaming_query_states = query_states[:, :, self.num_full_query_head :, :]
streaming_attn_output = flash_attn_func(
streaming_query_states,
streaming_key_states,
streaming_value_states,
causal=True,
dropout_p=0.0,
)
else:
streaming_attn_output = None
if full_attn_output is None:
attn_output = streaming_attn_output
elif streaming_attn_output is None:
attn_output = full_attn_output
else:
attn_output = torch.cat([full_attn_output, streaming_attn_output], dim=2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if streaming_key_states.shape[1] > self.recent_size + self.sink_size:
recent_key_states = streaming_key_states[:, -self.recent_size :, :, :].clone()
streaming_key_states[
:, self.sink_size : self.sink_size + self.recent_size, :, :
].copy_(recent_key_states)
streaming_key_states = streaming_key_states[
:, : self.sink_size + self.recent_size, :, :
]
recent_value_states = streaming_value_states[
:, -self.recent_size :, :, :
].clone()
streaming_value_states[
:, self.sink_size : self.sink_size + self.recent_size, :, :
].copy_(recent_value_states)
streaming_value_states = streaming_value_states[
:, : self.sink_size + self.recent_size, :, :
]
past_key_value = (
(
torch.cat([full_key_states, full_value_states], dim=0).transpose(1, 2),
torch.cat([streaming_key_states, streaming_value_states], dim=0).transpose(
1, 2
),
)
if use_cache
else None
)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value

These lines of code above conduct a full attention computation.

if q_len == kv_seq_len:
# pre-filling: use flash attention
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
causal=True,
dropout_p=0.0,
)

If I'm not mistaken, your code will first compute (n-50)x(n-50) full attention for all kv_heads and then do decoding in the following, which is in conflict with the description in DuoAttention paper.

Is there anything I made it wrong?

@Guangxuan-Xiao
Copy link
Contributor

Chunked pre-filling in our approach is handled at a higher level, rather than within the attention function implementation itself. Specifically, the code section you referenced here does perform full attention for q_len == kv_seq_len, using flash attention as a pre-filling mechanism. However, the chunked pre-filling, as applied in our experiments, is set up in the outer loop (see this code section). In our paper experiments, we used a chunk size of 32K to pre-fill the benchmarks.

Since most samples in the LongBench dataset are shorter than 32K, and our benchmarks run comfortably on a single A100 GPU, we disabled chunked pre-filling in the publicly released code. The results from this approach are very close to those obtained using 32K chunked pre-filling.

@BirdChristopher
Copy link
Author

Thank you for your response.
Are those model accuracy results listed in paper also measured under the condition of using full attention for pre-filling?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants