diff --git a/CMakeLists.txt b/CMakeLists.txt index c5add8239c2bd..d5d61a7d25ca2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,6 +79,7 @@ option(LLAMA_AVX512 "llama: enable AVX512" option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF) +option(LLAMA_AMX "llama: enable AMX" OFF) option(LLAMA_FMA "llama: enable FMA" ${INS_ENB}) # in MSVC F16C is implied with AVX2/AVX512 if (NOT MSVC) @@ -1072,6 +1073,14 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW add_compile_definitions($<$:__AVX512BF16__>) add_compile_definitions($<$:__AVX512BF16__>) endif() + if (LLAMA_AMX) + add_compile_definitions($<$:__AMX_TILE__>) + add_compile_definitions($<$:__AMX_TILE__>) + add_compile_definitions($<$:__AMX_INT8__>) + add_compile_definitions($<$:__AMX_INT8__>) + add_compile_definitions($<$:__AMX_BF16__>) + add_compile_definitions($<$:__AMX_BF16__>) + endif() elseif (LLAMA_AVX2) list(APPEND ARCH_FLAGS /arch:AVX2) elseif (LLAMA_AVX) @@ -1106,6 +1115,10 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW if (LLAMA_AVX512_BF16) list(APPEND ARCH_FLAGS -mavx512bf16) endif() + if (LLAMA_AMX) + list(APPEND ARCH_FLAGS -mavx512vl -mavx512dq) + list(APPEND ARCH_FLAGS -mamx-tile -mamx-int8 -mamx-bf16) + endif() endif() elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64") message(STATUS "PowerPC detected") diff --git a/ggml.c b/ggml.c index 62d050e2b5b4b..1175db6a0fd44 100644 --- a/ggml.c +++ b/ggml.c @@ -27,9 +27,7 @@ #if defined(__gnu_linux__) #include #if defined(__AMX_TILE__) && defined(__AMX_BF16__) -#define ARCH_GET_XCOMP_PERM 0x1022 #define ARCH_REQ_XCOMP_PERM 0x1023 -#define XFEATURE_XTILECFG 17 #define XFEATURE_XTILEDATA 18 #endif #endif @@ -1904,24 +1902,40 @@ static void ggml_transpose_pack4(void * restrict d, const size_t bd, const void } } } + +typedef struct __tile_config +{ + uint8_t palette_id; + uint8_t start_row; + uint8_t reserved_0[14]; + uint16_t colsb[8]; + uint8_t reserved_1[16]; + uint8_t rows[8]; + uint8_t reserved_2[8]; +} __tile_config_t; #endif static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) { #if defined(__AMX_TILE__) && defined(__AMX_BF16__) if (nrc == AMX_TILE_MN) { assert(n % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0); - __tile1024i tileyt = {AMX_TILE_MN, AMX_TILE_K*4}; - __tile1024i tilext = {AMX_TILE_K, AMX_TILE_MN*4}; - __tile1024i tilezt = {AMX_TILE_MN, AMX_TILE_MN*sizeof(float)}; - __tile_zero(&tilezt); + // 0: zt, 1: yt, 2: xt + __tile_config_t cfg = { + .palette_id = 1, + .start_row = 0, + .colsb = {AMX_TILE_MN*sizeof(float), AMX_TILE_K*4, AMX_TILE_MN*4, 0,}, + .rows = {AMX_TILE_MN, AMX_TILE_K, AMX_TILE_MN, 0,}, + }; + _tile_loadconfig(&cfg); + _tile_zero(0); for (int i = 0; i < n; i+=AMX_TILE_K*4/sizeof(ggml_bf16_t)) { ggml_bf16_t axt[AMX_TILE_K*AMX_TILE_MN*4/sizeof(ggml_bf16_t)]; ggml_transpose_pack4(axt, AMX_TILE_MN*4, x + i, bx, AMX_TILE_MN, AMX_TILE_K); - __tile_loadd(&tileyt, y + i, by); - __tile_loadd(&tilext, axt, AMX_TILE_MN*4); - __tile_dpbf16ps(&tilezt, tileyt, tilext); + _tile_loadd(1, y + i, by); + _tile_loadd(2, axt, AMX_TILE_MN*4); + _tile_dpbf16ps(0, 1, 2); } - __tile_stored(s, bs*sizeof(float), tilezt); + _tile_stored(0, s, bs*sizeof(float)); return; } #endif @@ -19485,10 +19499,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { set_numa_thread_affinity(state->ith); -#if defined(__gnu_linux__) -#if defined(__AMX_TILE__) && defined(__AMX_BF16__) +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) && defined(__gnu_linux__) + // refer to https://www.kernel.org/doc/Documentation/x86/xstate.rst syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); -#endif #endif int node_n = -1; @@ -23045,4 +23058,12 @@ int ggml_cpu_has_matmul_int8(void) { #endif } +int ggml_cpu_has_amx(void) { +#if defined(__AMX_TILE__) && defined(__AMX_BF16__) + return 1; +#else + return 0; +#endif +} + //////////////////////////////////////////////////////////////////////////////// diff --git a/ggml.h b/ggml.h index f803ba7241fe1..72517142c6245 100644 --- a/ggml.h +++ b/ggml.h @@ -2421,6 +2421,7 @@ extern "C" { GGML_API int ggml_cpu_has_sycl (void); GGML_API int ggml_cpu_has_vsx (void); GGML_API int ggml_cpu_has_matmul_int8(void); + GGML_API int ggml_cpu_has_amx (void); // // Internal types and functions exposed for tests and benchmarks diff --git a/llama.cpp b/llama.cpp index f67cb7e232945..d3ef5238ec3a0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -18355,6 +18355,7 @@ const char * llama_print_system_info(void) { s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | "; s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; s += "MATMUL_INT8 = " + std::to_string(ggml_cpu_has_matmul_int8()) + " | "; + s += "AMX = " + std::to_string(ggml_cpu_has_amx()) + " | "; #ifdef GGML_USE_LLAMAFILE s += "LLAMAFILE = 1 | "; #else