From 192090bae47960f0d38d4967abe398a5d190057e Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Mon, 22 Apr 2024 15:00:36 -0400 Subject: [PATCH] llamafile : improve sgemm.cpp (#6796) * llamafile : improve sgemm.cpp - Re-enable by default - Fix issue described in #6716 - Make code more abstract, elegant, and maintainable - Faster handling of weirdly shaped `m` an `n` edge cases * Address review comments * Help clang produce fma instructions * Address review comments --- CMakeLists.txt | 16 +- Makefile | 4 - ggml.c | 8 +- sgemm.cpp | 957 +++++++++++++++++++++---------------------------- 4 files changed, 412 insertions(+), 573 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f134a153bb4ff..58a1805ba10fd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,17 +43,11 @@ else() set(LLAMA_METAL_DEFAULT OFF) endif() -# TODO: fix this for Android CI -# https://github.com/ggerganov/llama.cpp/pull/6716#issuecomment-2061509191 -#if (CMAKE_SYSTEM_NAME MATCHES "ANDROID") -# set(LLAMA_LLAMAFILE_DEFAULT OFF) -#else() -# set(LLAMA_LLAMAFILE_DEFAULT ON) -#endif() - -# TODO: temporary disable until MoE is fixed -# https://github.com/ggerganov/llama.cpp/pull/6716 -set(LLAMA_LLAMAFILE_DEFAULT OFF) +if (CMAKE_SYSTEM_NAME MATCHES "ANDROID") + set(LLAMA_LLAMAFILE_DEFAULT OFF) +else() + set(LLAMA_LLAMAFILE_DEFAULT ON) +endif() # general option(BUILD_SHARED_LIBS "build shared libraries" OFF) diff --git a/Makefile b/Makefile index b0b2ea997ee12..24acb80136516 100644 --- a/Makefile +++ b/Makefile @@ -384,10 +384,6 @@ ifdef LLAMA_OPENBLAS MK_LDFLAGS += $(shell pkg-config --libs openblas) endif # LLAMA_OPENBLAS -# TODO: temporary disable until MoE is fixed -# https://github.com/ggerganov/llama.cpp/pull/6716 -LLAMA_NO_LLAMAFILE := 1 - ifndef LLAMA_NO_LLAMAFILE MK_CPPFLAGS += -DGGML_USE_LLAMAFILE OBJS += sgemm.o diff --git a/ggml.c b/ggml.c index a3b312e4aef59..086db96af7fcd 100644 --- a/ggml.c +++ b/ggml.c @@ -10825,7 +10825,7 @@ static void ggml_compute_forward_mul_mat( #endif #if GGML_USE_LLAMAFILE - if (nb10 == ggml_type_size(src1->type)) { + if (src1_cont) { for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), @@ -10878,15 +10878,13 @@ UseGgmlGemm1:; const size_t row_size = ggml_row_size(vec_dot_type, ne10); #if GGML_USE_LLAMAFILE - if (nb10 == ggml_type_size(src1->type) || src1->type != vec_dot_type) { + if (src1->type != vec_dot_type) { for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), - (const char *)wdata + ggml_row_size(vec_dot_type, - nb12/ggml_type_size(src1->type)*i12 + - nb13/ggml_type_size(src1->type)*i13), + (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size/ggml_type_size(vec_dot_type), (char *)dst->data + i12*nb2 + i13*nb3, nb1/ggml_type_size(dst->type), diff --git a/sgemm.cpp b/sgemm.cpp index 6900f04cfb242..531e12af361cc 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -50,6 +50,7 @@ #pragma GCC diagnostic ignored "-Wignored-attributes" #include "sgemm.h" +#include #include "ggml-impl.h" #include "ggml-quants.h" @@ -65,22 +66,6 @@ #define VECTOR_REGISTERS 16 #endif -// there will be blocks -#define BEGIN_KERNEL(RM, RN) \ - int ytiles = (m - m0) / RM; \ - int xtiles = (n - n0) / RN; \ - int tiles = ytiles * xtiles; \ - int duty = (tiles + nth - 1) / nth; \ - int start = duty * ith; \ - int end = start + duty; \ - if (end > tiles) \ - end = tiles; \ - for (int job = start; job < end; ++job) { \ - int i = m0 + job / xtiles * RM; \ - int j = n0 + job % xtiles * RN; - -#define END_KERNEL() } - #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) namespace { @@ -122,6 +107,45 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); } inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +//////////////////////////////////////////////////////////////////////////////////////////////////// +// VECTORIZED FUSED MULTIPLY ADD + +/** + * Computes a * b + c. + */ +template +inline U madd(T a, T b, U c) { + return add(mul(a, b), c); +} + +#if defined(__FMA__) +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +template <> +inline __m256 madd(__m256 a, __m256 b, __m256 c) { + return _mm256_fmadd_ps(a, b, c); +} +#endif +#if defined(__AVX512F__) +template <> +inline __m512 madd(__m512 a, __m512 b, __m512 c) { + return _mm512_fmadd_ps(a, b, c); +} +#endif +#endif + +#if defined(__ARM_FEATURE_FMA) +template <> +inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { + return vfmaq_f32(c, b, a); +} +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) +template <> +inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) { + return vfmaq_f16(c, b, a); +} +#endif +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED HORIZONTAL SUM @@ -213,36 +237,6 @@ template <> inline __m512 load(const ggml_fp16_t *p) { } #endif // __AVX512F__ -//////////////////////////////////////////////////////////////////////////////////////////////////// -// ABSTRACTIONS - -/** - * Computes a * b + c. - * - * This operation will become fused into a single arithmetic instruction - * if the hardware has support for this feature, e.g. Intel Haswell+ (c. - * 2013), AMD Bulldozer+ (c. 2011), etc. - */ -template -inline U madd(T a, T b, U c) { - return add(mul(a, b), c); -} - -/** - * Computes a * b + c with error correction. - * - * @see W. Kahan, "Further remarks on reducing truncation errors," - * Communications of the ACM, vol. 8, no. 1, p. 40, Jan. 1965, - * doi: 10.1145/363707.363723. - */ -template -inline U madder(T a, T b, U c, U *e) { - U y = sub(mul(a, b), *e); - U t = add(c, y); - *e = sub(sub(t, c), y); - return t; -} - //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT MATRIX MULTIPLICATION @@ -265,226 +259,179 @@ class tinyBLAS { private: NOINLINE void mnpack(int m0, int m, int n0, int n) { int mc, nc, mp, np; - if (m - m0 <= 0 || n - n0 <= 0) - return; - if (VECTOR_REGISTERS >= 32 && n - n0 >= 5 && m - m0 >= 5) { + switch ((std::min(m - m0, 5) << 4) | std::min(n - n0, 5)) { +#if VECTOR_REGISTERS == 32 + case 0x55: + mc = 5; + nc = 5; + gemm<5, 5>(m0, m, n0, n); + break; + case 0x45: + mc = 4; + nc = 5; + gemm<4, 5>(m0, m, n0, n); + break; + case 0x54: mc = 5; + nc = 4; + gemm<5, 4>(m0, m, n0, n); + break; + case 0x44: + mc = 4; + nc = 4; + gemm<4, 4>(m0, m, n0, n); + break; + case 0x53: + mc = 5; + nc = 3; + gemm<5, 3>(m0, m, n0, n); + break; + case 0x35: + mc = 3; nc = 5; - gemm5x5(m0, m, n0, n); - } else if (n - n0 >= 4 && m - m0 >= 3) { + gemm<3, 5>(m0, m, n0, n); + break; + case 0x43: + mc = 4; + nc = 3; + gemm<4, 3>(m0, m, n0, n); + break; +#else + case 0x55: + case 0x54: + case 0x53: + case 0x45: + case 0x44: + case 0x43: + mc = 4; + nc = 3; + gemm<4, 3>(m0, m, n0, n); + break; + case 0x35: +#endif + case 0x34: mc = 3; nc = 4; - gemm3x4(m0, m, n0, n); - } else if (n - n0 >= 4) { - mc = 1; + gemm<3, 4>(m0, m, n0, n); + break; + case 0x52: + mc = 5; + nc = 2; + gemm<5, 2>(m0, m, n0, n); + break; + case 0x33: + mc = 3; + nc = 3; + gemm<3, 3>(m0, m, n0, n); + break; + case 0x25: + mc = 2; + nc = 5; + gemm<2, 5>(m0, m, n0, n); + break; + case 0x42: + mc = 4; + nc = 2; + gemm<4, 2>(m0, m, n0, n); + break; + case 0x24: + mc = 2; nc = 4; - gemm1x4(m0, m, n0, n); - } else if (m - m0 >= 4) { + gemm<2, 4>(m0, m, n0, n); + break; + case 0x32: + mc = 3; + nc = 2; + gemm<3, 2>(m0, m, n0, n); + break; + case 0x23: + mc = 2; + nc = 3; + gemm<2, 3>(m0, m, n0, n); + break; + case 0x51: + mc = 5; + nc = 1; + gemm<5, 1>(m0, m, n0, n); + break; + case 0x41: mc = 4; nc = 1; - gemm4x1(m0, m, n0, n); - } else { + gemm<4, 1>(m0, m, n0, n); + break; + case 0x22: + mc = 2; + nc = 2; + gemm<2, 2>(m0, m, n0, n); + break; + case 0x15: + mc = 1; + nc = 5; + gemm<1, 5>(m0, m, n0, n); + break; + case 0x14: + mc = 1; + nc = 4; + gemm<1, 4>(m0, m, n0, n); + break; + case 0x31: + mc = 3; + nc = 1; + gemm<3, 1>(m0, m, n0, n); + break; + case 0x13: mc = 1; + nc = 3; + gemm<1, 3>(m0, m, n0, n); + break; + case 0x21: + mc = 2; nc = 1; - gemm1x1(m0, m, n0, n); + gemm<2, 1>(m0, m, n0, n); + break; + case 0x12: + mc = 1; + nc = 2; + gemm<1, 2>(m0, m, n0, n); + break; + case 0x11: + mc = 1; + nc = 1; + gemm<1, 1>(m0, m, n0, n); + break; + default: + return; } mp = m0 + (m - m0) / mc * mc; np = n0 + (n - n0) / nc * nc; mnpack(mp, m, n0, np); - mnpack(m0, mp, np, n); - mnpack(mp, m, np, n); - } - - NOINLINE void gemm5x5(int m0, int m, int n0, int n) { - BEGIN_KERNEL(5, 5) - D c00 = {0}; - D c01 = {0}; - D c02 = {0}; - D c03 = {0}; - D c04 = {0}; - D c10 = {0}; - D c11 = {0}; - D c12 = {0}; - D c13 = {0}; - D c14 = {0}; - D c20 = {0}; - D c21 = {0}; - D c22 = {0}; - D c23 = {0}; - D c24 = {0}; - D c30 = {0}; - D c31 = {0}; - D c32 = {0}; - D c33 = {0}; - D c34 = {0}; - D c40 = {0}; - D c41 = {0}; - D c42 = {0}; - D c43 = {0}; - D c44 = {0}; - for (int l = 0; l < k; l += KN) { - V k0 = load(B + ldb * (j + 0) + l); - V k1 = load(B + ldb * (j + 1) + l); - V k2 = load(B + ldb * (j + 2) + l); - V k3 = load(B + ldb * (j + 3) + l); - V k4 = load(B + ldb * (j + 4) + l); - V a0 = load(A + lda * (i + 0) + l); - c00 = madd(a0, k0, c00); - c01 = madd(a0, k1, c01); - c02 = madd(a0, k2, c02); - c03 = madd(a0, k3, c03); - c04 = madd(a0, k4, c04); - V a1 = load(A + lda * (i + 1) + l); - c10 = madd(a1, k0, c10); - c11 = madd(a1, k1, c11); - c12 = madd(a1, k2, c12); - c13 = madd(a1, k3, c13); - c14 = madd(a1, k4, c14); - V a2 = load(A + lda * (i + 2) + l); - c20 = madd(a2, k0, c20); - c21 = madd(a2, k1, c21); - c22 = madd(a2, k2, c22); - c23 = madd(a2, k3, c23); - c24 = madd(a2, k4, c24); - V a3 = load(A + lda * (i + 3) + l); - c30 = madd(a3, k0, c30); - c31 = madd(a3, k1, c31); - c32 = madd(a3, k2, c32); - c33 = madd(a3, k3, c33); - c34 = madd(a3, k4, c34); - V a4 = load(A + lda * (i + 4) + l); - c40 = madd(a4, k0, c40); - c41 = madd(a4, k1, c41); - c42 = madd(a4, k2, c42); - c43 = madd(a4, k3, c43); - c44 = madd(a4, k4, c44); - } - C[ldc * (j + 0) + (i + 0)] = hsum(c00); - C[ldc * (j + 0) + (i + 1)] = hsum(c10); - C[ldc * (j + 0) + (i + 2)] = hsum(c20); - C[ldc * (j + 0) + (i + 3)] = hsum(c30); - C[ldc * (j + 0) + (i + 4)] = hsum(c40); - C[ldc * (j + 1) + (i + 0)] = hsum(c01); - C[ldc * (j + 1) + (i + 1)] = hsum(c11); - C[ldc * (j + 1) + (i + 2)] = hsum(c21); - C[ldc * (j + 1) + (i + 3)] = hsum(c31); - C[ldc * (j + 1) + (i + 4)] = hsum(c41); - C[ldc * (j + 2) + (i + 0)] = hsum(c02); - C[ldc * (j + 2) + (i + 1)] = hsum(c12); - C[ldc * (j + 2) + (i + 2)] = hsum(c22); - C[ldc * (j + 2) + (i + 3)] = hsum(c32); - C[ldc * (j + 2) + (i + 4)] = hsum(c42); - C[ldc * (j + 3) + (i + 0)] = hsum(c03); - C[ldc * (j + 3) + (i + 1)] = hsum(c13); - C[ldc * (j + 3) + (i + 2)] = hsum(c23); - C[ldc * (j + 3) + (i + 3)] = hsum(c33); - C[ldc * (j + 3) + (i + 4)] = hsum(c43); - C[ldc * (j + 4) + (i + 0)] = hsum(c04); - C[ldc * (j + 4) + (i + 1)] = hsum(c14); - C[ldc * (j + 4) + (i + 2)] = hsum(c24); - C[ldc * (j + 4) + (i + 3)] = hsum(c34); - C[ldc * (j + 4) + (i + 4)] = hsum(c44); - END_KERNEL() - } - - NOINLINE void gemm3x4(int m0, int m, int n0, int n) { - BEGIN_KERNEL(3, 4) - D c00 = {0}; - D c01 = {0}; - D c02 = {0}; - D c03 = {0}; - D c10 = {0}; - D c11 = {0}; - D c12 = {0}; - D c13 = {0}; - D c20 = {0}; - D c21 = {0}; - D c22 = {0}; - D c23 = {0}; - for (int l = 0; l < k; l += KN) { - V k0 = load(B + ldb * (j + 0) + l); - V k1 = load(B + ldb * (j + 1) + l); - V k2 = load(B + ldb * (j + 2) + l); - V k3 = load(B + ldb * (j + 3) + l); - V a0 = load(A + lda * (i + 0) + l); - c00 = madd(a0, k0, c00); - c01 = madd(a0, k1, c01); - c02 = madd(a0, k2, c02); - c03 = madd(a0, k3, c03); - V a1 = load(A + lda * (i + 1) + l); - c10 = madd(a1, k0, c10); - c11 = madd(a1, k1, c11); - c12 = madd(a1, k2, c12); - c13 = madd(a1, k3, c13); - V a2 = load(A + lda * (i + 2) + l); - c20 = madd(a2, k0, c20); - c21 = madd(a2, k1, c21); - c22 = madd(a2, k2, c22); - c23 = madd(a2, k3, c23); - } - C[ldc * (j + 0) + (i + 0)] = hsum(c00); - C[ldc * (j + 0) + (i + 1)] = hsum(c10); - C[ldc * (j + 0) + (i + 2)] = hsum(c20); - C[ldc * (j + 1) + (i + 0)] = hsum(c01); - C[ldc * (j + 1) + (i + 1)] = hsum(c11); - C[ldc * (j + 1) + (i + 2)] = hsum(c21); - C[ldc * (j + 2) + (i + 0)] = hsum(c02); - C[ldc * (j + 2) + (i + 1)] = hsum(c12); - C[ldc * (j + 2) + (i + 2)] = hsum(c22); - C[ldc * (j + 3) + (i + 0)] = hsum(c03); - C[ldc * (j + 3) + (i + 1)] = hsum(c13); - C[ldc * (j + 3) + (i + 2)] = hsum(c23); - END_KERNEL() - } - - NOINLINE void gemm1x4(int m0, int m, int n0, int n) { - BEGIN_KERNEL(1, 4) - D c00 = {0}, e00 = {0}; - D c01 = {0}, e01 = {0}; - D c02 = {0}, e02 = {0}; - D c03 = {0}, e03 = {0}; - for (int l = 0; l < k; l += KN) { - V a = load(A + lda * (i + 0) + l); - c00 = madder(a, load(B + ldb * (j + 0) + l), c00, &e00); - c01 = madder(a, load(B + ldb * (j + 1) + l), c01, &e01); - c02 = madder(a, load(B + ldb * (j + 2) + l), c02, &e02); - c03 = madder(a, load(B + ldb * (j + 3) + l), c03, &e03); - } - C[ldc * (j + 0) + (i + 0)] = hsum(c00); - C[ldc * (j + 1) + (i + 0)] = hsum(c01); - C[ldc * (j + 2) + (i + 0)] = hsum(c02); - C[ldc * (j + 3) + (i + 0)] = hsum(c03); - END_KERNEL() - } - - NOINLINE void gemm4x1(int m0, int m, int n0, int n) { - BEGIN_KERNEL(4, 1) - D c00 = {0}, e00 = {0}; - D c10 = {0}, e10 = {0}; - D c20 = {0}, e20 = {0}; - D c30 = {0}, e30 = {0}; - for (int l = 0; l < k; l += KN) { - V b = load(B + ldb * (j + 0) + l); - c00 = madder(load(A + lda * (i + 0) + l), b, c00, &e00); - c10 = madder(load(A + lda * (i + 1) + l), b, c10, &e10); - c20 = madder(load(A + lda * (i + 2) + l), b, c20, &e20); - c30 = madder(load(A + lda * (i + 3) + l), b, c30, &e30); + mnpack(m0, m, np, n); + } + + template + NOINLINE void gemm(int m0, int m, int n0, int n) { + int ytiles = (m - m0) / RM; + int xtiles = (n - n0) / RN; + int tiles = xtiles * ytiles; + int duty = (tiles + nth - 1) / nth; + int start = duty * ith; + int end = start + duty; + if (end > tiles) + end = tiles; + for (int job = start; job < end; ++job) { + int ii = m0 + job / xtiles * RM; + int jj = n0 + job % xtiles * RN; + D Cv[RN][RM] = {}; + for (int l = 0; l < k; l += KN) + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + Cv[j][i] = madd(load(A + lda * (ii + i) + l), + load(B + ldb * (jj + j) + l), + Cv[j][i]); + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); } - C[ldc * (j + 0) + (i + 0)] = hsum(c00); - C[ldc * (j + 0) + (i + 1)] = hsum(c10); - C[ldc * (j + 0) + (i + 2)] = hsum(c20); - C[ldc * (j + 0) + (i + 3)] = hsum(c30); - END_KERNEL() - } - - NOINLINE void gemm1x1(int m0, int m, int n0, int n) { - BEGIN_KERNEL(1, 1) - D c = {0}, e = {0}; - for (int l = 0; l < k; l += KN) - c = madder(load(A + lda * i + l), - load(B + ldb * j + l), c, &e); - C[ldc * j + i] = hsum(c); - END_KERNEL() } const TA *const A; @@ -521,120 +468,97 @@ class tinyBLAS_Q0_ARM { private: NOINLINE void mnpack(int m0, int m, int n0, int n) { int mc, nc, mp, np; - if (m - m0 <= 0 || n - n0 <= 0) - return; - if (m - m0 >= 3 && n - n0 >= 3) { + switch ((std::min(m - m0, 3) << 4) | std::min(n - n0, 3)) { + case 0x33: mc = 3; nc = 3; - gemm3x3(m0, m, n0, n); - } else { + gemm<3, 3>(m0, m, n0, n); + break; + case 0x32: + mc = 3; + nc = 2; + gemm<3, 2>(m0, m, n0, n); + break; + case 0x23: + mc = 2; + nc = 3; + gemm<2, 3>(m0, m, n0, n); + break; + case 0x22: + mc = 2; + nc = 2; + gemm<2, 2>(m0, m, n0, n); + break; + case 0x31: + mc = 3; + nc = 1; + gemm<3, 1>(m0, m, n0, n); + break; + case 0x13: + mc = 1; + nc = 3; + gemm<1, 3>(m0, m, n0, n); + break; + case 0x21: + mc = 2; + nc = 1; + gemm<2, 1>(m0, m, n0, n); + break; + case 0x12: + mc = 1; + nc = 2; + gemm<1, 2>(m0, m, n0, n); + break; + case 0x11: mc = 1; nc = 1; - gemm1x1(m0, m, n0, n); + gemm<1, 1>(m0, m, n0, n); + break; + default: + return; } mp = m0 + (m - m0) / mc * mc; np = n0 + (n - n0) / nc * nc; mnpack(mp, m, n0, np); - mnpack(m0, mp, np, n); - mnpack(mp, m, np, n); - } - - NOINLINE void gemm3x3(int m0, int m, int n0, int n) { - BEGIN_KERNEL(3, 3) - int32x4_t zero = vdupq_n_s32(0); - float32x4_t c00 = vdupq_n_f32(0.f); - float32x4_t c01 = vdupq_n_f32(0.f); - float32x4_t c02 = vdupq_n_f32(0.f); - float32x4_t c10 = vdupq_n_f32(0.f); - float32x4_t c11 = vdupq_n_f32(0.f); - float32x4_t c12 = vdupq_n_f32(0.f); - float32x4_t c20 = vdupq_n_f32(0.f); - float32x4_t c21 = vdupq_n_f32(0.f); - float32x4_t c22 = vdupq_n_f32(0.f); - const TA *Ap0 = A + lda * (i + 0); - const TA *Ap1 = A + lda * (i + 1); - const TA *Ap2 = A + lda * (i + 2); - const block_q8_0 *Bp0 = B + ldb * (j + 0); - const block_q8_0 *Bp1 = B + ldb * (j + 1); - const block_q8_0 *Bp2 = B + ldb * (j + 2); - for (int l = 0; l < k; ++l) { - c00 = vmlaq_n_f32( - c00, - vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp0 + l)), - load_hi(Ap0 + l), load_hi(Bp0 + l))), - unhalf(Ap0[l].d) * unhalf(Bp0[l].d)); - c01 = vmlaq_n_f32( - c01, - vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp1 + l)), - load_hi(Ap0 + l), load_hi(Bp1 + l))), - unhalf(Ap0[l].d) * unhalf(Bp1[l].d)); - c02 = vmlaq_n_f32( - c02, - vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap0 + l), load_lo(Bp2 + l)), - load_hi(Ap0 + l), load_hi(Bp2 + l))), - unhalf(Ap0[l].d) * unhalf(Bp2[l].d)); - c10 = vmlaq_n_f32( - c10, - vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp0 + l)), - load_hi(Ap1 + l), load_hi(Bp0 + l))), - unhalf(Ap1[l].d) * unhalf(Bp0[l].d)); - c11 = vmlaq_n_f32( - c11, - vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp1 + l)), - load_hi(Ap1 + l), load_hi(Bp1 + l))), - unhalf(Ap1[l].d) * unhalf(Bp1[l].d)); - c12 = vmlaq_n_f32( - c12, - vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap1 + l), load_lo(Bp2 + l)), - load_hi(Ap1 + l), load_hi(Bp2 + l))), - unhalf(Ap1[l].d) * unhalf(Bp2[l].d)); - c20 = vmlaq_n_f32( - c20, - vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp0 + l)), - load_hi(Ap2 + l), load_hi(Bp0 + l))), - unhalf(Ap2[l].d) * unhalf(Bp0[l].d)); - c21 = vmlaq_n_f32( - c21, - vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp1 + l)), - load_hi(Ap2 + l), load_hi(Bp1 + l))), - unhalf(Ap2[l].d) * unhalf(Bp1[l].d)); - c22 = vmlaq_n_f32( - c22, - vcvtq_f32_s32(vdotq_s32(vdotq_s32(zero, load_lo(Ap2 + l), load_lo(Bp2 + l)), - load_hi(Ap2 + l), load_hi(Bp2 + l))), - unhalf(Ap2[l].d) * unhalf(Bp2[l].d)); - } - C[ldc * (j + 0) + (i + 0)] = hsum(c00); - C[ldc * (j + 0) + (i + 1)] = hsum(c10); - C[ldc * (j + 0) + (i + 2)] = hsum(c20); - C[ldc * (j + 1) + (i + 0)] = hsum(c01); - C[ldc * (j + 1) + (i + 1)] = hsum(c11); - C[ldc * (j + 1) + (i + 2)] = hsum(c21); - C[ldc * (j + 2) + (i + 0)] = hsum(c02); - C[ldc * (j + 2) + (i + 1)] = hsum(c12); - C[ldc * (j + 2) + (i + 2)] = hsum(c22); - END_KERNEL() - } - - NOINLINE void gemm1x1(int m0, int m, int n0, int n) { - BEGIN_KERNEL(1, 1) - float32x4_t acc = vdupq_n_f32(0.f); - const TA *Ap = A + lda * i; - const block_q8_0 *Bp = B + ldb * j; - for (int l = 0; l < k; ++l) { - acc = vmlaq_n_f32(acc, - vcvtq_f32_s32(vdotq_s32( - vdotq_s32(vdupq_n_s32(0), load_lo(Ap + l), load_lo(Bp + l)), - load_hi(Ap + l), load_hi(Bp + l))), - unhalf(Ap[l].d) * unhalf(Bp[l].d)); + mnpack(m0, m, np, n); + } + + template + NOINLINE void gemm(int m0, int m, int n0, int n) { + int ytiles = (m - m0) / RM; + int xtiles = (n - n0) / RN; + int tiles = xtiles * ytiles; + int duty = (tiles + nth - 1) / nth; + int start = duty * ith; + int end = start + duty; + if (end > tiles) + end = tiles; + for (int job = start; job < end; ++job) { + int ii = m0 + job / xtiles * RM; + int jj = n0 + job % xtiles * RN; + float32x4_t Cv[RN][RM] = {}; + for (int l = 0; l < k; ++l) + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + Cv[j][i] = vmlaq_n_f32(Cv[j][i], + vcvtq_f32_s32(vdotq_s32( + vdotq_s32(vdupq_n_s32(0), + load_lo(A + lda * (ii + i) + l), + load_lo(B + ldb * (jj + j) + l)), + load_hi(A + lda * (ii + i) + l), + load_hi(B + ldb * (jj + j) + l))), + unhalf(A[lda * (ii + i) + l].d) * + unhalf(B[ldb * (jj + j) + l].d)); + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); } - C[ldc * j + i] = hsum(acc); - END_KERNEL() } inline int8x16_t load_lo(const block_q8_0 *b) { return vld1q_s8(b->qs); } + inline int8x16_t load_hi(const block_q8_0 *b) { return vld1q_s8(b->qs + 16); } @@ -644,6 +568,7 @@ class tinyBLAS_Q0_ARM { vdupq_n_u8(0x0f))), vdupq_n_s8(0x8)); } + inline int8x16_t load_hi(const block_q4_0 *b) { return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), vdupq_n_s8(0x8)); @@ -679,217 +604,143 @@ class tinyBLAS_Q0_AVX2 { } private: - NOINLINE void mnpack(int m0, int m, int n0, int n) { + void mnpack(int m0, int m, int n0, int n) { int mc, nc, mp, np; - if (m - m0 <= 0 || n - n0 <= 0) - return; - if (m - m0 >= 4 && n - n0 >= 3) { + switch ((std::min(m - m0, 4) << 4) | std::min(n - n0, 4)) { +#if VECTOR_REGISTERS == 32 + case 0x44: + mc = 4; + nc = 4; + gemm<4, 4>(m0, m, n0, n); + break; + case 0x43: + mc = 4; + nc = 3; + gemm<4, 3>(m0, m, n0, n); + break; + case 0x34: + mc = 3; + nc = 4; + gemm<3, 4>(m0, m, n0, n); + break; + case 0x33: + mc = 3; + nc = 3; + gemm<3, 3>(m0, m, n0, n); + break; + case 0x42: mc = 4; + nc = 2; + gemm<4, 2>(m0, m, n0, n); + break; + case 0x24: + mc = 2; + nc = 4; + gemm<2, 4>(m0, m, n0, n); + break; +#else + case 0x44: + case 0x43: + case 0x42: + mc = 4; + nc = 2; + gemm<4, 2>(m0, m, n0, n); + break; + case 0x34: + case 0x24: + mc = 2; + nc = 4; + gemm<2, 4>(m0, m, n0, n); + break; + case 0x33: +#endif + case 0x32: + mc = 3; + nc = 2; + gemm<3, 2>(m0, m, n0, n); + break; + case 0x23: + mc = 2; nc = 3; - gemm4x3(m0, m, n0, n); - } else if (m - m0 >= 4 && n - n0 >= 1) { + gemm<2, 3>(m0, m, n0, n); + break; + case 0x41: mc = 4; nc = 1; - gemm4x1(m0, m, n0, n); - } else if (m - m0 >= 1 && n - n0 >= 4) { + gemm<4, 1>(m0, m, n0, n); + break; + case 0x22: + mc = 2; + nc = 2; + gemm<2, 2>(m0, m, n0, n); + break; + case 0x14: mc = 1; nc = 4; - gemm1x4(m0, m, n0, n); - } else { + gemm<1, 4>(m0, m, n0, n); + break; + case 0x31: + mc = 3; + nc = 1; + gemm<3, 1>(m0, m, n0, n); + break; + case 0x13: mc = 1; + nc = 3; + gemm<1, 3>(m0, m, n0, n); + break; + case 0x21: + mc = 2; nc = 1; - gemm1x1(m0, m, n0, n); + gemm<2, 1>(m0, m, n0, n); + break; + case 0x12: + mc = 1; + nc = 2; + gemm<1, 2>(m0, m, n0, n); + break; + case 0x11: + mc = 1; + nc = 1; + gemm<1, 1>(m0, m, n0, n); + break; + default: + return; } mp = m0 + (m - m0) / mc * mc; np = n0 + (n - n0) / nc * nc; mnpack(mp, m, n0, np); - mnpack(m0, mp, np, n); - mnpack(mp, m, np, n); - } - - NOINLINE void gemm4x3(int m0, int m, int n0, int n) { - BEGIN_KERNEL(4, 3) - __m256 c00 = _mm256_setzero_ps(); - __m256 c10 = _mm256_setzero_ps(); - __m256 c20 = _mm256_setzero_ps(); - __m256 c30 = _mm256_setzero_ps(); - __m256 c01 = _mm256_setzero_ps(); - __m256 c11 = _mm256_setzero_ps(); - __m256 c21 = _mm256_setzero_ps(); - __m256 c31 = _mm256_setzero_ps(); - __m256 c02 = _mm256_setzero_ps(); - __m256 c12 = _mm256_setzero_ps(); - __m256 c22 = _mm256_setzero_ps(); - __m256 c32 = _mm256_setzero_ps(); - const TA *Ap0 = A + lda * (i + 0); - const TA *Ap1 = A + lda * (i + 1); - const TA *Ap2 = A + lda * (i + 2); - const TA *Ap3 = A + lda * (i + 3); - const TB *Bp0 = B + ldb * (j + 0); - const TB *Bp1 = B + ldb * (j + 1); - const TB *Bp2 = B + ldb * (j + 2); - for (int l = 0; l < k; ++l) { - float da0 = unhalf(Ap0[l].d); - float da1 = unhalf(Ap1[l].d); - float da2 = unhalf(Ap2[l].d); - float da3 = unhalf(Ap3[l].d); - __m256i e0 = load(Ap0 + l); - __m256i e1 = load(Ap1 + l); - __m256i e2 = load(Ap2 + l); - __m256i e3 = load(Ap3 + l); - float db0 = unhalf(Bp0[l].d); - __m256 d00 = _mm256_set1_ps(da0 * db0); - __m256 d10 = _mm256_set1_ps(da1 * db0); - __m256 d20 = _mm256_set1_ps(da2 * db0); - __m256 d30 = _mm256_set1_ps(da3 * db0); - __m256i f0 = load(Bp0 + l); - __m256i u0 = _mm256_sign_epi8(f0, f0); - __m256i s00 = _mm256_sign_epi8(e0, f0); - __m256i s10 = _mm256_sign_epi8(e1, f0); - __m256i s20 = _mm256_sign_epi8(e2, f0); - __m256i s30 = _mm256_sign_epi8(e3, f0); - c00 = madd(d00, updot(u0, s00), c00); - c10 = madd(d10, updot(u0, s10), c10); - c20 = madd(d20, updot(u0, s20), c20); - c30 = madd(d30, updot(u0, s30), c30); - float db1 = unhalf(Bp1[l].d); - __m256 d01 = _mm256_set1_ps(da0 * db1); - __m256 d11 = _mm256_set1_ps(da1 * db1); - __m256 d21 = _mm256_set1_ps(da2 * db1); - __m256 d31 = _mm256_set1_ps(da3 * db1); - __m256i f1 = load(Bp1 + l); - __m256i u1 = _mm256_sign_epi8(f1, f1); - __m256i s01 = _mm256_sign_epi8(e0, f1); - __m256i s11 = _mm256_sign_epi8(e1, f1); - __m256i s21 = _mm256_sign_epi8(e2, f1); - __m256i s31 = _mm256_sign_epi8(e3, f1); - c01 = madd(d01, updot(u1, s01), c01); - c11 = madd(d11, updot(u1, s11), c11); - c21 = madd(d21, updot(u1, s21), c21); - c31 = madd(d31, updot(u1, s31), c31); - float db2 = unhalf(Bp2[l].d); - __m256 d02 = _mm256_set1_ps(da0 * db2); - __m256 d12 = _mm256_set1_ps(da1 * db2); - __m256 d22 = _mm256_set1_ps(da2 * db2); - __m256 d32 = _mm256_set1_ps(da3 * db2); - __m256i f2 = load(Bp2 + l); - __m256i u2 = _mm256_sign_epi8(f2, f2); - __m256i s02 = _mm256_sign_epi8(e0, f2); - __m256i s12 = _mm256_sign_epi8(e1, f2); - __m256i s22 = _mm256_sign_epi8(e2, f2); - __m256i s32 = _mm256_sign_epi8(e3, f2); - c02 = madd(d02, updot(u2, s02), c02); - c12 = madd(d12, updot(u2, s12), c12); - c22 = madd(d22, updot(u2, s22), c22); - c32 = madd(d32, updot(u2, s32), c32); - } - C[ldc * (j + 0) + (i + 0)] = hsum(c00); - C[ldc * (j + 0) + (i + 1)] = hsum(c10); - C[ldc * (j + 0) + (i + 2)] = hsum(c20); - C[ldc * (j + 0) + (i + 3)] = hsum(c30); - C[ldc * (j + 1) + (i + 0)] = hsum(c01); - C[ldc * (j + 1) + (i + 1)] = hsum(c11); - C[ldc * (j + 1) + (i + 2)] = hsum(c21); - C[ldc * (j + 1) + (i + 3)] = hsum(c31); - C[ldc * (j + 2) + (i + 0)] = hsum(c02); - C[ldc * (j + 2) + (i + 1)] = hsum(c12); - C[ldc * (j + 2) + (i + 2)] = hsum(c22); - C[ldc * (j + 2) + (i + 3)] = hsum(c32); - END_KERNEL() - } - - NOINLINE void gemm4x1(int m0, int m, int n0, int n) { - BEGIN_KERNEL(4, 1) - __m256 c0 = _mm256_setzero_ps(); - __m256 c1 = _mm256_setzero_ps(); - __m256 c2 = _mm256_setzero_ps(); - __m256 c3 = _mm256_setzero_ps(); - const TA *Ap0 = A + lda * (i + 0); - const TA *Ap1 = A + lda * (i + 1); - const TA *Ap2 = A + lda * (i + 2); - const TA *Ap3 = A + lda * (i + 3); - const TB *Bp = B + ldb * j; - for (int l = 0; l < k; ++l) { - float db0 = unhalf(Bp[l].d); - __m256i f = load(Bp + l); - __m256i u = _mm256_sign_epi8(f, f); - __m256 d0 = _mm256_set1_ps(unhalf(Ap0[l].d) * db0); - __m256 d1 = _mm256_set1_ps(unhalf(Ap1[l].d) * db0); - __m256 d2 = _mm256_set1_ps(unhalf(Ap2[l].d) * db0); - __m256 d3 = _mm256_set1_ps(unhalf(Ap3[l].d) * db0); - __m256i e0 = load(Ap0 + l); - __m256i e1 = load(Ap1 + l); - __m256i e2 = load(Ap2 + l); - __m256i e3 = load(Ap3 + l); - __m256i s0 = _mm256_sign_epi8(e0, f); - __m256i s1 = _mm256_sign_epi8(e1, f); - __m256i s2 = _mm256_sign_epi8(e2, f); - __m256i s3 = _mm256_sign_epi8(e3, f); - __m256 g0 = updot(u, s0); - __m256 g1 = updot(u, s1); - __m256 g2 = updot(u, s2); - __m256 g3 = updot(u, s3); - c0 = madd(d0, g0, c0); - c1 = madd(d1, g1, c1); - c2 = madd(d2, g2, c2); - c3 = madd(d3, g3, c3); - } - C[ldc * j + (i + 0)] = hsum(c0); - C[ldc * j + (i + 1)] = hsum(c1); - C[ldc * j + (i + 2)] = hsum(c2); - C[ldc * j + (i + 3)] = hsum(c3); - END_KERNEL() - } - - NOINLINE void gemm1x4(int m0, int m, int n0, int n) { - BEGIN_KERNEL(1, 4) - __m256 c0 = _mm256_setzero_ps(); - __m256 c1 = _mm256_setzero_ps(); - __m256 c2 = _mm256_setzero_ps(); - __m256 c3 = _mm256_setzero_ps(); - const TB *Bp0 = B + ldb * (j + 0); - const TB *Bp1 = B + ldb * (j + 1); - const TB *Bp2 = B + ldb * (j + 2); - const TB *Bp3 = B + ldb * (j + 3); - const TA *Ap = A + lda * i; - for (int l = 0; l < k; ++l) { - float da0 = unhalf(Ap[l].d); - __m256i f = load(Ap + l); - __m256i u = _mm256_sign_epi8(f, f); - __m256 d0 = _mm256_set1_ps(unhalf(Bp0[l].d) * da0); - __m256 d1 = _mm256_set1_ps(unhalf(Bp1[l].d) * da0); - __m256 d2 = _mm256_set1_ps(unhalf(Bp2[l].d) * da0); - __m256 d3 = _mm256_set1_ps(unhalf(Bp3[l].d) * da0); - __m256 g0 = updot(u, _mm256_sign_epi8(load(Bp0 + l), f)); - __m256 g1 = updot(u, _mm256_sign_epi8(load(Bp1 + l), f)); - __m256 g2 = updot(u, _mm256_sign_epi8(load(Bp2 + l), f)); - __m256 g3 = updot(u, _mm256_sign_epi8(load(Bp3 + l), f)); - c0 = madd(d0, g0, c0); - c1 = madd(d1, g1, c1); - c2 = madd(d2, g2, c2); - c3 = madd(d3, g3, c3); - } - C[ldc * (j + 0) + i] = hsum(c0); - C[ldc * (j + 1) + i] = hsum(c1); - C[ldc * (j + 2) + i] = hsum(c2); - C[ldc * (j + 3) + i] = hsum(c3); - END_KERNEL() - } - - NOINLINE void gemm1x1(int m0, int m, int n0, int n) { - BEGIN_KERNEL(1, 1) - __m256 c = _mm256_setzero_ps(); - const TA *Ap = A + lda * i; - const TB *Bp = B + ldb * j; - for (int l = 0; l < k; ++l) { - __m256 d = _mm256_set1_ps(unhalf(Ap[l].d) * unhalf(Bp[l].d)); - __m256i e = load(Ap + l); - __m256i f = load(Bp + l); - __m256 g = updot(_mm256_sign_epi8(e, e), _mm256_sign_epi8(f, e)); - c = madd(d, g, c); + mnpack(m0, m, np, n); + } + + template + NOINLINE void gemm(int m0, int m, int n0, int n) { + int ytiles = (m - m0) / RM; + int xtiles = (n - n0) / RN; + int tiles = xtiles * ytiles; + int duty = (tiles + nth - 1) / nth; + int start = duty * ith; + int end = start + duty; + if (end > tiles) + end = tiles; + for (int job = start; job < end; ++job) { + int ii = m0 + job / xtiles * RM; + int jj = n0 + job % xtiles * RN; + __m256 Cv[RN][RM] = {}; + for (int l = 0; l < k; ++l) + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * + unhalf(B[ldb * (jj + j) + l].d)), + updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), + load(A + lda * (ii + i) + l))), + Cv[j][i]); + for (int j = 0; j < RN; ++j) + for (int i = 0; i < RM; ++i) + C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); } - C[ldc * j + i] = hsum(c); - END_KERNEL() } inline __m256i load(const block_q8_0 *b) { @@ -911,10 +762,10 @@ class tinyBLAS_Q0_AVX2 { } static inline __m256i denibble(const uint8_t *p) { - const __m128i tmp = _mm_loadu_si128((const __m128i *)p); - const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); - const __m256i lowMask = _mm256_set1_epi8(15); - return _mm256_and_si256(lowMask, bytes); + __m128i x = _mm_loadu_si128((const __m128i *)p); + return _mm256_and_si256(_mm256_set1_epi8(15), + _mm256_insertf128_si256(_mm256_castsi128_si256(x), + _mm_srli_epi16(x, 4), 1)); } const TA *const A;