Skip to content

Commit

Permalink
[CPU]apply review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Zhang Yi <[email protected]>
  • Loading branch information
zhangYiIntel committed Jan 3, 2025
1 parent dddb4d9 commit 244f7cc
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 162 deletions.
57 changes: 28 additions & 29 deletions src/plugins/intel_cpu/src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,43 +373,42 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) {
ov::hint::kv_cache_precision.name(),
". Supported values: u8, bf16, f16, f32");
}
} else if (key == ov::hint::key_cache_precision.name() || key == ov::hint::value_cache_precision.name()) {
} else if (key == ov::hint::key_cache_precision.name()) {
try {
kvCachePrecisionSetExplicitly = true;
auto const prec = val.as<ov::element::Type>();
if (key == ov::hint::key_cache_precision.name()) {
if (one_of(prec, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8)) {
keyCachePrecision = prec;
} else {
OPENVINO_THROW("keyCachePrecision doesn't support value ", prec);
}
if (one_of(prec, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8)) {
keyCachePrecision = prec;
} else {
if (one_of(prec,
ov::element::f32,
ov::element::f16,
ov::element::bf16,
ov::element::u8,
ov::element::u4,
ov::element::i4)) {
valueCachePrecision = prec;
} else {
OPENVINO_THROW("valueCachePrecision doesn't support value ", prec);
}
OPENVINO_THROW("keyCachePrecision doesn't support value ", prec);
}
} catch (ov::Exception&) {
if (key == ov::hint::key_cache_precision.name()) {
OPENVINO_THROW("Wrong value ",
val.as<std::string>(),
" for property key ",
ov::hint::key_cache_precision.name(),
". Supported values: u8, bf16, f16, f32");
OPENVINO_THROW("Wrong value ",
val.as<std::string>(),
" for property key ",
ov::hint::key_cache_precision.name(),
". Supported values: u8, bf16, f16, f32");
}
} else if (key == ov::hint::value_cache_precision.name()) {
try {
kvCachePrecisionSetExplicitly = true;
auto const prec = val.as<ov::element::Type>();
if (one_of(prec,
ov::element::f32,
ov::element::f16,
ov::element::bf16,
ov::element::u8,
ov::element::u4)) {
valueCachePrecision = prec;
} else {
OPENVINO_THROW("Wrong value ",
val.as<std::string>(),
" for property key ",
ov::hint::value_cache_precision.name(),
". Supported values: u4, s4, u8, bf16, f16, f32");
OPENVINO_THROW("valueCachePrecision doesn't support value ", prec);
}
} catch (ov::Exception&) {
OPENVINO_THROW("Wrong value ",
val.as<std::string>(),
" for property key ",
ov::hint::value_cache_precision.name(),
". Supported values: u4, s4, u8, bf16, f16, f32");
}
} else if (key == ov::hint::key_cache_group_size.name() || key == ov::hint::value_cache_group_size.name()) {
try {
Expand Down
163 changes: 33 additions & 130 deletions src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp)
_mm512_mask_cvtepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined);
}
#endif
#if defined(HAVE_AVX2) || defined(HAVE_AVX512F)
#if defined(HAVE_AVX2)
auto v256_zero = _mm256_set1_epi32(0);
auto v256_upper = _mm256_set1_epi32(15);
auto v256_scale = _mm256_set1_ps(1 / scale);
Expand Down Expand Up @@ -273,7 +273,7 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp)
}

