From 2643652cccf705237dc6bfa25337da6f3e40e827 Mon Sep 17 00:00:00 2001 From: Xiuchuan Zhai Date: Tue, 24 Sep 2024 17:58:39 +0800 Subject: [PATCH] fix sdpa fusion failed in f16 case --- .../intel_cpu/src/transformations/transformation_pipeline.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 38649b2906e9e3..52c41209e895e3 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -389,8 +389,10 @@ void Transformations::PreLpt(const std::vector& defaultPrecis precisions_map fp_convert_precision_map = {{ov::element::f32, ov::element::f16}}; #if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64) type_to_fuse_map fuse_map = {{ov::opset1::FakeQuantize::get_type_info_static(), fuse_type_to_fq}}; + constexpr bool cvt_input_output_precision = true; #else type_to_fuse_map fuse_map = {}; + constexpr bool cvt_input_output_precision = false; #endif const bool keep_precision_sensitive_in_fp32 = true; CPU_REGISTER_PASS_COMMON(manager, @@ -398,7 +400,7 @@ void Transformations::PreLpt(const std::vector& defaultPrecis fp_convert_precision_map, fuse_map, keep_precision_sensitive_in_fp32, - false); + cvt_input_output_precision); } CPU_REGISTER_PASS_COMMON(manager, ov::pass::KeepConstAndDecompression); CPU_SET_CALLBACK_COMMON(manager,