From a485f10d115c504d498158a726a07c6fd71911c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Thu, 27 Jun 2024 10:47:31 +0200 Subject: [PATCH] Warmup works, inference crashes Probably due to incorrect head size in the cache --- .../models/custom_modeling/flash_deepseek_v2_modeling.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py index 01bcea7d39a..4ea1f469c98 100644 --- a/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py @@ -404,11 +404,14 @@ def forward( ) # Decode else: + logger.warning( + f"paged attention -> query: {query.shape}, key: {key.shape}, value: {value_states.shape}" + ) paged_attention( attn_output, query, - key, - value_states, + kv_cache[0], + kv_cache[1], self.kv_head_mapping, self.softmax_scale, block_tables, @@ -420,7 +423,6 @@ def forward( logger.warning(f"attention output: {attn_output.shape}") attn_output = attn_output[..., : self.v_head_dim] logger.warning(f"attention output after unpad: {attn_output.shape}") - logger.warning(f"v_head_dim: {self.v_head_dim}") return self.o_proj(attn_output.reshape(-1, self.num_heads * self.v_head_dim))