template <typename T>
static void quant_s4(const T* src, void* dst, size_t n, float& scale) {
static void quant_i4(const T* src, void* dst, size_t n, float& scale) {
auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t {
uint8_t shift = high_half ? 0 : 4;
if (high_half)
Expand Down Expand Up @@ -318,7 +318,7 @@ static void quant_s4(const T* src, void* dst, size_t n, float& scale) {
_mm512_mask_cvtepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined);
}
#endif
#if defined(HAVE_AVX2) || defined(HAVE_AVX512F)
#if defined(HAVE_AVX2)
auto v256_lower = _mm256_set1_epi32(-8);
auto v256_upper = _mm256_set1_epi32(7);
auto v256_scale = _mm256_set1_ps(1 / scale);
Expand Down Expand Up @@ -372,6 +372,27 @@ static void quant_s4(const T* src, void* dst, size_t n, float& scale) {
}
}

template <typename T,
ov::element::Type_t DST_PREC,
typename std::enable_if<DST_PREC == ov::element::u8, bool>::type = true>
static void quantize(const T* src, uint8_t* dst, size_t n, float* scale_zp) {
quant_u8(src, dst, n, *scale_zp, *(scale_zp + 1));
}

template <typename T,
ov::element::Type_t DST_PREC,
typename std::enable_if<DST_PREC == ov::element::u4, bool>::type = true>
static void quantize(const T* src, void* dst, size_t n, float* scale_zp) {
quant_u4(src, dst, n, *scale_zp, *(scale_zp + 1));
}

template <typename T,
ov::element::Type_t DST_PREC,
typename std::enable_if<DST_PREC == ov::element::i4, bool>::type = true>
static void quantize(const T* src, void* dst, size_t n, float* scale_zp) {
quant_i4(src, dst, n, *scale_zp);
}

template <typename T, typename T2>
static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
const ov::intel_cpu::PlainTensor& v_src,
Expand All @@ -389,10 +410,7 @@ static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
});
}

template <typename T,
ov::element::Type_t KEY_DST_PREC,
ov::element::Type_t VALUE_DST_PREC,
typename std::enable_if<VALUE_DST_PREC == ov::element::u8, bool>::type = true>
template <typename T, ov::element::Type_t KEY_DST_PREC, ov::element::Type_t VALUE_DST_PREC>
static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
const ov::intel_cpu::PlainTensor& v_src,
const ov::intel_cpu::PlainTensor& k_dst,
Expand All @@ -402,6 +420,7 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
const size_t value_group_size) {
size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3], SV = v_src.m_dims[3];
size_t block_size = k_dst.m_dims[2];
size_t sub_byte_multiplier = 8 / v_dst.get_precision().bitwidth();
parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) {
auto slot = slot_mapping.ptr<int32_t>(b)[m];
if (slot < 0)
Expand All @@ -418,76 +437,15 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
h,
block_offset,
dst_offset));
quant_u8(k_src.ptr<T>(b, h, m, src_offset),
k_dst.ptr<typename ov::element_type_traits<KEY_DST_PREC>::value_type>(block_number,
h,
block_offset,
dst_offset) +
sizeof(float) + sizeof(float),
key_group_size,
p_k[0],
p_k[1]);
}

for (size_t src_offset = 0, dst_offset = 0; src_offset < SV;
src_offset += value_group_size, dst_offset += value_group_size + sizeof(float) + sizeof(float)) {
auto p_v = reinterpret_cast<float*>(
v_dst.ptr<typename ov::element_type_traits<VALUE_DST_PREC>::value_type>(block_number,
h,
block_offset,
dst_offset));
quant_u8(v_src.ptr<T>(b, h, m, src_offset),
v_dst.ptr<typename ov::element_type_traits<VALUE_DST_PREC>::value_type>(block_number,
h,
block_offset,
dst_offset) +
sizeof(float) + sizeof(float),
value_group_size,
p_v[0],
p_v[1]);
}
});
}

