diff --git a/python/llm/src/ipex_llm/transformers/npu_models/lm_head.py b/python/llm/src/ipex_llm/transformers/npu_models/lm_head.py index 357eddb5e88..3dc05b6a978 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/lm_head.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/lm_head.py @@ -85,7 +85,6 @@ def run( Returns: np.ndarray: result """ - self.prefetchWeights(1, verify_size=False) self.set_input_tensor(X, 0) self.elapsed = backend_lib.run(self._mm) if len(self.out) == 1: diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index b576c0d3995..6266e5c6bb2 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -990,35 +990,39 @@ def qwen2_fused_model_forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - past_key_values_length = 0 + if seq_length > 1: + past_key_values_length = 0 - from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache + from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache - if use_cache and not isinstance(past_key_values, DynamicFusedNormalCache): - past_key_values = DynamicFusedNormalCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, DynamicFusedNormalCache): + past_key_values = DynamicFusedNormalCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_seq_length() - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask + + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, + sliding_window=self.config.sliding_window, ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: - position_ids = position_ids.view(-1, seq_length).long() - - from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask - - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) + attention_mask = None + position_ids = None # embed positions hidden_states = inputs_embeds