Skip to content

Commit

Permalink
better toolchain compability
Browse files Browse the repository at this point in the history
  • Loading branch information
ReinForce-II committed May 23, 2024
1 parent d13da05 commit ba1987f
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 13 deletions.
13 changes: 13 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ option(LLAMA_AVX512 "llama: enable AVX512"
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF)
option(LLAMA_AMX "llama: enable AMX" OFF)
option(LLAMA_FMA "llama: enable FMA" ${INS_ENB})
# in MSVC F16C is implied with AVX2/AVX512
if (NOT MSVC)
Expand Down Expand Up @@ -1073,6 +1074,14 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
endif()
if (LLAMA_AMX)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_TILE__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_TILE__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_INT8__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_INT8__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_BF16__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_BF16__>)
endif()
elseif (LLAMA_AVX2)
list(APPEND ARCH_FLAGS /arch:AVX2)
elseif (LLAMA_AVX)
Expand Down Expand Up @@ -1107,6 +1116,10 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
if (LLAMA_AVX512_BF16)
list(APPEND ARCH_FLAGS -mavx512bf16)
endif()
if (LLAMA_AMX)
list(APPEND ARCH_FLAGS -mavx512vl -mavx512dq)
list(APPEND ARCH_FLAGS -mamx-tile -mamx-int8 -mamx-bf16)
endif()
endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
message(STATUS "PowerPC detected")
Expand Down
47 changes: 34 additions & 13 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@
#if defined(__gnu_linux__)
#include <syscall.h>
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
#define ARCH_GET_XCOMP_PERM 0x1022
#define ARCH_REQ_XCOMP_PERM 0x1023
#define XFEATURE_XTILECFG 17
#define XFEATURE_XTILEDATA 18
#endif
#endif
Expand Down Expand Up @@ -1912,24 +1910,40 @@ static void ggml_transpose_pack4(void * restrict d, const size_t bd, const void
}
}
}

typedef struct __tile_config
{
uint8_t palette_id;
uint8_t start_row;
uint8_t reserved_0[14];
uint16_t colsb[8];
uint8_t reserved_1[16];
uint8_t rows[8];
uint8_t reserved_2[8];
} __tile_config_t;
#endif

static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc) {
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
if (nrc == AMX_TILE_MN) {
assert(n % (AMX_TILE_K*4/sizeof(ggml_bf16_t)) == 0);
__tile1024i tileyt = {AMX_TILE_MN, AMX_TILE_K*4};
__tile1024i tilext = {AMX_TILE_K, AMX_TILE_MN*4};
__tile1024i tilezt = {AMX_TILE_MN, AMX_TILE_MN*sizeof(float)};
__tile_zero(&tilezt);
// 0: zt, 1: yt, 2: xt
__tile_config_t cfg = {
.palette_id = 1,
.start_row = 0,
.colsb = {AMX_TILE_MN*sizeof(float), AMX_TILE_K*4, AMX_TILE_MN*4, 0,},
.rows = {AMX_TILE_MN, AMX_TILE_K, AMX_TILE_MN, 0,},
};
_tile_loadconfig(&cfg);
_tile_zero(0);
for (int i = 0; i < n; i+=AMX_TILE_K*4/sizeof(ggml_bf16_t)) {
ggml_bf16_t axt[AMX_TILE_K*AMX_TILE_MN*4/sizeof(ggml_bf16_t)];
ggml_transpose_pack4(axt, AMX_TILE_MN*4, x + i, bx, AMX_TILE_MN, AMX_TILE_K);
__tile_loadd(&tileyt, y + i, by);
__tile_loadd(&tilext, axt, AMX_TILE_MN*4);
__tile_dpbf16ps(&tilezt, tileyt, tilext);
_tile_loadd(1, y + i, by);
_tile_loadd(2, axt, AMX_TILE_MN*4);
_tile_dpbf16ps(0, 1, 2);
}
__tile_stored(s, bs*sizeof(float), tilezt);
_tile_stored(0, s, bs*sizeof(float));
return;
}
#endif
Expand Down Expand Up @@ -20133,10 +20147,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {

set_numa_thread_affinity(state->ith);

#if defined(__gnu_linux__)
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
#if defined(__AMX_TILE__) && defined(__AMX_BF16__) && defined(__gnu_linux__)
// refer to https://www.kernel.org/doc/Documentation/x86/xstate.rst
syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA);
#endif
#endif

int node_n = -1;
Expand Down Expand Up @@ -23715,4 +23728,12 @@ int ggml_cpu_has_matmul_int8(void) {
#endif
}

int ggml_cpu_has_amx(void) {
#if defined(__AMX_TILE__) && defined(__AMX_BF16__)
return 1;
#else
return 0;
#endif
}

////////////////////////////////////////////////////////////////////////////////
1 change: 1 addition & 0 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -2436,6 +2436,7 @@ extern "C" {
GGML_API int ggml_cpu_has_sycl (void);
GGML_API int ggml_cpu_has_vsx (void);
GGML_API int ggml_cpu_has_matmul_int8(void);
GGML_API int ggml_cpu_has_amx (void);

//
// Internal types and functions exposed for tests and benchmarks
Expand Down
1 change: 1 addition & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17891,6 +17891,7 @@ const char * llama_print_system_info(void) {
s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | ";
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
s += "MATMUL_INT8 = " + std::to_string(ggml_cpu_has_matmul_int8()) + " | ";
s += "AMX = " + std::to_string(ggml_cpu_has_amx()) + " | ";
#ifdef GGML_USE_LLAMAFILE
s += "LLAMAFILE = 1 | ";
#else
Expand Down

0 comments on commit ba1987f

Please sign in to comment.