template <typename T,
ov::element::Type_t KEY_DST_PREC,
ov::element::Type_t VALUE_DST_PREC,
typename std::enable_if<VALUE_DST_PREC == ov::element::u4, bool>::type = true>
static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
const ov::intel_cpu::PlainTensor& v_src,
const ov::intel_cpu::PlainTensor& k_dst,
const ov::intel_cpu::PlainTensor& v_dst,
const ov::intel_cpu::PlainTensor& slot_mapping,
const size_t key_group_size,
const size_t value_group_size) {
size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3], SV = v_src.m_dims[3];
size_t block_size = k_dst.m_dims[2];
size_t sub_byte_multiplier = 8 / v_dst.get_precision().bitwidth();
parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) {
auto slot = slot_mapping.ptr<int32_t>(b)[m];
if (slot < 0)
return;
auto block_number = slot / block_size;
auto block_offset = slot % block_size;
// The layout for per token per head:
// |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized
// feature(u8,idx_S)|
for (size_t src_offset = 0, dst_offset = 0; src_offset < S;
src_offset += key_group_size, dst_offset += key_group_size + sizeof(float) + sizeof(float)) {
auto p_k = reinterpret_cast<float*>(
quantize<T, KEY_DST_PREC>(
k_src.ptr<T>(b, h, m, src_offset),
k_dst.ptr<typename ov::element_type_traits<KEY_DST_PREC>::value_type>(block_number,
h,
block_offset,
dst_offset));
quant_u8(k_src.ptr<T>(b, h, m, src_offset),
k_dst.ptr<typename ov::element_type_traits<KEY_DST_PREC>::value_type>(block_number,
h,
block_offset,
dst_offset) +
sizeof(float) + sizeof(float),
key_group_size,
p_k[0],
p_k[1]);
dst_offset) +
sizeof(float) + sizeof(float),
key_group_size,
p_k);
}

for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += value_group_size,
Expand All @@ -499,62 +457,7 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
dst_offset);
auto p_v = reinterpret_cast<float*>(v_base);
uint8_t* v_ptr = v_base + sizeof(float) * 2;
quant_u4(v_src.ptr<T>(b, h, m, src_offset), v_ptr, value_group_size, p_v[0], p_v[1]);
}
});
}

