Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump MNN KleidiAI ukernel to qai8_qsi4_sme2 ukernel #3101

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions source/backend/cpu/CPURuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
// ref: https://cs.android.com/android/platform/superproject/+/master:bionic/libc/kernel/uapi/asm-arm64/asm/hwcap.h;drc=04da58f5b3bc40dbbafb4f8422aa2991479d9e1e;l=70
#define CPUINFO_ARM_LINUX_FEATURE_I8MM UINT32_C(0x00002000)
#define CPUINFO_ARM_LINUX_FEATURE_SVE UINT32_C(0x00400000)
#define CPUINFO_ARM_LINUX_FEATURE_SVE2 UINT32_C(0x00000002)

#define CPUINFO_ARM_LINUX_FEATURE2_SVE2 UINT32_C(0x00000002)
#define CPUINFO_ARM_LINUX_FEATURE2_SME2 UINT64_C(0x0000002000000000)
#endif

#include <algorithm>
Expand Down Expand Up @@ -1279,13 +1281,18 @@ static void _getInfoApple(MNNCPUInfo* cpuinfo_isa) {
if (have_feature("hw.optional.arm.FEAT_I8MM")) {
cpuinfo_isa->i8mm = true;
}
if (have_feature("hw.optional.arm.FEAT_SME2")) {
cpuinfo_isa->sme2 = true;
}
}
#endif

#if defined(__linux__) && defined(__aarch64__)
static void _getInfoAux(MNNCPUInfo* cpuinfo_isa) {
// Use AUX to get info for linux-aarch64
uint32_t isa_features = 0;
uint64_t isa_features2 = 0;

isa_features = (uint32_t)getauxval(AT_HWCAP);
if (isa_features & CPUINFO_ARM_LINUX_FEATURE_ASIMDDP) {
cpuinfo_isa->dot = true;
Expand All @@ -1297,10 +1304,14 @@ static void _getInfoAux(MNNCPUInfo* cpuinfo_isa) {
if (isa_features & CPUINFO_ARM_LINUX_FEATURE_I8MM) {
cpuinfo_isa->i8mm = true;
}
isa_features = (uint32_t)getauxval(AT_HWCAP2);
if (isa_features & CPUINFO_ARM_LINUX_FEATURE_SVE2) {

isa_features2 = (uint64_t)getauxval(AT_HWCAP2);
if (isa_features & CPUINFO_ARM_LINUX_FEATURE2_SVE2) {
cpuinfo_isa->sve2 = true;
}
if (isa_features & CPUINFO_ARM_LINUX_FEATURE2_SME2) {
cpuinfo_isa->sme2 = true;
}
}
#endif

Expand Down Expand Up @@ -1351,6 +1362,7 @@ static void _fillInfo(MNNCPUInfo* cpuinfo_isa) {
cpuinfo_isa->fp16arith = false;
cpuinfo_isa->i8mm = false;
cpuinfo_isa->sve2 = false;
cpuinfo_isa->sme2 = false;
// android
/**Get CPU Info*/
#ifdef __linux__
Expand Down Expand Up @@ -1447,6 +1459,7 @@ static void _fillInfo(MNNCPUInfo* cpuinfo_isa) {
cpuinfo_isa->dot = true;
#endif

MNN_PRINT("The device supports: i8sdot:%d, fp16:%d, i8mm: %d, sve2: %d\n", cpuinfo_isa->dot, cpuinfo_isa->fp16arith, cpuinfo_isa->i8mm, cpuinfo_isa->sve2);
MNN_PRINT("The device supports: i8sdot:%d, fp16:%d, i8mm: %d, sve2: %d, sme2: %d\n",
cpuinfo_isa->dot, cpuinfo_isa->fp16arith, cpuinfo_isa->i8mm, cpuinfo_isa->sve2, cpuinfo_isa->sme2);
return;
}
1 change: 1 addition & 0 deletions source/backend/cpu/CPURuntime.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct MNNCPUInfo {
bool dot;
bool i8mm;
bool sve2;
bool sme2;
std::vector<CPUGroup> groups;
int cpuNumber = 0;
};
Expand Down
10 changes: 10 additions & 0 deletions source/backend/cpu/arm/kleidiAI/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ if(CMAKE_C_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS
endif()

list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai.cpp)
list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai_util.cpp)
list(APPEND MNN_KleidiAI_HEADERS ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai.h)
list(APPEND MNN_KleidiAI_HEADERS ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai_util.h)

add_library(
MNN_KleidiAI
Expand All @@ -41,6 +43,7 @@ include_directories(
set(KLEIDIAI_FILES_SCALAR
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxpo_qsu4cxs1s0.c
)

set(KLEIDIAI_FILES_NEON_DOTPROD
Expand All @@ -51,13 +54,20 @@ set(KLEIDIAI_FILES_NEON_I8MM
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c
)

set(KLEIDIAI_FILES_SME2_MOPA
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxpo4vlx4_1x4vl_sme2_sdot.c
)

# Selectively enable architecture features.
target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_SCALAR})
if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") AND NOT MSVC)
target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD})
target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_NEON_I8MM})
target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_SME2_MOPA})

