Skip to content

Commit

Permalink
Attention WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Jun 26, 2024
1 parent 49ac515 commit 08a09e3
Showing 1 changed file with 64 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,16 @@ def __init__(
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_heads
self.kv_lora_rank = config.kv_lora_rank
self.q_lora_rank = config.q_lora_rank
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.q_head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
self.v_head_dim = config.v_head_dim

self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.head_size,
dim=self.qk_rope_head_dim,
base=config.rope_theta,
device=weights.device,
)
Expand Down Expand Up @@ -318,23 +324,60 @@ def forward(
input_lengths,
max_s,
):
qkv = self.query_key_value(hidden_states)
if self.clip_qkv is not None:
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
if self.q_lora_rank is None:
query = self.query(hidden_states)
else:
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
query = query.view(-1, self.num_heads, self.q_head_size)
from loguru import logger

query, kv = qkv.split(
[
self.head_size * self.num_heads,
2 * self.head_size * self.num_key_value_heads,
],
dim=1,
query_nope, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)

compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)
kv = (
self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
)
# .transpose(1, 2)
)
k_nope, value_states = torch.split(
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
)

self.rotary_emb(query_pe, k_pe, cos, sin)

query[..., self.qk_nope_head_dim :] = query_pe
key = torch.empty_like(query)
key[..., : self.qk_nope_head_dim] = k_nope
key[..., self.qk_nope_head_dim :] = k_pe

logger.warning(
f"query: {query.shape}, key: {key.shape}, value: {value_states.shape}"
)

reshape_and_cache(key, value_states, kv_cache[0], kv_cache[1], slots)

logger.warning(
f"after reshape_and_cache -> query: {query.shape}, key: {key.shape}, value: {value_states.shape}"
)
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)

self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
# We need to pad the heads because Flash Attention does not support
# qk and v with different head sizes.
query = torch.nn.functional.pad(query, (0, 256 - self.q_head_size), value=0)
key = torch.nn.functional.pad(key, (0, 256 - self.q_head_size), value=0)
value_states = torch.nn.functional.pad(
value_states, (0, 256 - self.v_head_dim), value=0
)

reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
logger.warning(
f"after pad -> query: {query.shape}, key: {key.shape}, value: {value_states.shape}"
)

# output tensor
attn_output = torch.empty_like(query)
Expand All @@ -344,8 +387,8 @@ def forward(
# flash attention
attention(
query,
torch.select(kv, dim=1, index=0),
torch.select(kv, dim=1, index=1),
key,
value_states,
attn_output,
cu_seqlen_prefill,
max_s,
Expand All @@ -356,8 +399,8 @@ def forward(
paged_attention(
attn_output,
query,
kv_cache[0],
kv_cache[1],
key,
value_states,
self.kv_head_mapping,
self.softmax_scale,
block_tables,
Expand Down Expand Up @@ -665,9 +708,9 @@ def forward(
attn_output, res
)

moe_output = self.moe(normed_attn_res_output)
output = self.mlp(normed_attn_res_output)

return moe_output, attn_res
return output, attn_res


class DeepseekV2Model(torch.nn.Module):
Expand Down Expand Up @@ -711,7 +754,7 @@ def forward(

# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, max_s, hidden_states.dtype
)

Expand Down

0 comments on commit 08a09e3

Please sign in to comment.