template <typename T,
ov::element::Type_t KEY_DST_PREC,
ov::element::Type_t VALUE_DST_PREC,
typename std::enable_if<VALUE_DST_PREC == ov::element::i4, bool>::type = true>
static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src,
const ov::intel_cpu::PlainTensor& v_src,
const ov::intel_cpu::PlainTensor& k_dst,
const ov::intel_cpu::PlainTensor& v_dst,
const ov::intel_cpu::PlainTensor& slot_mapping,
const size_t key_group_size,
const size_t value_group_size) {
size_t B = k_src.m_dims[0], H = k_src.m_dims[1], L1 = k_src.m_dims[2], S = k_src.m_dims[3], SV = v_src.m_dims[3];
size_t block_size = k_dst.m_dims[2];
size_t sub_byte_multiplier = 8 / v_dst.get_precision().bitwidth();
parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) {
auto slot = slot_mapping.ptr<int32_t>(b)[m];
if (slot < 0)
return;
auto block_number = slot / block_size;
auto block_offset = slot % block_size;
// The layout for per token per head:
// |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized
// feature(u8,idx_S)|
for (size_t src_offset = 0, dst_offset = 0; src_offset < S;
src_offset += key_group_size, dst_offset += key_group_size + sizeof(float) + sizeof(float)) {
auto p_k = reinterpret_cast<float*>(
k_dst.ptr<typename ov::element_type_traits<KEY_DST_PREC>::value_type>(block_number,
h,
block_offset,
dst_offset));
quant_u8(k_src.ptr<T>(b, h, m, src_offset),
k_dst.ptr<typename ov::element_type_traits<KEY_DST_PREC>::value_type>(block_number,
h,
block_offset,
dst_offset) +
sizeof(float) + sizeof(float),
key_group_size,
p_k[0],
p_k[1]);
}

for (size_t src_offset = 0, dst_offset = 0; src_offset < SV;
src_offset += value_group_size, dst_offset += value_group_size / sub_byte_multiplier + sizeof(float)) {
uint8_t* v_base = reinterpret_cast<uint8_t*>(
v_dst.m_ptr.get() +
(block_number * v_dst.m_strides[0] + h * v_dst.m_strides[1] + block_offset * v_dst.m_strides[2]) /
sub_byte_multiplier +
dst_offset);
auto p_v = reinterpret_cast<float*>(v_base);
uint8_t* v_ptr = v_base + sizeof(float);
quant_s4(v_src.ptr<T>(b, h, m, src_offset), v_ptr, value_group_size, p_v[0]);
quantize<T, VALUE_DST_PREC>(v_src.ptr<T>(b, h, m, src_offset), v_ptr, value_group_size, p_v);
}
});
}
Expand Down
9 changes: 7 additions & 2 deletions src/plugins/intel_cpu/src/nodes/paged_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,13 @@ bool PagedAttention::isSupportedOperation(const std::shared_ptr<const ov::Node>&
try {
auto vCachePrecision = op->get_input_element_type(PagedAttentionExecutor::ID_VCACHE);
auto kCachePrecision = op->get_input_element_type(PagedAttentionExecutor::ID_KCACHE);
if (one_of(vCachePrecision, ov::element::i4, ov::element::u4, ov::element::u8)) {
if (kCachePrecision != ov::element::u8) {
if (one_of(vCachePrecision,
ov::element::u4,
ov::element::u8,
ov::element::f32,
ov::element::f16,
ov::element::bf16)) {
if (!one_of(kCachePrecision, ov::element::u8, ov::element::f16, ov::element::f32, ov::element::bf16)) {
errorMessage = "PageAttn key value cache compression doesn't support key cache prec " +
kCachePrecision.to_string() + " value cache prec " + vCachePrecision.to_string();
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1721,7 +1721,7 @@ void ScaledDotProductAttention::updatePastkv(const MemoryPtr& mem_cur_k, const M
m_v_state->assign_internal_state(new_internal_mem_v);
m_k_state->assign_internal_state_max_size(2 * (L0 + L1) * B * H * S);
m_v_state->assign_internal_state_max_size(2 * (L0 + L1) * B * H * SV);
if (kvcache_precision == ov::element::u8) {
if (kvcache_precision == ov::element::u8) {
auto& old_scale_zp_k = m_k_state->get_scale_zp();
auto& old_scale_zp_v = m_v_state->get_scale_zp();
PlainTensor new_scale_zp_k, new_scale_zp_v;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,36 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckKVCachePrecision) {
ASSERT_EQ(kv_cache_precision_value, ov::element::f32);
}

TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkFinetuneKVCachePrecision) {
ov::Core core;

core.set_property(deviceName, ov::hint::key_cache_precision(ov::element::f16));
core.set_property(deviceName, ov::hint::value_cache_precision(ov::element::u4));
ov::CompiledModel compiledModel = core.compile_model(model, deviceName);

auto key_cache_precision_value = ov::element::undefined;
auto value_cache_precision_value = ov::element::undefined;
OV_ASSERT_NO_THROW(key_cache_precision_value = compiledModel.get_property(ov::hint::key_cache_precision));
OV_ASSERT_NO_THROW(value_cache_precision_value = compiledModel.get_property(ov::hint::value_cache_precision));
ASSERT_EQ(key_cache_precision_value, ov::element::f16);
ASSERT_EQ(value_cache_precision_value, ov::element::u4);
}

TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkFinetuneKVCacheGroupSize) {
ov::Core core;

core.set_property(deviceName, ov::hint::key_cache_group_size(32));
core.set_property(deviceName, ov::hint::value_cache_group_size(16));
ov::CompiledModel compiledModel = core.compile_model(model, deviceName);

auto key_cache_group_size_value = 0;
auto value_cache_group_size_value = 0;
OV_ASSERT_NO_THROW(key_cache_group_size_value = compiledModel.get_property(ov::hint::key_cache_group_size));
OV_ASSERT_NO_THROW(value_cache_group_size_value = compiledModel.get_property(ov::hint::value_cache_group_size));
ASSERT_EQ(key_cache_group_size_value, 32);
ASSERT_EQ(value_cache_group_size_value, 16);
}

TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckAccuracyModeDynamicQuantizationGroupSize) {
ov::Core core;

Expand Down

0 comments on commit 244f7cc

Please sign in to comment.