Skip to content

Commit

Permalink
update code for NPU qwen2 (intel-analytics#12094)
Browse files Browse the repository at this point in the history
* update code

* fix
  • Loading branch information
rnwang04 authored Sep 20, 2024
1 parent db7500b commit 09b8c80
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 24 deletions.
1 change: 0 additions & 1 deletion python/llm/src/ipex_llm/transformers/npu_models/lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def run(
Returns:
np.ndarray: result
"""
self.prefetchWeights(1, verify_size=False)
self.set_input_tensor(X, 0)
self.elapsed = backend_lib.run(self._mm)
if len(self.out) == 1:
Expand Down
50 changes: 27 additions & 23 deletions python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,35 +990,39 @@ def qwen2_fused_model_forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

past_key_values_length = 0
if seq_length > 1:
past_key_values_length = 0

from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache
from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache

if use_cache and not isinstance(past_key_values, DynamicFusedNormalCache):
past_key_values = DynamicFusedNormalCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_seq_length()
if use_cache and not isinstance(past_key_values, DynamicFusedNormalCache):
past_key_values = DynamicFusedNormalCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_seq_length()

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()

from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
sliding_window=self.config.sliding_window,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()

from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask

attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
attention_mask = None
position_ids = None

# embed positions
hidden_states = inputs_embeds
Expand Down

0 comments on commit 09b8c80

Please sign in to comment.