From 51008b0ca2aab97a24aec4f10fbf4f499e437f12 Mon Sep 17 00:00:00 2001 From: Maksim Kutakov Date: Fri, 29 Nov 2024 19:46:27 +0100 Subject: [PATCH] Add SDPA Traspose fusing to the pipeline for half precision --- .../transformation_pipeline.cpp | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 836a5f5e25b531..4c15bc95f568b2 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -881,11 +881,28 @@ void Transformations::PostLpt() { // If the SDPA patterns haven't been fused into the special CPU optimized SDPA nodes, we have to decompose // these layers and run some auxilary transformation passes to let the snippets handle SDPA ops - CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::ScaledDotProductAttentionDecomposition); - CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::ConvertConvertLike); - CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::ConstantFolding); - CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::MoveEltwiseUpThroughDataMov); - CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::TransposeMatMul); + if (!one_of(config.inferencePrecision, element::bf16, element::f16)) { // So far Snippets don't support AMX MHA + CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::ScaledDotProductAttentionDecomposition); + CPU_SET_CALLBACK_COMMON(postLPTPassManager, [](const_node_ptr& node) { + // So far Snippets don't support AMX MHA + constexpr size_t QKV_inpt_number = 3ul; + for (size_t i = 0; i < QKV_inpt_number; ++i) { + if (one_of(node->get_input_element_type(i), element::bf16, element::f16)) { + return true; + } + } + return false; + }, + ov::pass::ScaledDotProductAttentionDecomposition); + + CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::ConvertConvertLike); + CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::ConstantFolding); + CPU_REGISTER_PASS_COMMON(postLPTPassManager, + ov::pass::MoveEltwiseUpThroughDataMovScalar, + std::vector{ov::op::v1::Transpose::get_type_info_static()}); + CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::TransposeMatMul); + } + CPU_REGISTER_PASS_X64(postLPTPassManager, ov::intel_cpu::SDPAFuseTransposeReshape); CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RMSFusion, false); CPU_REGISTER_PASS_X64(postLPTPassManager, ov::intel_cpu::DecomposeRMSNorm);