diff --git a/src/plugins/intel_cpu/src/config.cpp b/src/plugins/intel_cpu/src/config.cpp index b16e270b984fca..3b052e7094d34c 100644 --- a/src/plugins/intel_cpu/src/config.cpp +++ b/src/plugins/intel_cpu/src/config.cpp @@ -408,7 +408,7 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { val.as(), " for property key ", ov::hint::value_cache_precision.name(), - ". Supported values: u4, s4, u8, bf16, f16, f32"); + ". Supported values: u4, u8, bf16, f16, f32"); } } else if (key == ov::hint::key_cache_group_size.name() || key == ov::hint::value_cache_group_size.name()) { try { diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index 40ed27bf73ea97..e25b204e670218 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -272,106 +272,6 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) } } -template -static void quant_i4(const T* src, void* dst, size_t n, float& scale) { - auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t { - uint8_t shift = high_half ? 0 : 4; - if (high_half) - val &= 0x0F; - return dst | (uint8_t)(val << shift); - }; - auto dst_ptr = reinterpret_cast(dst); - size_t i = 0; - float max = -FLT_MAX; - float min = FLT_MAX; - find_minmax(src, n, min, max); - float max_abs = std::max(std::abs(min), std::abs(max)); - scale = max_abs / ((1 << 3) - 1); - if (scale == 0) - scale = 0.0001f; -#if defined(HAVE_AVX512F) - auto v_scale = _mm512_set1_ps(1 / scale); - auto v_upper = _mm512_set1_epi32(7); - auto v_lower = _mm512_set1_epi32(-8); - for (; i + 2 * vec_len_f32_avx512 <= n; i += 2 * vec_len_f32_avx512) { - auto v0 = mm512_uni_loadu_ps(src + i); - auto v1 = mm512_uni_loadu_ps(src + i + vec_len_f32_avx512); - v0 = _mm512_mul_ps(v0, v_scale); - v1 = _mm512_mul_ps(v1, v_scale); - auto v0_i32 = _mm512_cvt_roundps_epi32(v0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - auto v1_i32 = _mm512_cvt_roundps_epi32(v1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); - - v0_i32 = _mm512_max_epi32(v0_i32, v_lower); - v1_i32 = _mm512_max_epi32(v1_i32, v_lower); - v0_i32 = _mm512_min_epi32(v0_i32, v_upper); - v1_i32 = _mm512_min_epi32(v1_i32, v_upper); - - __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); - __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1); - auto first_half = _mm512_permutex2var_epi32(v0_i32, idx1, v1_i32); - auto second_half = _mm512_permutex2var_epi32(v0_i32, idx2, v1_i32); - - auto mask = _mm512_set1_epi32(0x0F); - second_half = _mm512_and_epi32(second_half, mask); - first_half = _mm512_slli_epi32(first_half, 4); - auto combined = _mm512_or_epi32(first_half, second_half); - _mm512_mask_cvtepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined); - } -#endif -#if defined(HAVE_AVX2) - auto v256_lower = _mm256_set1_epi32(-8); - auto v256_upper = _mm256_set1_epi32(7); - auto v256_scale = _mm256_set1_ps(1 / scale); - for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { - auto v0 = mm256_uni_loadu_ps(src + i); - auto v1 = mm256_uni_loadu_ps(src + i + vec_len_f32_avx2); - v0 = _mm256_mul_ps(v0, v256_scale); - v1 = _mm256_mul_ps(v1, v256_scale); - v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST); - v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST); - - auto v0_i32 = _mm256_cvtps_epi32(v0); - auto v1_i32 = _mm256_cvtps_epi32(v1); - v0_i32 = _mm256_max_epi32(v0_i32, v256_lower); - v1_i32 = _mm256_max_epi32(v1_i32, v256_lower); - v0_i32 = _mm256_min_epi32(v0_i32, v256_upper); - v1_i32 = _mm256_min_epi32(v1_i32, v256_upper); - auto idx1 = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); - v0_i32 = _mm256_permutevar8x32_epi32(v0_i32, idx1); - v1_i32 = _mm256_permutevar8x32_epi32(v1_i32, idx1); - - // 0,1,2,3,4,5,6,7 | 8,9,10,11,12,13,14,15 - // _mm256_permutevar8x32_epi32 - // 0,2,4,6,1,3,5,7 | 8,10,12,14,9,11,13,15 - // _mm256_permute2x128_si256 - // 0,2,4,6,8,10,12,14 | 1,3,5,7,9,11,13,15 - // shift + mask + or - // [0,1],[2,3], ..., [12,13], [14,15] - auto first_half = _mm256_permute2x128_si256(v0_i32, v1_i32, 0x20); - auto second_half = _mm256_permute2x128_si256(v0_i32, v1_i32, 0x31); - first_half = _mm256_slli_epi32(first_half, 4); - auto mask = _mm256_set1_epi32(0x0F); - second_half = _mm256_and_si256(second_half, mask); - auto combined = _mm256_or_si256(first_half, second_half); - - auto high4 = _mm256_extractf128_si256(combined, 1); - auto low4 = _mm256_castsi256_si128(combined); - // keep sign bit for s4 case - auto packed = _mm_packs_epi32(low4, high4); - packed = _mm_packs_epi16(packed, packed); - _mm_storel_epi64(reinterpret_cast<__m128i*>(dst_ptr + i / 2), packed); - } -#endif - for (; i < n; i++) { - float tmp = src[i]; - int8_t src_val = std::min((int8_t)(7), (int8_t)std::round(tmp / scale)); - src_val = std::max((int8_t)(-8), src_val); - uint8_t dst_val = i % 2 == 0 ? 0 : dst_ptr[i / 2]; - dst_val = insert_half_byte(dst_val, src_val, (uint8_t)(i % 2)); - dst_ptr[i / 2] = dst_val; - } -} - template ::type = true> @@ -386,13 +286,6 @@ static void quantize(const T* src, void* dst, size_t n, float* scale_zp) { quant_u4(src, dst, n, *scale_zp, *(scale_zp + 1)); } -template ::type = true> -static void quantize(const T* src, void* dst, size_t n, float* scale_zp) { - quant_i4(src, dst, n, *scale_zp); -} - template static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, const ov::intel_cpu::PlainTensor& v_src, @@ -500,17 +393,14 @@ void paged_attn_quantkv(const ov::intel_cpu::PlainTensor& k_src, static constexpr function_type funcs_fp32[] = { paged_attn_quant_mt, paged_attn_quant_mt, - paged_attn_quant_mt, }; static constexpr function_type funcs_bf16[] = { paged_attn_quant_mt, paged_attn_quant_mt, - paged_attn_quant_mt, }; static constexpr function_type funcs_f16[] = { paged_attn_quant_mt, paged_attn_quant_mt, - paged_attn_quant_mt, }; if (k_dst.get_precision() != ov::element::u8) { OPENVINO_THROW("unsupport src type: ", diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index 587c535cc1fb05..bd659cb1b164f7 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -463,99 +463,6 @@ static void attn_acc_value_block(float* out, } } -template ::type = true> -static void attn_acc_value_block(float* out, - float* weight, - void* v, - const size_t S, - const size_t block_size, - const size_t group_size) { - size_t src_offset = 0; - size_t dst_offset = 0; - const size_t params_offset = sizeof(float); - auto sub_byte_multiplier = 8 / 4; - uint8_t* v_ptr = reinterpret_cast(v); - const size_t src_stride = S / group_size * (group_size / sub_byte_multiplier + params_offset); - auto extract_half_byte = [](uint8_t val, bool high_half) -> uint8_t { - uint8_t shift = high_half ? 0 : 4; - - return (uint8_t)((val >> shift) & 0x000F); - }; - - for (size_t j = 0; j < block_size; j++) { - dst_offset = 0; - src_offset = 0; - while (dst_offset < S) { - auto v0 = reinterpret_cast(v_ptr + src_offset); - size_t i = 0; -# if defined(HAVE_AVX512F) - auto attn_w_vec0 = _mm512_set1_ps(weight[j] * v0[0]); - for (; i + vec_len_f32_avx512 * 2 <= group_size; i += vec_len_f32_avx512 * 2) { - auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(v_ptr + i / 2 + src_offset + params_offset)); - auto v_i32 = _mm512_cvtepi8_epi32(data); - // cvt to f32 - auto v_256_low_half = _mm512_srai_epi32(v_i32, 4); - auto v_256_high_half = _mm512_slli_epi32(v_i32, 28); - v_256_high_half = _mm512_srai_epi32(v_256_high_half, 28); - - auto v_f32_low_half = _mm512_cvtepi32_ps(v_256_low_half); - auto v_f32_high_half = _mm512_cvtepi32_ps(v_256_high_half); - - __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); - __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); - __m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half); - __m512 second_half = _mm512_permutex2var_ps(v_f32_low_half, idx2, v_f32_high_half); - auto v_out0 = mm512_uni_loadu_ps(out + dst_offset + i); - auto v_out1 = mm512_uni_loadu_ps(out + dst_offset + i + vec_len_f32_avx512); - v_out0 = _mm512_fmadd_ps(attn_w_vec0, first_half, v_out0); - v_out1 = _mm512_fmadd_ps(attn_w_vec0, second_half, v_out1); - mm512_uni_storeu_ps(out + dst_offset + i, v_out0); - mm512_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx512, v_out1); - } -# elif defined(HAVE_AVX2) - auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]); - for (; i + vec_len_f32_avx2 * 2 <= group_size; i += vec_len_f32_avx2 * 2) { - auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(v_ptr + i / 2 + src_offset + params_offset)); - - auto v_i32 = _mm256_cvtepi8_epi32(data); - auto v_256_low_half = _mm256_srai_epi32(v_i32, 4); - auto v_f32_low_half = _mm256_cvtepi32_ps(v_256_low_half); - - auto v_256_high_half = _mm256_slli_epi32(v_i32, 28); - v_256_high_half = _mm256_srai_epi32(v_256_high_half, 28); - auto v_f32_high_half = _mm256_cvtepi32_ps(v_256_high_half); - - auto v_out0 = mm256_uni_loadu_ps(out + dst_offset + i); - auto v_out1 = mm256_uni_loadu_ps(out + dst_offset + i + vec_len_f32_avx2); - __m256 first_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x20); - auto idx1 = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0); - first_half = _mm256_permutevar8x32_ps(first_half, idx1); - __m256 second_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x31); - second_half = _mm256_permutevar8x32_ps(second_half, idx1); - v_out0 = _mm256_fmadd_ps(v256_attn_w_vec0, first_half, v_out0); - v_out1 = _mm256_fmadd_ps(v256_attn_w_vec0, second_half, v_out1); - mm256_uni_storeu_ps(out + dst_offset + i, v_out0); - mm256_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx2, v_out1); - } -# endif - for (; i < group_size; i += 2) { - uint8_t data = v_ptr[i / 2 + src_offset + params_offset]; - float tmp0 = extract_half_byte(data, static_cast(i % 2)); - tmp0 = tmp0 > 8 ? (tmp0 - 16) : tmp0; - float tmp1 = extract_half_byte(data, static_cast((i + 1) % 2)); - tmp1 = tmp1 > 8 ? (tmp1 - 16) : tmp1; - out[dst_offset + i] += weight[j] * (tmp0)*v0[0]; - out[dst_offset + i + 1] += weight[j] * (tmp1)*v0[0]; - } - dst_offset += group_size; - src_offset += group_size / sub_byte_multiplier + params_offset; - } - v_ptr += src_stride; - } -} - template static void dot_product_block(TA* a, TB* b, @@ -1296,47 +1203,6 @@ static void pack_32NxK(TDST* dst, src_stride, group_size); } - -template ::value != ov::element::f32 && (SRC_PREC == ov::element::i4), - bool>::type = true> -static void pack_32NxK(TDST* dst, - void* src, - TDST* tmp, - const size_t N, - const size_t K, - const size_t dst_stride, - const size_t src_stride, - const size_t group_size) { - // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized - // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) - auto s = reinterpret_cast(src); - auto t = tmp; - // if group_size not set, the whole row is used as a group - const size_t sub_byte_mulitplier = 2; - for (size_t n = 0; n < N; n++) { - size_t src_offset = 0; - size_t dst_offset = 0; - while (dst_offset < K) { - auto f = reinterpret_cast(s + src_offset); - attn_dequant_s4_kernel(s + (src_offset + sizeof(float)), t + dst_offset, group_size, f[0]); - src_offset += group_size / sub_byte_mulitplier + sizeof(float); - dst_offset += group_size; - } - s += src_offset; - t += src_stride; - } - pack_32NxK::value>(dst, - tmp, - reinterpret_cast(0), - N, - K, - dst_stride, - src_stride, - group_size); -} # endif template ::value, ov::element::bf16, ov::element::f16); - constexpr bool q_cache_is_same = precision_of::value == precision_of::value; + constexpr bool q_cache_is_same = precision_of::value == VALUE_PREC; auto attn_work_count = _workitems.attn_work_size(); auto reorder_work_count = _workitems.reorder_work_size();