Skip to content

Commit

Permalink
fix llama3.1/3.2 quantize kv check (intel-analytics#12302)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Oct 31, 2024
1 parent 416c191 commit 72605c7
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
7 changes: 5 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/llama32.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def llama_model_forward(
# IPEX-LLM OPT start: kv cache and quantize kv cache
inputs = input_ids if input_ids is not None else inputs_embeds
use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
use_quantize_kv = use_quantize_kv_cache(
self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads // self.config.num_key_value_heads
)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
Expand Down Expand Up @@ -114,7 +117,7 @@ def llama_model_forward(

# IPEX-LLM OPT start: use fused rope
if (should_use_fuse_rope(hidden_states, position_ids, False)
and self.rotary_emb.rope_type == "llama3"):
and self.rotary_emb.rope_type in ["default", "llama3"]):
position_embeddings = self.rotary_emb.inv_freq
# IEPX_LLM OPT end

Expand Down
5 changes: 4 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ def mllama_text_model_forward(
# IPEX-LLM OPT start: kv cache and quantize kv cache
inputs = input_ids if input_ids is not None else inputs_embeds
use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
use_quantize_kv = use_quantize_kv_cache(
self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads // self.config.num_key_value_heads
)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
Expand Down

0 comments on commit 72605c7

Please sign in to comment.