Skip to content

Commit

Permalink
[CPU]remove useless code of s4
Browse files Browse the repository at this point in the history
Signed-off-by: Zhang Yi <[email protected]>
  • Loading branch information
zhangYiIntel committed Jan 3, 2025
1 parent c362399 commit 28bcf7b
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 246 deletions.
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) {
val.as<std::string>(),
" for property key ",
ov::hint::value_cache_precision.name(),
". Supported values: u4, s4, u8, bf16, f16, f32");
". Supported values: u4, u8, bf16, f16, f32");
}
} else if (key == ov::hint::key_cache_group_size.name() || key == ov::hint::value_cache_group_size.name()) {
try {
Expand Down
110 changes: 0 additions & 110 deletions src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,106 +272,6 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp)
}
}

template <typename T>
static void quant_i4(const T* src, void* dst, size_t n, float& scale) {
auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t {
uint8_t shift = high_half ? 0 : 4;
if (high_half)
val &= 0x0F;
return dst | (uint8_t)(val << shift);
};
auto dst_ptr = reinterpret_cast<uint8_t*>(dst);
size_t i = 0;
float max = -FLT_MAX;
float min = FLT_MAX;
find_minmax(src, n, min, max);
float max_abs = std::max(std::abs(min), std::abs(max));
scale = max_abs / ((1 << 3) - 1);
if (scale == 0)
scale = 0.0001f;
#if defined(HAVE_AVX512F)
auto v_scale = _mm512_set1_ps(1 / scale);
auto v_upper = _mm512_set1_epi32(7);
auto v_lower = _mm512_set1_epi32(-8);
for (; i + 2 * vec_len_f32_avx512 <= n; i += 2 * vec_len_f32_avx512) {
auto v0 = mm512_uni_loadu_ps(src + i);
auto v1 = mm512_uni_loadu_ps(src + i + vec_len_f32_avx512);
v0 = _mm512_mul_ps(v0, v_scale);
v1 = _mm512_mul_ps(v1, v_scale);
auto v0_i32 = _mm512_cvt_roundps_epi32(v0, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
auto v1_i32 = _mm512_cvt_roundps_epi32(v1, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);

v0_i32 = _mm512_max_epi32(v0_i32, v_lower);
v1_i32 = _mm512_max_epi32(v1_i32, v_lower);
v0_i32 = _mm512_min_epi32(v0_i32, v_upper);
v1_i32 = _mm512_min_epi32(v1_i32, v_upper);

__m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
__m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1);
auto first_half = _mm512_permutex2var_epi32(v0_i32, idx1, v1_i32);
auto second_half = _mm512_permutex2var_epi32(v0_i32, idx2, v1_i32);

auto mask = _mm512_set1_epi32(0x0F);
second_half = _mm512_and_epi32(second_half, mask);
first_half = _mm512_slli_epi32(first_half, 4);
auto combined = _mm512_or_epi32(first_half, second_half);
_mm512_mask_cvtepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined);
}
#endif
#if defined(HAVE_AVX2)
auto v256_lower = _mm256_set1_epi32(-8);
auto v256_upper = _mm256_set1_epi32(7);
auto v256_scale = _mm256_set1_ps(1 / scale);
for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) {
auto v0 = mm256_uni_loadu_ps(src + i);
auto v1 = mm256_uni_loadu_ps(src + i + vec_len_f32_avx2);
v0 = _mm256_mul_ps(v0, v256_scale);
v1 = _mm256_mul_ps(v1, v256_scale);
v0 = _mm256_round_ps(v0, _MM_ROUND_NEAREST);
v1 = _mm256_round_ps(v1, _MM_ROUND_NEAREST);

auto v0_i32 = _mm256_cvtps_epi32(v0);
auto v1_i32 = _mm256_cvtps_epi32(v1);
v0_i32 = _mm256_max_epi32(v0_i32, v256_lower);
v1_i32 = _mm256_max_epi32(v1_i32, v256_lower);
v0_i32 = _mm256_min_epi32(v0_i32, v256_upper);
v1_i32 = _mm256_min_epi32(v1_i32, v256_upper);
auto idx1 = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0);
v0_i32 = _mm256_permutevar8x32_epi32(v0_i32, idx1);
v1_i32 = _mm256_permutevar8x32_epi32(v1_i32, idx1);

// 0,1,2,3,4,5,6,7 | 8,9,10,11,12,13,14,15
// _mm256_permutevar8x32_epi32
// 0,2,4,6,1,3,5,7 | 8,10,12,14,9,11,13,15
// _mm256_permute2x128_si256
// 0,2,4,6,8,10,12,14 | 1,3,5,7,9,11,13,15
// shift + mask + or
// [0,1],[2,3], ..., [12,13], [14,15]
auto first_half = _mm256_permute2x128_si256(v0_i32, v1_i32, 0x20);
auto second_half = _mm256_permute2x128_si256(v0_i32, v1_i32, 0x31);
first_half = _mm256_slli_epi32(first_half, 4);
auto mask = _mm256_set1_epi32(0x0F);
second_half = _mm256_and_si256(second_half, mask);
auto combined = _mm256_or_si256(first_half, second_half);

