Skip to content

Commit

Permalink
[Snippets] Fixed SplitDimensionM pass for Subgraphs with dynamic para…
Browse files Browse the repository at this point in the history
…ms (#28280)

### Details:
- *Currently, the pass `SplitDimensionM` supports only static inputs of
Subgraphs because the pass inserts `Reshape` ops with const shapes. In
some cases (the case from the ticket), MatMul may have static output
shape but some parameters - dynamic shape. Then `SplitDimensionM` should
not call `split` method. The PR added check with early `return` to cover
such cases*

### Tickets:
 - *159661*
  • Loading branch information
a-sidorova authored Jan 8, 2025
1 parent 3526fa5 commit 6739444
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/common/snippets/src/pass/split_dimension_m.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,12 @@ bool SplitDimensionM::run_on_subgraph(const std::shared_ptr<op::Subgraph>& subgr
if (!subgraph->has_domain_sensitive_ops())
return false;

// The pass supports only static shapes on Subgraph inputs due to static `Reshape` insertion around Subgraph.
const auto& params = subgraph->body_ptr()->get_parameters();
const auto is_dynamic = [](const std::shared_ptr<ov::Node>& p) { return p->get_output_partial_shape(0).is_dynamic(); };
if (std::any_of(params.cbegin(), params.cend(), is_dynamic))
return false;

if (const auto matmul0 = get_matmul(subgraph)) {
const auto mm_shape = matmul0->get_shape();
size_t batch_m_dim, new_m_dim;
Expand Down
9 changes: 9 additions & 0 deletions src/common/snippets/tests/src/pass/mha_tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,15 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_SplitM_AlmostAllThreads) {
run();
}

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_4D_SplitM_DynamicParameter) {
const auto &f = MHAFunction(std::vector<PartialShape>{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 128, -1}, {1, 128, 16, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}), false, false);
model = f.getOriginal();
model_ref = f.getReference();
config.set_concurrency(32);
run();
}

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHASelect_SplitM) {
const auto& f = MHASelectSplitMFunction(std::vector<PartialShape>{{8, 512, 18}, {8, 18, 64}, {1, 512, 64}, {1, 1, 64}, {8, 64, 512}},
std::vector<Shape>{{8, 2, 256, 18}, {8, 1, 18, 64}, {1, 2, 256, 64}, {1, 1, 1, 64},
Expand Down

0 comments on commit 6739444

Please sign in to comment.