Skip to content

Commit

Permalink
optimize npu qwen2 (intel-analytics#12107)
Browse files Browse the repository at this point in the history
  • Loading branch information
rnwang04 authored Sep 20, 2024
1 parent 0239902 commit 03bd01c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -399,22 +399,22 @@ def set_weights_async(self, op_id, weights):
self.setWeights(offset, op_id, *weights)

@staticmethod
def run_decoders(inputs, decoders):
def run_decoders(inputs, decoders, models_ptr=None):
x_np = [elem.to(torch.float16).numpy() for elem in inputs]

num_decoders = len(decoders)
num_inputs = len(x_np)

with record_function(f"npu_factory"):

if models_ptr is None:
array_type = ctypes.POINTER(ctypes.c_char) * num_decoders
models_ptr = array_type(
*[decoders[i]._mm for i in range(num_decoders)]
)
inputs_ptr = (ctypes.c_void_p * num_inputs)(
*[x.ctypes.data_as(ctypes.c_void_p) for x in x_np]
)
backend_lib.run_decoders(models_ptr, inputs_ptr, num_decoders, num_inputs)

inputs_ptr = (ctypes.c_void_p * num_inputs)(
*[x.ctypes.data_as(ctypes.c_void_p) for x in x_np]
)
backend_lib.run_decoders(models_ptr, inputs_ptr, num_decoders, num_inputs)

hidden_states = decoders[-1].torch_out[0]
new_key_states = []
Expand Down
8 changes: 6 additions & 2 deletions python/llm/src/ipex_llm/transformers/npu_models/qwen2_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
import torch
import time

import ctypes
from typing import Optional, Sequence, List, Union, Any, Tuple
import numpy as np

Expand Down Expand Up @@ -379,6 +379,9 @@ def __init__(
self.backend_decoders[i].set_weights(self.op_id, curr_parameters)
offset = offset + curr_linear_ops

array_type = ctypes.POINTER(ctypes.c_char) * intra_stages
self.models_ptr = array_type(*[self.backend_decoders[i]._mm for i in range(intra_stages)])

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -402,7 +405,8 @@ def forward(

hidden_states, new_keys, new_values = LowBitQwenMultiDecoderlayer.run_decoders(
inputs,
decoders=self.backend_decoders)
self.backend_decoders,
self.models_ptr)

if self.do_print:
print("outputs:", hidden_states)
Expand Down

0 comments on commit 03bd01c

Please sign in to comment.