Skip to content

Commit

Permalink
Add provisions for windows support for BF16 code including CMake prov…
Browse files Browse the repository at this point in the history
…ision for enabling AVX512_BF16 (ggerganov#7258)
  • Loading branch information
Srihari-mcw authored May 20, 2024
1 parent d359f30 commit 33c8d50
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 8 deletions.
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ option(LLAMA_AVX2 "llama: enable AVX2"
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF)
option(LLAMA_FMA "llama: enable FMA" ${INS_ENB})
# in MSVC F16C is implied with AVX2/AVX512
if (NOT MSVC)
Expand Down Expand Up @@ -1060,6 +1061,10 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
endif()
if (LLAMA_AVX512_BF16)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
endif()
elseif (LLAMA_AVX2)
list(APPEND ARCH_FLAGS /arch:AVX2)
elseif (LLAMA_AVX)
Expand Down Expand Up @@ -1091,6 +1096,9 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
if (LLAMA_AVX512_VNNI)
list(APPEND ARCH_FLAGS -mavx512vnni)
endif()
if (LLAMA_AVX512_BF16)
list(APPEND ARCH_FLAGS -mavx512bf16)
endif()
endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
message(STATUS "PowerPC detected")
Expand Down
12 changes: 12 additions & 0 deletions ggml-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))

#if defined(_WIN32)

#define m512bh(p) p
#define m512i(p) p

#else

#define m512bh(p) (__m512bh)(p)
#define m512i(p) (__m512i)(p)

#endif

/**
* Converts brain16 to float32.
*
Expand Down
24 changes: 16 additions & 8 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -406,10 +406,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
int i = 0;
#if defined(__AVX512BF16__)
for (; i + 32 <= n; i += 32) {
_mm512_storeu_ps(
(__m512 *)(y + i),
(__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
_mm512_loadu_ps(x + i)));
_mm512_storeu_si512(
(__m512i *)(y + i),
m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
_mm512_loadu_ps(x + i))));
}
#endif
for (; i < n; i++) {
Expand Down Expand Up @@ -1666,10 +1666,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
__m512 c1 = _mm512_setzero_ps();
__m512 c2 = _mm512_setzero_ps();
for (; i + 64 <= n; i += 64) {
c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
(__m512bh)_mm512_loadu_ps((const float *)(y + i)));
c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
(__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
m512bh(_mm512_loadu_si512((y + i))));
c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
m512bh(_mm512_loadu_si512((y + i + 32))));
}
sumf += (ggml_float)_mm512_reduce_add_ps(c1);
sumf += (ggml_float)_mm512_reduce_add_ps(c2);
Expand Down Expand Up @@ -23137,6 +23137,14 @@ int ggml_cpu_has_avx512_vnni(void) {
#endif
}

int ggml_cpu_has_avx512_bf16(void) {
#if defined(__AVX512BF16__)
return 1;
#else
return 0;
#endif
}

int ggml_cpu_has_fma(void) {
#if defined(__FMA__)
return 1;
Expand Down
1 change: 1 addition & 0 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -2390,6 +2390,7 @@ extern "C" {
GGML_API int ggml_cpu_has_avx512 (void);
GGML_API int ggml_cpu_has_avx512_vbmi(void);
GGML_API int ggml_cpu_has_avx512_vnni(void);
GGML_API int ggml_cpu_has_avx512_bf16(void);
GGML_API int ggml_cpu_has_fma (void);
GGML_API int ggml_cpu_has_neon (void);
GGML_API int ggml_cpu_has_arm_fma (void);
Expand Down
1 change: 1 addition & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18074,6 +18074,7 @@ const char * llama_print_system_info(void) {
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
s += "AVX512_VBMI = " + std::to_string(ggml_cpu_has_avx512_vbmi()) + " | ";
s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | ";
s += "AVX512_BF16 = " + std::to_string(ggml_cpu_has_avx512_bf16()) + " | ";
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
Expand Down

0 comments on commit 33c8d50

Please sign in to comment.