Skip to content

Commit

Permalink
[TRANSFORMATIONS] Derive 'scale' from hidden_dim directly in SDPAToPA (
Browse files Browse the repository at this point in the history
…openvinotoolkit#28091)

Currently 'scale' is obtained using a ShapeOf expression as the
hidden_dim may be dynamic in some cases and not propagated, so we can't
use it directly to create a 'scale' Constant.

Check if hidden_dim is static and use it to calculate 'scale' directly
omitting the ShapeOf expression.

Ticket:
* [CVS-158394](https://jira.devtools.intel.com/browse/CVS-158394)

Signed-off-by: Andrii Staikov <[email protected]>
  • Loading branch information
CuriousPanCake authored Dec 19, 2024
1 parent 60d7264 commit 6acc929
Showing 1 changed file with 18 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -310,20 +310,28 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
auto v_reshape =
std::make_shared<v1::Reshape>(v_target_layout, v0::Constant::create(element::i64, Shape{2}, {0, -1}), true);

auto hidden_shape = std::make_shared<v3::ShapeOf>(real_q);
auto hidden_dim = std::make_shared<v8::Gather>(hidden_shape,
v0::Constant::create(element::i64, Shape{}, {-1}),
v0::Constant::create(element::i64, Shape{}, {0}));
std::shared_ptr<ov::Node> scale;
if (pattern_map.count(scale_input)) {
scale = pattern_map.at(scale_input).get_node_shared_ptr();
} else {
// most likely `scale` below will always be a constant in real inference, but dynamic dimension
// propagation may not always derive it as a constant. That's why a sub-graph computing `scale` is built
// instead of just a constant node representing one of the dimensions.
scale = std::make_shared<v1::Divide>(
v0::Constant::create(element::f32, Shape{}, {1}),
std::make_shared<v0::Sqrt>(std::make_shared<v0::Convert>(hidden_dim, element::f32)));
auto real_q_ps = real_q.get_partial_shape();

bool rank_is_static = real_q_ps.rank().is_static();
if (rank_is_static && real_q_ps[real_q_ps.rank().get_length() - 1].is_static()) {
auto hidden_dim_len = static_cast<float>(real_q_ps[real_q_ps.rank().get_length() - 1].get_length());
scale = v0::Constant::create(element::f32, Shape{}, {1.0 / std::sqrt(hidden_dim_len)});
} else {
// most likely `scale` below will always be a constant in real inference, but dynamic dimension
// propagation may not always derive it as a constant. That's why a sub-graph computing `scale` is built
// instead of just a constant node representing one of the dimensions.
auto hidden_shape = std::make_shared<v3::ShapeOf>(real_q);
auto hidden_dim = std::make_shared<v8::Gather>(hidden_shape,
v0::Constant::create(element::i64, Shape{}, {-1}),
v0::Constant::create(element::i64, Shape{}, {0}));
scale = std::make_shared<v1::Divide>(
v0::Constant::create(element::f32, Shape{}, {1}),
std::make_shared<v0::Sqrt>(std::make_shared<v0::Convert>(hidden_dim, element::f32)));
}
}

std::shared_ptr<Node> alibi_slopes;
Expand Down

0 comments on commit 6acc929

Please sign in to comment.