diff --git a/src/common/snippets/src/op/brgemm.cpp b/src/common/snippets/src/op/brgemm.cpp index 72fc692fff5d70..7190074c8ae30b 100644 --- a/src/common/snippets/src/op/brgemm.cpp +++ b/src/common/snippets/src/op/brgemm.cpp @@ -87,7 +87,8 @@ ov::element::Type Brgemm::get_output_type(const ov::element::Type& in_type0, con const bool is_f32 = utils::everyone_is(element::f32, in_type0, in_type1); const bool is_int8 = utils::one_of(in_type0, element::i8, element::u8) && in_type1 == element::i8; const bool is_bf16 = utils::everyone_is(element::bf16, in_type0, in_type1); - if (is_f32 || is_bf16) { + const bool is_f16 = utils::everyone_is(element::f16, in_type0, in_type1); + if (is_f32 || is_bf16 || is_f16) { return element::f32; } else if (is_int8) { return element::i32; diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp index 96a80153bba4b6..d937e646b603da 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp @@ -17,13 +17,12 @@ class jit_brgemm_copy_b_emitter : public jit_emitter { const ov::snippets::lowered::ExpressionPtr& expr, const snippets::KernelExecutorTablePtr& kernel_table, const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache); - size_t get_inputs_num() const override { return 1; } static std::set> get_supported_precisions( const std::shared_ptr& node = nullptr) { - return {{element::i8}, {element::bf16}, {element::f32}}; + return {{element::i8}, {element::bf16}, {element::f16}, {element::f32}}; } private: diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp index 172a1cc0b98284..8d343cec908732 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp @@ -79,7 +79,8 @@ std::set> jit_brgemm_emitter::get_supported_precision } else if (brgemm->get_type() == BRGEMM_TYPE::WITH_AMX) { return {{element::i8, element::i8, element::u8}, {element::u8, element::i8, element::u8}, - {element::bf16, element::bf16, element::u8}}; + {element::bf16, element::bf16, element::u8}, + {element::f16, element::f16, element::u8}}; } OV_CPU_JIT_EMITTER_THROW("got BrgemmCPU node with unsupported type"); } diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index 94e01cd89a39fa..2b0c7b55fb043d 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -458,11 +458,12 @@ void Subgraph::initSupportedPrimitiveDescriptors() { config.inConfs.resize(inputShapes.size()); for (size_t i = 0; i < inputShapes.size(); i++) { const auto originalInputPrecision = getOriginalInputPrecisionAtPort(i); - const auto precision = ((originalInputPrecision == ov::element::f32) && - context->getConfig().inferencePrecision == ov::element::bf16 && - subgraph_attrs->snippet->has_domain_sensitive_ops()) - ? static_cast(ov::element::bf16) - : originalInputPrecision; + const auto precision = + ((originalInputPrecision == ov::element::f32) && + one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) && + subgraph_attrs->snippet->has_domain_sensitive_ops()) + ? context->getConfig().inferencePrecision + : originalInputPrecision; if (supportedPrecisions.count(precision) == 0) OPENVINO_THROW("Subgraph node with name `", getName(), "` doesn't support ", precision, " precision."); @@ -653,7 +654,7 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() { ov::snippets::pass::Canonicalization, ov::snippets::pass::AnalyzeBroadcastableInputs, broadcastable_inputs); - if (context->getConfig().inferencePrecision == ov::element::bf16 && + if (one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) && subgraph_attrs->snippet->has_domain_sensitive_ops()) { // enforce BF16 precisions to supported operations // MatMul has to be decomposed to Brgemm operations before enforcement @@ -663,7 +664,7 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() { ov::snippets::pass::MatMulToBrgemm, pass::EnforcePrecision, element::f32, - element::bf16); + context->getConfig().inferencePrecision); } SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before, ov::snippets::pass::PropagatePrecision, diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp index a513299a516f5f..7e52905145869f 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp @@ -91,7 +91,7 @@ void BrgemmCopyB::validate_and_infer_types() { } void BrgemmCopyB::validate_element_type(const ov::element::Type& element_type) { - OPENVINO_ASSERT(one_of(element_type, element::f32, element::bf16, element::i8), + OPENVINO_ASSERT(one_of(element_type, element::f32, element::bf16, element::f16, element::i8), "BrgemmCopyB doesn't support element type" + element_type.get_type_name()); } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp index e1802d2914127a..386941fd94bb98 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp @@ -35,7 +35,12 @@ cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx) { // Note: AMX might be not used even if it's supported by the hardware, check the BrgemmToBrgemmCPU pass for details if (is_with_amx) { - SUPPORT_ONE(avx512_core_amx, "Unsupported hardware configuration: amx is supported only on avx512 platforms") + if (dt_in0 == ov::element::f16) + SUPPORT_ONE(avx512_core_amx_fp16, + "Unsupported hardware configuration: amx is supported only on avx512 platforms") + else + SUPPORT_ONE(avx512_core_amx, + "Unsupported hardware configuration: amx is supported only on avx512 platforms") } else if (dt_in0 == ov::element::bf16) { SUPPORT_ONE(avx512_core_bf16, "Unsupported hardware configuration: bf16 is supported only on avx512 platforms") } else if (one_of(dt_in0, ov::element::u8, ov::element::i8)) { @@ -59,12 +64,15 @@ BRGEMM_TYPE get_brgemm_type(const ov::element::Type& element_type_a, bool transp return transpose_b ? BRGEMM_TYPE::REPACKING_ONLY : BRGEMM_TYPE::STAND_ALONE; OPENVINO_ASSERT(element_type_a != element::bf16 || mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16), - "BF16 precision is not supported on this hardware"); + "BrgemmCPU BF16 precision is not supported on non avx512_core_bf16 system"); + OPENVINO_ASSERT(element_type_a != element::f16 || mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16), + "BrgemmCPU FP16 precision is not supported on non avx512_core_amx_fp16 system"); if (one_of(element_type_a, element::u8, element::i8, element::bf16) && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx)) return BRGEMM_TYPE::WITH_AMX; - + if (element_type_a == ov::element::f16 && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16)) + return BRGEMM_TYPE::WITH_AMX; // Note: this condition reproduces logic from the OneDNN Brgemm implementation. This is needed to align with the // backend requirements. More details in onednn/src/cpu/x64/brgemm/brgemm_utils.cpp if (element_type_a == ov::element::i8) @@ -96,6 +104,8 @@ size_t compute_inner_n_block(const ov::element::Type& precision) { return 64; case element::bf16: return 32; + case element::f16: + return 32; case element::f32: return 16; default: diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp index 46428828e7139c..b5f470c1c695ba 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp @@ -15,8 +15,8 @@ namespace intel_cpu { namespace brgemm_utils { enum class BRGEMM_TYPE { - STAND_ALONE, // No extra requirements, used for f32|f32 - WITH_AMX, // i8|i8 or bf16|bf16 on AMX system - needs BrgemmCopyB and scratchpad + STAND_ALONE, // No extra requirements, used for f32|f32 + WITH_AMX, // i8|i8 or bf16|bf16 on AMX system or fp16|fp16 on AMX_FP16 system - needs BrgemmCopyB and scratchpad WITH_COMPENSATIONS, // i8|i8 (non-AMX system) - needs BrgemmCopyB for data repacking and compensations REPACKING_ONLY, // u8|i8, or bf16|bf16 (non-AMX system), or brgemm with transpose_b=true - needs BrgemmCopyB on // second input for data repacking diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp index 2cbf2d7e087919..9475171b24f65d 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp @@ -25,7 +25,7 @@ namespace pass { * \ Buffer (with repacked data) Buffer (with compensations) * \ | / * BrgemmCPU - * - f32|f32 with transpose_b, u8|i8, i8|i8 or bf16|bf16 on AMX system: + * - f32|f32 with transpose_b, u8|i8, i8|i8 or bf16|bf16 on AMX system or fp16|fp16 on AMX_FP16 system: * \ BrgemmCopyB * \ Buffer (with repacked data) Buffer (with new memory) * \ | / diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp index 92b5be2692f3b2..6b7d5d31a5b12f 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp @@ -121,9 +121,12 @@ bool EnforcePrecision::run_on_model(const std::shared_ptr& f) { std::set> EnforcePrecision::get_supported_precisions_default( const std::shared_ptr& op) noexcept { - if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16) && - ov::is_type(op)) { - return {{element::bf16, element::bf16}}; + std::set> types; + if (ov::is_type(op)) { + if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16)) + types.insert({element::f16, element::f16}); + if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16)) + types.insert({element::bf16, element::bf16}); } - return {}; + return types; } diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 909f6b7531d421..4013c1c3cd84f9 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -866,7 +866,7 @@ void Transformations::PostLpt() { postLPTPassManager, [](const std::shared_ptr& node) -> bool { if (!ov::is_type(node) && - node->get_output_element_type(0) != node->get_input_element_type(0)) + node->get_output_element_type(0).size() > node->get_input_element_type(0).size()) return true; if (node->get_input_size() >= 2) { return node->get_input_element_type(1) == ov::element::i8 || @@ -986,7 +986,7 @@ void Transformations::MainSnippets(void) { // MatMul and Result. However there may be Convert [f32->bf16] before Result since: // - bf16 Brgemm has f32 output; // - CPU Node Subgraph requires bf16 on output when inference precision is bf16. - // To avoid sitations when Transpose is not alone node between MatMul and Result, + // To avoid situations when Transpose is not alone node between MatMul and Result, // Plugin disables Transpose tokenization on output bool mha_token_enable_transpose_on_output = one_of(config.inferencePrecision, element::f32, element::undefined); size_t concurrency = config.streamExecutorConfig.get_threads_per_stream(); @@ -1023,6 +1023,7 @@ void Transformations::MainSnippets(void) { ov::pass::Manager snippetsManager("CPU:Snippets"); snippetsManager.set_per_pass_validation(false); + // if callback needed for better perf, enable SnippetsMarkSkipped, and disable TokenizeFCSnippets. if (!ignoreCallback) { #if defined(OPENVINO_ARCH_ARM64) CPU_REGISTER_PASS_ARM(snippetsManager, SnippetsMarkSkipped); @@ -1033,9 +1034,7 @@ void Transformations::MainSnippets(void) { } CPU_REGISTER_PASS_COMMON(snippetsManager, snippets::pass::SnippetsTokenization, tokenization_config); - // - MHA has BRGEMM that is supported only on AVX512 platforms - // - CPU Plugin Subgraph supports only f32, bf16 (and quantized) BRGEMM - // [122494] Need to add support of f16 + // - CPU Plugin Subgraph supports f32, bf16, quantized and fp16(on avx_512_core_amx_fp16 target) BRGEMM const bool isMHASupported = #if defined(OPENVINO_ARCH_ARM64) false; @@ -1043,7 +1042,9 @@ void Transformations::MainSnippets(void) { (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) && one_of(config.inferencePrecision, ov::element::f32, element::undefined)) || (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) && - one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined)); + one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined)) || + (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) && + one_of(config.inferencePrecision, ov::element::f16)); #endif if (!isMHASupported) { CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::TokenizeMHASnippets); @@ -1059,13 +1060,13 @@ void Transformations::MainSnippets(void) { const auto in_type1 = matmul->get_input_element_type(1); const auto is_fp32 = (in_type0 == ov::element::f32 && in_type1 == ov::element::f32 && one_of(config.inferencePrecision, element::f32, element::undefined)); - const auto is_fp16 = (in_type0 == ov::element::f16 || in_type1 == ov::element::f16); + const auto is_fp16 = + (in_type0 == ov::element::f16 || in_type1 == ov::element::f16) || + (in_type0 == element::f32 && in_type1 == ov::element::f32 && config.inferencePrecision == ov::element::f16); const auto is_bf16 = (in_type0 == ov::element::bf16 && in_type1 == ov::element::bf16) || ((in_type0 == element::f32 && in_type1 == ov::element::f32 && config.inferencePrecision == ov::element::bf16)); const auto is_int8 = in_type0 == ov::element::i8; - if (is_fp16) - return false; if (is_fp32) return true; // Only FP32 dynamic MHA is supported @@ -1076,13 +1077,14 @@ void Transformations::MainSnippets(void) { // brgemm_copy_b kernel if (matmul->get_transpose_a() || matmul->get_transpose_b()) return false; - // [150842] The execution of Brgemm INT8/BF16 on AMX platforms depends on the value of "K % VNNIFactor". + // [150842] The execution of Brgemm INT8/BF16/FP16 on AMX platforms depends on the value of "K % VNNIFactor". // For more details, please teake a look at the ticket 150842 if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx)) { const auto& b_shape = matmul->get_input_partial_shape(1); const auto K = matmul->get_transpose_b() ? *b_shape.rbegin() : *++b_shape.rbegin(); - if (is_bf16) - return K.is_static() && (K.get_length() % 2 == 0); + const size_t brgemm_vnni_factor_for_real16 = 2; // 4/2(size in term of byte for bf16/fp16) + if (is_bf16 || is_fp16) + return K.is_static() && (K.get_length() % brgemm_vnni_factor_for_real16 == 0); if (is_int8) return K.is_static(); } @@ -1091,6 +1093,8 @@ void Transformations::MainSnippets(void) { dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2_vnni); if (is_bf16) return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16); + if (is_fp16) + return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16); return true; }; auto is_unsupported_parallel_work_amount = [&](const std::shared_ptr& n, diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mha.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mha.cpp index 8517612a348f68..a94f52be91df02 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mha.cpp @@ -189,7 +189,7 @@ class MHATest : public testing::WithParamInterface, virtual public Sub for (size_t i = 0; i < funcInputs.size(); ++i) { const auto& funcInput = funcInputs[i]; ov::Tensor tensor; - if (funcInput.get_element_type() == ov::element::bf16) { + if (funcInput.get_element_type() == ov::element::bf16 || funcInput.get_element_type() == ov::element::f16) { ov::test::utils::InputGenerateData in_data; in_data.start_from = -1; in_data.range = 2; @@ -232,6 +232,9 @@ class MHATest : public testing::WithParamInterface, virtual public Sub configuration.insert({ov::hint::inference_precision(ov::element::bf16)}); } + if (inputPrecisions[0] == ElementType::f16) + configuration.insert({ov::hint::inference_precision(ov::element::f16)}); + // Snippets MHA tokenization has limitations to avoid performance degradations. These limitations depend on // target machine. Just for testing, we disable these limitations to allow Snippets to tokenize pattern on all // machines for validation. @@ -253,6 +256,9 @@ TEST_P(MHATest, CompareWithRefs) { if (inputPrecisions[0] == ElementType::bf16 && !ov::with_cpu_x86_bfloat16()) GTEST_SKIP(); + if (inputPrecisions[0] == ElementType::f16 && !ov::with_cpu_x86_avx512_core_amx_fp16()) + GTEST_SKIP(); + if (!ov::with_cpu_x86_avx512_core()) GTEST_SKIP(); @@ -308,6 +314,20 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(ov::test::utils::DEVICE_CPU)), MHATest::getTestCaseName); +INSTANTIATE_TEST_SUITE_P( + smoke_MHA_FP16, + MHATest, + ::testing::Combine( + ::testing::ValuesIn(static_shapes_to_test_representation(inputShapes)), + ::testing::Values( + std::vector{ElementType::f16, ElementType::f16, ElementType::f16, ElementType::f16}), + ::testing::ValuesIn(matMulIn0Precisions), + ::testing::ValuesIn(patternTypes), + ::testing::Values(ExpectedNodes{{"Subgraph", 1}, + {"Transpose", 1}}), // Plugin disables tokenization of Transpose on output + ::testing::Values(ov::test::utils::DEVICE_CPU)), + MHATest::getTestCaseName); + } // namespace static std::shared_ptr initMHAQuantSubgraph0(std::vector& inputDynamicShapes, diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp index 089a03b4d6bba7..e9b38fedc0b4e5 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp @@ -565,6 +565,11 @@ std::vector disabledTestPatterns() { retVector.emplace_back(R"(.*smoke_Snippets_MHA.*EnforceBF16.*)"); retVector.emplace_back(R"(.*ConcatSDPTest.*bf16.*)"); } + // MHA FP16 precision is only supported on amx_fp16 platform + if (!ov::with_cpu_x86_avx512_core_amx_fp16()) { + retVector.emplace_back(R"(.*smoke_Snippets_MHA.*FP16.*)"); + } + #ifdef SNIPPETS_LIBXSMM_TPP // GN in TPP requires exposing tmp Buffer results outside the loop (ticket: 151234) retVector.emplace_back(R"(.*smoke_Snippets_GroupNormalization.*)"); diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index 63f5176684ccc1..df0b69f99ef06d 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -12,32 +12,41 @@ namespace snippets { namespace { -std::vector> transposedShape_4D(bool with_dynamic = true) { - auto shapes = SNIPPETS_TESTS_STATIC_SHAPES( - {{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}, - {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 1, 1}, {1, 128, 16, 64}}, - {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}}, - {{2, 68, 6, 92}, {2, 68, 6, 92}, {1, 1, 68, 68}, {2, 68, 6, 92}}, - {{1, 58, 16, 34}, {1, 58, 16, 34}, {1, 1, 1, 58}, {1, 58, 16, 34}}); +std::vector> transposedShape_4D(bool with_static = true, bool with_dynamic = true) { + std::vector> shapes; + if (with_static) { + auto static_shapes = + SNIPPETS_TESTS_STATIC_SHAPES({{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}, + {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 1, 1}, {1, 128, 16, 64}}, + {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}}, + {{2, 68, 6, 92}, {2, 68, 6, 92}, {1, 1, 68, 68}, {2, 68, 6, 92}}, + {{1, 58, 16, 34}, {1, 58, 16, 34}, {1, 1, 1, 58}, {1, 58, 16, 34}}); + shapes.insert(shapes.end(), static_shapes.begin(), static_shapes.end()); + } if (with_dynamic) { - std::vector> dynamic_shapes = {{ - {PartialShape{-1, -1, -1, 100}, {{1, 64, 4, 100}, {2, 16, 2, 100}, {1, 72, 4, 100}}}, - {PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}}, - {PartialShape{-1, -1, -1, 128}, {{1, 4, 64, 128}, {2, 2, 16, 128}, {1, 4, 72, 128}}}, - {PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}}, - }, - { - {PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 16, 2, 100}, {1, 128, 3, 64}}}, - {PartialShape{-1, -1, -1, -1}, {{1, 128, 1, 64}, {2, 128, 2, 100}, {1, 128, 1, 64}}}, - {PartialShape{-1, -1, -1, -1}, {{2, 1, 128, 128}, {2, 2, 16, 128}, {2, 1, 128, 128}}}, - {PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 128, 2, 100}, {1, 128, 3, 64}}}, - }, - { - {PartialShape{-1, -1, 12, 64}, {{1, 70, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 70, 12, 64}}}, - {PartialShape{-1, -1, 12, 64}, {{1, 35, 12, 64}, {2, 10, 12, 64}, {2, 1, 12, 64}, {2, 10, 12, 64}, {1, 35, 12, 64}}}, - {PartialShape{-1, 12, -1, -1}, {{2, 12, 70, 35}, {1, 12, 20, 10}, {1, 12, 20, 10}, {1, 12, 20, 1}, {2, 12, 70, 35}}}, - {PartialShape{-1, -1, 12, 64}, {{1, 35, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 35, 12, 64}}}, - }}; + std::vector> dynamic_shapes = { + { + {PartialShape{-1, -1, -1, 100}, {{1, 64, 4, 100}, {2, 16, 2, 100}, {1, 72, 4, 100}}}, + {PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}}, + {PartialShape{-1, -1, -1, 128}, {{1, 4, 64, 128}, {2, 2, 16, 128}, {1, 4, 72, 128}}}, + {PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}}, + }, + { + {PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 16, 2, 100}, {1, 128, 3, 64}}}, + {PartialShape{-1, -1, -1, -1}, {{1, 128, 1, 64}, {2, 128, 2, 100}, {1, 128, 1, 64}}}, + {PartialShape{-1, -1, -1, -1}, {{2, 1, 128, 128}, {2, 2, 16, 128}, {2, 1, 128, 128}}}, + {PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 128, 2, 100}, {1, 128, 3, 64}}}, + }, + { + {PartialShape{-1, -1, 12, 64}, + {{1, 70, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 70, 12, 64}}}, + {PartialShape{-1, -1, 12, 64}, + {{1, 35, 12, 64}, {2, 10, 12, 64}, {2, 1, 12, 64}, {2, 10, 12, 64}, {1, 35, 12, 64}}}, + {PartialShape{-1, 12, -1, -1}, + {{2, 12, 70, 35}, {1, 12, 20, 10}, {1, 12, 20, 10}, {1, 12, 20, 1}, {2, 12, 70, 35}}}, + {PartialShape{-1, -1, 12, 64}, + {{1, 35, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 35, 12, 64}}}, + }}; shapes.insert(shapes.end(), dynamic_shapes.begin(), dynamic_shapes.end()); } return shapes; @@ -74,7 +83,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_4D, INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_4D_WithScalarMul, MHA, - ::testing::Combine(::testing::ValuesIn(transposedShape_4D(false)), + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(true, false)), ::testing::ValuesIn(precision_f32(4)), ::testing::Values(ov::element::f32), ::testing::Values(true), @@ -137,6 +146,80 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16, ::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)), MHA::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D_Without_Multiply, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D()), + ::testing::ValuesIn(precision_fp16_if_supported(4)), + ::testing::Values(ov::element::f16), + ::testing::ValuesIn({false}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(2), + ::testing::Values(1), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::empty_plugin_config)), + MHA::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D_With_Multiply_Static, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(true, false)), + ::testing::ValuesIn(precision_fp16_if_supported(4)), + ::testing::Values(ov::element::f16), + ::testing::ValuesIn({true}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(2), + ::testing::Values(1), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::empty_plugin_config)), + MHA::getTestCaseName); +// 3 nodes and 2 subgraph for dynamic with multiply case. +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D_With_Multiply_Dynamic, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(false, true)), + ::testing::ValuesIn(precision_fp16_if_supported(4)), + ::testing::Values(ov::element::f16), + ::testing::ValuesIn({true}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(3), + ::testing::Values(2), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::empty_plugin_config)), + MHA::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16_Without_Multiply, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D()), + ::testing::ValuesIn(precision_f32(4)), + ::testing::Values(ov::element::f16), + ::testing::ValuesIn({false}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(2), + ::testing::Values(1), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpu_f16_plugin_config)), + MHA::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16_With_Multiply_Static, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(true, false)), + ::testing::ValuesIn(precision_f32(4)), + ::testing::Values(ov::element::f16), + ::testing::ValuesIn({true}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(2), + ::testing::Values(1), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpu_f16_plugin_config)), + MHA::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16_With_Multiply_Dynamic, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(false, true)), + ::testing::ValuesIn(precision_f32(4)), + ::testing::Values(ov::element::f16), + ::testing::ValuesIn({true}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(3), + ::testing::Values(2), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpu_f16_plugin_config)), + MHA::getTestCaseName); } // namespace } // namespace snippets } // namespace test diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/utils.hpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/utils.hpp index 6c0d54da973086..6815cdab671cea 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/utils.hpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/utils.hpp @@ -16,6 +16,10 @@ static inline bool is_bf16_supported_by_brgemm() { return ov::with_cpu_x86_bfloat16() || ov::with_cpu_x86_avx512_core_amx_bf16(); } +static inline bool is_fp16_supported_by_brgemm() { + return ov::with_cpu_x86_avx512_core_amx_fp16(); +} + static inline bool is_i8_supported_by_brgemm() { return ov::with_cpu_x86_avx512_core_vnni() || ov::with_cpu_x86_avx512_core_amx_int8(); } @@ -33,6 +37,13 @@ static inline std::vector> precision_bf16_if_supporte return prc; } +static inline std::vector> precision_fp16_if_supported(size_t count) { + std::vector> prc; + if (is_fp16_supported_by_brgemm()) + prc.emplace_back(std::vector(count, element::f16)); + return prc; +} + static inline std::vector> quantized_precisions_if_supported() { std::vector> prc = {}; // In Snippets MatMul INT8 is supported only on VNNI/AMX platforms diff --git a/src/tests/functional/plugin/shared/src/snippets/mha.cpp b/src/tests/functional/plugin/shared/src/snippets/mha.cpp index 8d0cb8613bc47e..0a8fcc77717c42 100644 --- a/src/tests/functional/plugin/shared/src/snippets/mha.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/mha.cpp @@ -65,6 +65,8 @@ void MHABase::SetUp() { #endif if (inType == ov::element::bf16) rel_threshold = 0.05f; + if (inType == ov::element::f16) + abs_threshold = 2e-2; } std::string MHA::getTestCaseName(testing::TestParamInfo obj) {