From 35cc5567c8cce16443e97789ce428ce275922f34 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 13 Aug 2024 18:00:06 -0400 Subject: [PATCH] ggml-quants : deduplicate TQ1_0 and TQ2_0 __ARM_FEATURE_DOTPROD support --- ggml/src/ggml-quants.c | 204 ++++++++++------------------------------- 1 file changed, 47 insertions(+), 157 deletions(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 0574bf4451614..15e5297227854 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -5667,7 +5667,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * const int nb = n / QK_K; -#if defined __ARM_NEON && defined __ARM_FEATURE_DOTPROD +#if defined(__ARM_NEON) float sumf = 0.0f; uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; @@ -5675,8 +5675,13 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * const uint8x16_t shift = vld1q_u8(k_shift); for (int i = 0; i < nb; ++i) { +#if defined(__ARM_FEATURE_DOTPROD) int32x4_t sumi0 = vdupq_n_s32(0); int32x4_t sumi1 = vdupq_n_s32(0); +#else + int16x8_t sumi0 = vdupq_n_s16(0); + int16x8_t sumi1 = vdupq_n_s16(0); +#endif // first 32 bytes of 5 elements { @@ -5714,6 +5719,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * const int8x16_t qy8 = vld1q_s8(y[i].qs + 128); const int8x16_t qy9 = vld1q_s8(y[i].qs + 144); +#if defined(__ARM_FEATURE_DOTPROD) sumi0 = vdotq_s32(sumi0, sqx0, qy0); sumi1 = vdotq_s32(sumi1, sqx1, qy1); sumi0 = vdotq_s32(sumi0, sqx2, qy2); @@ -5724,103 +5730,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * sumi1 = vdotq_s32(sumi1, sqx7, qy7); sumi0 = vdotq_s32(sumi0, sqx8, qy8); sumi1 = vdotq_s32(sumi1, sqx9, qy9); - } - - // last 16 bytes of 5-element, along with the 4 bytes of 4 elements - { - uint8x16_t qx0 = vld1q_u8(x[i].qs + 32); - uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3)); - uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9)); - uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27)); - uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81)); - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned - uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh)); - qx5 = vmulq_u8(qx5, shift); - - // multiply by 3 and keep the 2 bits above 8 bits - int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); - int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); - int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); - int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); - int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); - int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); - - const int8x16_t qy0 = vld1q_s8(y[i].qs + 160); - const int8x16_t qy1 = vld1q_s8(y[i].qs + 176); - const int8x16_t qy2 = vld1q_s8(y[i].qs + 192); - const int8x16_t qy3 = vld1q_s8(y[i].qs + 208); - const int8x16_t qy4 = vld1q_s8(y[i].qs + 224); - const int8x16_t qy5 = vld1q_s8(y[i].qs + 240); - - sumi0 = vdotq_s32(sumi0, sqx0, qy0); - sumi1 = vdotq_s32(sumi1, sqx1, qy1); - sumi0 = vdotq_s32(sumi0, sqx2, qy2); - sumi1 = vdotq_s32(sumi1, sqx3, qy3); - sumi0 = vdotq_s32(sumi0, sqx4, qy4); - sumi1 = vdotq_s32(sumi1, sqx5, qy5); - } - - const int16x8_t ysum0 = vld1q_s16(y[i].bsums); - const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); - - sumi0 = vaddq_s32(sumi0, sumi1); - sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - - sumf += d * (float) vaddvq_s32(sumi0); - } - - *s = sumf; - -#elif defined __ARM_NEON - float sumf = 0.0f; - - uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; - - const uint8x16_t shift = vld1q_u8(k_shift); - - for (int i = 0; i < nb; ++i) { - int16x8_t sumi0 = vdupq_n_s16(0); - int16x8_t sumi1 = vdupq_n_s16(0); - - // first 32 bytes of 5 elements - { - uint8x16_t qx0 = vld1q_u8(x[i].qs + 0); - uint8x16_t qx1 = vld1q_u8(x[i].qs + 16); - uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3)); - uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3)); - uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9)); - uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9)); - uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27)); - uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27)); - uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81)); - uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81)); - - // multiply by 3 and keep the 2 bits above 8 bits - int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); - int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); - int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); - int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); - int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); - int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); - int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6)); - int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6)); - int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6)); - int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6)); - - const int8x16_t qy0 = vld1q_s8(y[i].qs + 0); - const int8x16_t qy1 = vld1q_s8(y[i].qs + 16); - const int8x16_t qy2 = vld1q_s8(y[i].qs + 32); - const int8x16_t qy3 = vld1q_s8(y[i].qs + 48); - const int8x16_t qy4 = vld1q_s8(y[i].qs + 64); - const int8x16_t qy5 = vld1q_s8(y[i].qs + 80); - const int8x16_t qy6 = vld1q_s8(y[i].qs + 96); - const int8x16_t qy7 = vld1q_s8(y[i].qs + 112); - const int8x16_t qy8 = vld1q_s8(y[i].qs + 128); - const int8x16_t qy9 = vld1q_s8(y[i].qs + 144); - +#else sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); @@ -5841,6 +5751,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8)); sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9)); sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9)); +#endif } // last 16 bytes of 5-element, along with the 4 bytes of 4 elements @@ -5870,6 +5781,14 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * const int8x16_t qy4 = vld1q_s8(y[i].qs + 224); const int8x16_t qy5 = vld1q_s8(y[i].qs + 240); +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vdotq_s32(sumi0, sqx0, qy0); + sumi1 = vdotq_s32(sumi1, sqx1, qy1); + sumi0 = vdotq_s32(sumi0, sqx2, qy2); + sumi1 = vdotq_s32(sumi1, sqx3, qy3); + sumi0 = vdotq_s32(sumi0, sqx4, qy4); + sumi1 = vdotq_s32(sumi1, sqx5, qy5); +#else sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); @@ -5882,22 +5801,30 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); +#endif } const int16x8_t ysum0 = vld1q_s16(y[i].bsums); const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vaddq_s32(sumi0, sumi1); + sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); + + sumf += d * (float) vaddvq_s32(sumi0); +#else sumi0 = vaddq_s16(sumi0, sumi1); sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - sumf += d * (float) vaddlvq_s16(sumi0); +#endif } *s = sumf; -#elif defined __AVX2__ +#elif defined(__AVX2__) __m256 sumf = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { @@ -6063,14 +5990,19 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * const int nb = n / QK_K; -#if defined __ARM_NEON && defined __ARM_FEATURE_DOTPROD +#if defined(__ARM_NEON) float sumf = 0.0f; const uint8x16_t m3 = vdupq_n_u8(3); for (int i = 0; i < nb; ++i) { +#if defined(__ARM_FEATURE_DOTPROD) int32x4_t sumi0 = vdupq_n_s32(0); int32x4_t sumi1 = vdupq_n_s32(0); +#else + int16x8_t sumi0 = vdupq_n_s16(0); + int16x8_t sumi1 = vdupq_n_s16(0); +#endif for (size_t j = 0; j < sizeof(x->qs); j += 32) { uint8x16_t qx0 = vld1q_u8(x[i].qs + j); @@ -6100,6 +6032,7 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96); const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112); +#if defined(__ARM_FEATURE_DOTPROD) sumi0 = vdotq_s32(sumi0, sqx0, qy0); sumi1 = vdotq_s32(sumi1, sqx1, qy1); sumi0 = vdotq_s32(sumi0, sqx2, qy2); @@ -6108,58 +6041,7 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * sumi1 = vdotq_s32(sumi1, sqx5, qy5); sumi0 = vdotq_s32(sumi0, sqx6, qy6); sumi1 = vdotq_s32(sumi1, sqx7, qy7); - } - - const int16x8_t ysum0 = vld1q_s16(y[i].bsums); - const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); - - sumi0 = vaddq_s32(sumi0, sumi1); - sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); - - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - - sumf += d * (float) vaddvq_s32(sumi0); - } - - *s = sumf; - -#elif defined __ARM_NEON - float sumf = 0.0f; - - const uint8x16_t m3 = vdupq_n_u8(3); - - for (int i = 0; i < nb; ++i) { - int16x8_t sumi0 = vdupq_n_s16(0); - int16x8_t sumi1 = vdupq_n_s16(0); - - for (size_t j = 0; j < sizeof(x->qs); j += 32) { - uint8x16_t qx0 = vld1q_u8(x[i].qs + j); - uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16); - uint8x16_t qx2 = vshrq_n_u8(qx0, 2); - uint8x16_t qx3 = vshrq_n_u8(qx1, 2); - uint8x16_t qx4 = vshrq_n_u8(qx0, 4); - uint8x16_t qx5 = vshrq_n_u8(qx1, 4); - uint8x16_t qx6 = vshrq_n_u8(qx0, 6); - uint8x16_t qx7 = vshrq_n_u8(qx1, 6); - - int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3)); - int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3)); - int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3)); - int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3)); - int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3)); - int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3)); - int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3)); - int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3)); - - const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0); - const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16); - const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32); - const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48); - const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64); - const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80); - const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96); - const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112); - +#else sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); @@ -6176,22 +6058,30 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); +#endif } const int16x8_t ysum0 = vld1q_s16(y[i].bsums); const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vaddq_s32(sumi0, sumi1); + sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); + + sumf += d * (float) vaddvq_s32(sumi0); +#else sumi0 = vaddq_s16(sumi0, sumi1); sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); - const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; - sumf += d * (float) vaddlvq_s16(sumi0); +#endif } *s = sumf; -#elif defined __AVX2__ +#elif defined(__AVX2__) __m256 sumf = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) {