auto high4 = _mm256_extractf128_si256(combined, 1);
auto low4 = _mm256_castsi256_si128(combined);
// keep sign bit for s4 case
auto packed = _mm_packs_epi32(low4, high4);
packed = _mm_packs_epi16(packed, packed);
_mm_storel_epi64(reinterpret_cast<__m128i*>(dst_ptr + i / 2), packed);
}
#endif
for (; i < n; i++) {
float tmp = src[i];
int8_t src_val = std::min((int8_t)(7), (int8_t)std::round(tmp / scale));
src_val = std::max((int8_t)(-8), src_val);
uint8_t dst_val = i % 2 == 0 ? 0 : dst_ptr[i / 2];
dst_val = insert_half_byte(dst_val, src_val, (uint8_t)(i % 2));
dst_ptr[i / 2] = dst_val;
}
}

template <typename T,
ov::element::Type_t DST_PREC,
typename std::enable_if<DST_PREC == ov::element::u8, bool>::type = true>
Expand All @@ -386,13 +286,6 @@ static void quantize(const T* src, void* dst, size_t n, float* scale_zp) {
quant_u4(src, dst, n, *scale_zp, *(scale_zp + 1));
}

template <typename T,
ov::element::Type_t DST_PREC,
typename std::enable_if<DST_PREC == ov::element::i4, bool>::type = true>
static void quantize(const T* src, void* dst, size_t n, float* scale_zp) {
quant_i4(src, dst, n, *scale_zp);
}

template <typename T, typename T2>
static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
const ov::intel_cpu::PlainTensor& v_src,
Expand Down Expand Up @@ -500,17 +393,14 @@ void paged_attn_quantkv(const ov::intel_cpu::PlainTensor& k_src,
static constexpr function_type funcs_fp32[] = {
paged_attn_quant_mt<float, ov::element::u8, ov::element::u8>,
paged_attn_quant_mt<float, ov::element::u8, ov::element::u4>,
paged_attn_quant_mt<float, ov::element::u8, ov::element::i4>,
};
static constexpr function_type funcs_bf16[] = {
paged_attn_quant_mt<ov::bfloat16, ov::element::u8, ov::element::u8>,
paged_attn_quant_mt<ov::bfloat16, ov::element::u8, ov::element::u4>,
paged_attn_quant_mt<ov::bfloat16, ov::element::u8, ov::element::i4>,
};
static constexpr function_type funcs_f16[] = {
paged_attn_quant_mt<ov::float16, ov::element::u8, ov::element::u8>,
paged_attn_quant_mt<ov::float16, ov::element::u8, ov::element::u4>,
paged_attn_quant_mt<ov::float16, ov::element::u8, ov::element::i4>,
};
if (k_dst.get_precision() != ov::element::u8) {
OPENVINO_THROW("unsupport src type: ",
Expand Down
136 changes: 1 addition & 135 deletions src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,99 +463,6 @@ static void attn_acc_value_block(float* out,
}
}

template <typename T,
ov::element::Type_t SRC_PREC,
typename std::enable_if<SRC_PREC == ov::element::i4, bool>::type = true>
static void attn_acc_value_block(float* out,
float* weight,
void* v,
const size_t S,
const size_t block_size,
const size_t group_size) {
size_t src_offset = 0;
size_t dst_offset = 0;
const size_t params_offset = sizeof(float);
auto sub_byte_multiplier = 8 / 4;
uint8_t* v_ptr = reinterpret_cast<uint8_t*>(v);
const size_t src_stride = S / group_size * (group_size / sub_byte_multiplier + params_offset);
auto extract_half_byte = [](uint8_t val, bool high_half) -> uint8_t {
uint8_t shift = high_half ? 0 : 4;

return (uint8_t)((val >> shift) & 0x000F);
};

for (size_t j = 0; j < block_size; j++) {
dst_offset = 0;
src_offset = 0;
while (dst_offset < S) {
auto v0 = reinterpret_cast<float*>(v_ptr + src_offset);
size_t i = 0;
# if defined(HAVE_AVX512F)
auto attn_w_vec0 = _mm512_set1_ps(weight[j] * v0[0]);
for (; i + vec_len_f32_avx512 * 2 <= group_size; i += vec_len_f32_avx512 * 2) {
auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(v_ptr + i / 2 + src_offset + params_offset));
auto v_i32 = _mm512_cvtepi8_epi32(data);
// cvt to f32
auto v_256_low_half = _mm512_srai_epi32(v_i32, 4);
auto v_256_high_half = _mm512_slli_epi32(v_i32, 28);
v_256_high_half = _mm512_srai_epi32(v_256_high_half, 28);

auto v_f32_low_half = _mm512_cvtepi32_ps(v_256_low_half);
auto v_f32_high_half = _mm512_cvtepi32_ps(v_256_high_half);

__m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0);
__m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8);
__m512 first_half = _mm512_permutex2var_ps(v_f32_low_half, idx1, v_f32_high_half);
__m512 second_half = _mm512_permutex2var_ps(v_f32_low_half, idx2, v_f32_high_half);
auto v_out0 = mm512_uni_loadu_ps(out + dst_offset + i);
auto v_out1 = mm512_uni_loadu_ps(out + dst_offset + i + vec_len_f32_avx512);
v_out0 = _mm512_fmadd_ps(attn_w_vec0, first_half, v_out0);
v_out1 = _mm512_fmadd_ps(attn_w_vec0, second_half, v_out1);
mm512_uni_storeu_ps(out + dst_offset + i, v_out0);
mm512_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx512, v_out1);
}
# elif defined(HAVE_AVX2)
auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]);
for (; i + vec_len_f32_avx2 * 2 <= group_size; i += vec_len_f32_avx2 * 2) {
auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(v_ptr + i / 2 + src_offset + params_offset));

