Skip to content

Commit

Permalink
Add SDPA Traspose fusing to the pipeline for half precision
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnick committed Nov 29, 2024
1 parent d959369 commit 51008b0
Showing 1 changed file with 22 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -881,11 +881,28 @@ void Transformations::PostLpt() {
// If the SDPA patterns haven't been fused into the special CPU optimized SDPA nodes, we have to decompose
// these layers and run some auxilary transformation passes to let the snippets handle SDPA ops

CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::ScaledDotProductAttentionDecomposition);
CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::ConvertConvertLike);
CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::ConstantFolding);
CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::MoveEltwiseUpThroughDataMov);
CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::TransposeMatMul);
if (!one_of(config.inferencePrecision, element::bf16, element::f16)) { // So far Snippets don't support AMX MHA
CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::ScaledDotProductAttentionDecomposition);
CPU_SET_CALLBACK_COMMON(postLPTPassManager, [](const_node_ptr& node) {
// So far Snippets don't support AMX MHA
constexpr size_t QKV_inpt_number = 3ul;
for (size_t i = 0; i < QKV_inpt_number; ++i) {
if (one_of(node->get_input_element_type(i), element::bf16, element::f16)) {
return true;
}
}
return false;
},
ov::pass::ScaledDotProductAttentionDecomposition);

CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::ConvertConvertLike);
CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::ConstantFolding);
CPU_REGISTER_PASS_COMMON(postLPTPassManager,
ov::pass::MoveEltwiseUpThroughDataMovScalar,
std::vector<DiscreteTypeInfo>{ov::op::v1::Transpose::get_type_info_static()});
CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::TransposeMatMul);
}
CPU_REGISTER_PASS_X64(postLPTPassManager, ov::intel_cpu::SDPAFuseTransposeReshape);

CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RMSFusion, false);
CPU_REGISTER_PASS_X64(postLPTPassManager, ov::intel_cpu::DecomposeRMSNorm);
Expand Down

0 comments on commit 51008b0

Please sign in to comment.