set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS -march=armv8-a)
set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod)
set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm)
set_source_files_properties(${KLEIDIAI_FILES_SME2_MOPA} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+sve2)
endif()
67 changes: 24 additions & 43 deletions source/backend/cpu/arm/kleidiAI/kai/kai_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,21 @@ inline static size_t kai_get_datatype_size_in_bytes(enum kai_datatype dt) {
/// @param[in] f16 The f16 value
///
/// @return the f32 value
inline static float kai_cast_f32_f16(uint16_t f16) {
#if defined(__ARM_NEON)
inline static float kai_cast_f32_f16(uint16_t f16) {
__fp16 f32 = 0;
memcpy(&f32, &f16, sizeof(uint16_t));
return (float)f32;
#endif
}
#endif

/// Converts a scalar bf16 value to f32
/// @param[in] bf16 The f16 value
///
/// @return the f32 value
inline static float kai_cast_f32_bf16(uint16_t bf16) {
const uint32_t i32 = (bf16 << 16);
float f32;
float f32 = 0;
memcpy(&f32, &i32, sizeof(i32));
return f32;
}
Expand All @@ -116,79 +116,60 @@ inline static uint16_t kai_cast_bf16_f32(float f32) {
/// @param[in] f32 The f32 value
///
/// @return the f16 value
inline static uint16_t kai_cast_f16_f32(float f32) {
#if defined(__ARM_NEON)
inline static uint16_t kai_cast_f16_f32(float f32) {
uint16_t f16 = 0;
__fp16 tmp = f32;
__fp16 tmp = (__fp16)f32;
memcpy(&f16, &tmp, sizeof(uint16_t));
return f16;
#endif
}
#endif

inline static size_t kai_roundup(size_t a, size_t b) {
return ((a + b - 1) / b) * b;
}

#ifdef __ARM_FEATURE_SVE

#ifdef __ARM_FEATURE_SVE2
/// Gets the SME vector length for 8-bit elements.
inline static uint64_t kai_get_sme_vector_length_u8(void) {
uint64_t res = 0;

__asm__ __volatile__(
".inst 0xd503477f // SMSTART ZA\n"
"cntb %0\n"
".inst 0xd503467f // SMSTOP\n"
".inst 0x04bf5827 // rdsvl x7, #1\n"
"mov %0, x7\n"
: "=r"(res)
:
: "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16",
"z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");

: /* no inputs */
: "x7");
return res;
}

/// Gets the SME vector length for 16-bit elements.
inline static uint64_t kai_get_sme_vector_length_u16(void) {
uint64_t res = 0;

__asm__ __volatile__(
".inst 0xd503477f // SMSTART ZA\n"
"cnth %0\n"
".inst 0xd503467f // SMSTOP\n"
: "=r"(res)
:
: "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16",
"z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");

return res;
return kai_get_sme_vector_length_u8() / 2;
}

/// Gets the SME vector length for 32-bit elements.
inline static uint64_t kai_get_sme_vector_length_u32(void) {
uint64_t res = 0;

__asm__ __volatile__(
".inst 0xd503477f // SMSTART ZA\n"
"cntw %0\n"
".inst 0xd503467f // SMSTOP\n"
: "=r"(res)
:
: "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16",
"z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");

return res;
return kai_get_sme_vector_length_u8() / 4;
}

#endif // __ARM_FEATURE_SVE
#endif // __ARM_FEATURE_SVE2

/// Extends the sign bit of int 4-bit value (stored in int8_t variable)
/// @param[in] value The 4-bit int value
///
/// @return the int8_t value with sign extended
inline static int8_t kai_ext_sign_i8_i4(int8_t value) {
return (value ^ 0x8) - 8;
// Make sure value holds correct int4 value
KAI_ASSERT(value <= 0xF);

return (value ^ 0x8) - 8; // NOLINT(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
}

/// Parameter struct for RHS matrix packing
struct kai_rhs_pack_qs4cxs1s0_param {
int8_t lhs_zero_point; /**< LHS Matrix quantization zero-point */
uint8_t rhs_zero_point; /**< RHS Matrix quantization zero-point */
};

#ifdef __cplusplus
}
#endif
Loading