auto v_i32 = _mm256_cvtepi8_epi32(data);
auto v_256_low_half = _mm256_srai_epi32(v_i32, 4);
auto v_f32_low_half = _mm256_cvtepi32_ps(v_256_low_half);

auto v_256_high_half = _mm256_slli_epi32(v_i32, 28);
v_256_high_half = _mm256_srai_epi32(v_256_high_half, 28);
auto v_f32_high_half = _mm256_cvtepi32_ps(v_256_high_half);

auto v_out0 = mm256_uni_loadu_ps(out + dst_offset + i);
auto v_out1 = mm256_uni_loadu_ps(out + dst_offset + i + vec_len_f32_avx2);
__m256 first_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x20);
auto idx1 = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
first_half = _mm256_permutevar8x32_ps(first_half, idx1);
__m256 second_half = _mm256_permute2f128_ps(v_f32_low_half, v_f32_high_half, 0x31);
second_half = _mm256_permutevar8x32_ps(second_half, idx1);
v_out0 = _mm256_fmadd_ps(v256_attn_w_vec0, first_half, v_out0);
v_out1 = _mm256_fmadd_ps(v256_attn_w_vec0, second_half, v_out1);
mm256_uni_storeu_ps(out + dst_offset + i, v_out0);
mm256_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx2, v_out1);
}
# endif
for (; i < group_size; i += 2) {
uint8_t data = v_ptr[i / 2 + src_offset + params_offset];
float tmp0 = extract_half_byte(data, static_cast<bool>(i % 2));
tmp0 = tmp0 > 8 ? (tmp0 - 16) : tmp0;
float tmp1 = extract_half_byte(data, static_cast<bool>((i + 1) % 2));
tmp1 = tmp1 > 8 ? (tmp1 - 16) : tmp1;
out[dst_offset + i] += weight[j] * (tmp0)*v0[0];
out[dst_offset + i + 1] += weight[j] * (tmp1)*v0[0];
}
dst_offset += group_size;
src_offset += group_size / sub_byte_multiplier + params_offset;
}
v_ptr += src_stride;
}
}

template <typename TA, typename TB>
static void dot_product_block(TA* a,
TB* b,
Expand Down Expand Up @@ -1296,47 +1203,6 @@ static void pack_32NxK(TDST* dst,
src_stride,
group_size);
}

template <typename TDST,
ov::element::Type_t SRC_PREC,
typename std::enable_if<precision_of<TDST>::value != ov::element::f32 && (SRC_PREC == ov::element::i4),
bool>::type = true>
static void pack_32NxK(TDST* dst,
void* src,
TDST* tmp,
const size_t N,
const size_t K,
const size_t dst_stride,
const size_t src_stride,
const size_t group_size) {
// The layout for per token per head:
// |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized
// feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float)
auto s = reinterpret_cast<uint8_t*>(src);
auto t = tmp;
// if group_size not set, the whole row is used as a group
const size_t sub_byte_mulitplier = 2;
for (size_t n = 0; n < N; n++) {
size_t src_offset = 0;
size_t dst_offset = 0;
while (dst_offset < K) {
auto f = reinterpret_cast<float*>(s + src_offset);
attn_dequant_s4_kernel(s + (src_offset + sizeof(float)), t + dst_offset, group_size, f[0]);
src_offset += group_size / sub_byte_mulitplier + sizeof(float);
dst_offset += group_size;
}
s += src_offset;
t += src_stride;
}
pack_32NxK<TDST, precision_of<TDST>::value>(dst,
tmp,
reinterpret_cast<TDST*>(0),
N,
K,
dst_stride,
src_stride,
group_size);
}
# endif

template <typename TDST,
Expand Down Expand Up @@ -2085,7 +1951,7 @@ struct MHA {
auto Hk = v_cache.m_dims[1];

constexpr bool q_is_xf16 = one_of(precision_of<DATA_TYPE>::value, ov::element::bf16, ov::element::f16);
constexpr bool q_cache_is_same = precision_of<DATA_TYPE>::value == precision_of<KEY_CACHE_TYPE>::value;
constexpr bool q_cache_is_same = precision_of<DATA_TYPE>::value == VALUE_PREC;
auto attn_work_count = _workitems.attn_work_size();
auto reorder_work_count = _workitems.reorder_work_size();

Expand Down

0 comments on commit 28bcf7b

Please sign in to comment.