Skip to content

Commit

Permalink
Warmup works, inference crashes
Browse files Browse the repository at this point in the history
Probably due to incorrect head size in the cache
  • Loading branch information
danieldk committed Jun 27, 2024
1 parent 451cdb8 commit a485f10
Showing 1 changed file with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -404,11 +404,14 @@ def forward(
)
# Decode
else:
logger.warning(
f"paged attention -> query: {query.shape}, key: {key.shape}, value: {value_states.shape}"
)
paged_attention(
attn_output,
query,
key,
value_states,
kv_cache[0],
kv_cache[1],
self.kv_head_mapping,
self.softmax_scale,
block_tables,
Expand All @@ -420,7 +423,6 @@ def forward(
logger.warning(f"attention output: {attn_output.shape}")
attn_output = attn_output[..., : self.v_head_dim]
logger.warning(f"attention output after unpad: {attn_output.shape}")
logger.warning(f"v_head_dim: {self.v_head_dim}")

return self.o_proj(attn_output.reshape(-1, self.num_heads * self.v_head_dim))

Expand Down

0 comments on commit a485f10

Please sign in to comment.