diff --git a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py index 611270f8b12..6f959aead87 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/mp_models_base.py @@ -399,22 +399,22 @@ def set_weights_async(self, op_id, weights): self.setWeights(offset, op_id, *weights) @staticmethod - def run_decoders(inputs, decoders): + def run_decoders(inputs, decoders, models_ptr=None): x_np = [elem.to(torch.float16).numpy() for elem in inputs] num_decoders = len(decoders) num_inputs = len(x_np) - with record_function(f"npu_factory"): - + if models_ptr is None: array_type = ctypes.POINTER(ctypes.c_char) * num_decoders models_ptr = array_type( *[decoders[i]._mm for i in range(num_decoders)] ) - inputs_ptr = (ctypes.c_void_p * num_inputs)( - *[x.ctypes.data_as(ctypes.c_void_p) for x in x_np] - ) - backend_lib.run_decoders(models_ptr, inputs_ptr, num_decoders, num_inputs) + + inputs_ptr = (ctypes.c_void_p * num_inputs)( + *[x.ctypes.data_as(ctypes.c_void_p) for x in x_np] + ) + backend_lib.run_decoders(models_ptr, inputs_ptr, num_decoders, num_inputs) hidden_states = decoders[-1].torch_out[0] new_key_states = [] diff --git a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py index 6266e5c6bb2..ab4c19481bc 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py @@ -17,7 +17,7 @@ import os import torch import time - +import ctypes from typing import Optional, Sequence, List, Union, Any, Tuple import numpy as np @@ -379,6 +379,9 @@ def __init__( self.backend_decoders[i].set_weights(self.op_id, curr_parameters) offset = offset + curr_linear_ops + array_type = ctypes.POINTER(ctypes.c_char) * intra_stages + self.models_ptr = array_type(*[self.backend_decoders[i]._mm for i in range(intra_stages)]) + def forward( self, hidden_states: torch.Tensor, @@ -402,7 +405,8 @@ def forward( hidden_states, new_keys, new_values = LowBitQwenMultiDecoderlayer.run_decoders( inputs, - decoders=self.backend_decoders) + self.backend_decoders, + self.models_ptr) if self.do_print: print("outputs:", hidden_states)