diff --git a/src/common/transformations/include/transformations/utils/print_model.hpp b/src/common/transformations/include/transformations/utils/print_model.hpp index 53fa7de51c5eca..f9849bc7eeeb41 100644 --- a/src/common/transformations/include/transformations/utils/print_model.hpp +++ b/src/common/transformations/include/transformations/utils/print_model.hpp @@ -287,6 +287,7 @@ void dump_cpp_style(std::ostream& os, const std::shared_ptr& model) { return ss.str(); }; + int _idx = 0; // change name convension std::map opname; std::map opname_count; @@ -312,6 +313,7 @@ void dump_cpp_style(std::ostream& os, const std::shared_ptr& model) { name += std::to_string(idx); opname[op.get()] = name; + opname[op.get()] = op->get_type_name() + std::to_string(_idx++); } for (auto op : f.get_ordered_ops()) { diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp index a36085c34237a4..d94b174a60991b 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp @@ -19,8 +19,10 @@ #include "openvino/op/scaled_dot_product_attention.hpp" #include "openvino/op/select.hpp" #include "openvino/op/shape_of.hpp" +#include "openvino/op/slice.hpp" #include "openvino/op/sqrt.hpp" #include "openvino/op/strided_slice.hpp" +#include "openvino/op/subtract.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/unsqueeze.hpp" #include "openvino/op/variadic_split.hpp" @@ -155,6 +157,16 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par alibi_mask = pattern::wrap_type({alibi_mask, pattern::any_input()}); alibi_mask = pattern::wrap_type({pattern::any_input(), alibi_mask}); + // Baichuan2 13b case + // TODO: make it stricter, more conservative. + auto sub = pattern::wrap_type({pattern::any_input(), pattern::any_input()}); + auto select = pattern::wrap_type({pattern::any_input(), pattern::any_input(), sub}); + auto _alibi = pattern::any_input(); + auto Unsqueeze140 = pattern::wrap_type({_alibi, pattern::any_input()}); + auto Add141 = pattern::wrap_type({pattern::any_input(), Unsqueeze140}); + auto Slice147 = pattern::wrap_type( + {Add141, pattern::any_input(), pattern::any_input(), pattern::any_input(), pattern::any_input()}); + auto q = pattern::any_input(); auto scale_input = pattern::any_input(); @@ -162,7 +174,8 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par std::make_shared(OutputVector{k_concat, k_shaped, k_shaped_transposed, k_simply_shaped}); auto v_to_sdpa = std::make_shared(OutputVector{v_concat, v_shaped, v_shaped_transposed, v_simply_shaped}); - auto mask_to_sdpa = std::make_shared(OutputVector{sdpa_mask, alibi_mask, pattern::any_input()}); + auto mask_to_sdpa = + std::make_shared(OutputVector{sdpa_mask, alibi_mask, Slice147, pattern::any_input()}); auto sdpa_with_4_inputs = pattern::wrap_type({q, k_to_sdpa, v_to_sdpa, mask_to_sdpa}); @@ -335,7 +348,19 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par } std::shared_ptr alibi_slopes; - if (pattern_map.find(alibi) != pattern_map.end()) { + if (pattern_map.count(_alibi)) { + alibi_slopes = pattern_map.at(_alibi).get_node_shared_ptr(); + auto start = v0::Constant::create(element::i64, Shape{2}, {1, 1}); + auto stop = v0::Constant::create(element::i64, Shape{2}, {2, 2}); + auto step = v0::Constant::create(element::i64, Shape{2}, {1, 1}); + auto axes = v0::Constant::create(element::i64, Shape{2}, {1, 2}); + alibi_slopes = std::make_shared(alibi_slopes->input_value(0), start, stop, step, axes); + alibi_slopes = + std::make_shared(alibi_slopes, v0::Constant::create(element::i64, Shape{1}, {-1}), false); + if (alibi_slopes->get_element_type() == element::f32) { + alibi_slopes = std::make_shared(alibi_slopes, element::f32); + } + } else if (pattern_map.find(alibi) != pattern_map.end()) { alibi_slopes = std::make_shared(pattern_map.at(alibi), v0::Constant::create(element::i64, Shape{1}, {-1}), false); diff --git a/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp b/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp index 840309993c939a..a006964f9c00e5 100644 --- a/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp +++ b/src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp @@ -29,6 +29,7 @@ #include "openvino/op/subtract.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/unsqueeze.hpp" +#include "openvino/pass/visualize_tree.hpp" #include "transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp" #include "transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp" #include "transformations/utils/gen_pattern.hpp" @@ -616,3 +617,134 @@ TEST_F(TransformationTestsF, SDPAToPA_TotalSequenceLengthPatternQwen) { disable_result_friendly_names_check(); disable_rt_info_check(); } + +TEST(TransformationTests, Baichuan2_13b_alibi_slopes) { + auto beam_idx = makeOP({}, {{"shape", PartialShape{DYN}}, {"element_type", "i32"}}); + auto attention_mask = makeOP({}, {{"shape", PartialShape{DYN, DYN}}, {"element_type", "i64"}}); + auto input_ids = makeOP({}, {{"shape", PartialShape{DYN, DYN}}, {"element_type", "i64"}}); + ParameterVector params = nodes_to_params({input_ids, beam_idx, attention_mask}); + + auto input_ids_shape = makeOP({input_ids}, {{"output_type", "i64"}}); + auto batch_dim = makeOP({input_ids_shape, {0}, 0}, {{"batch_dims", 0}}); + + auto target_shape = makeOP({batch_dim, {40ll}, {0ll}, {128ll}}, {{"axis", 0}}); + auto init_to_read_value = makeOP({0.000000f, target_shape}, {{"mode", "numpy"}}); + + auto ReadValue63 = makeOP( + {init_to_read_value}, + {{"variable_id", "ID2"}, {"variable_type", "f32"}, {"variable_shape", PartialShape{DYN, 40, DYN, 128}}}); + auto Gather65 = makeOP({ReadValue63, beam_idx, 0}, {{"batch_dims", 0}}); + + auto Constant83 = makeConst(element::f32, ov::Shape({1, 1, 1, 1}), {1.000000f}); + auto Convert85 = makeOP({attention_mask}, {{"destination_type", "f32"}}); + auto Unsqueeze86 = makeOP({Convert85, 2}); + auto Unsqueeze87 = makeOP({Convert85, 1}); + auto Multiply88 = makeOP({Unsqueeze86, Unsqueeze87}, {{"auto_broadcast", "numpy"}}); + auto Constant89 = makeConst(element::f32, ov::Shape({1, 1, 1}), {0.000000f}); + auto Greater90 = makeOP({Multiply88, Constant89}, {{"auto_broadcast", "numpy"}}); + auto ShapeOf91 = makeOP({Greater90}, {{"output_type", "i32"}}); + auto Gather94 = makeOP({ShapeOf91, 1, 0}, {{"batch_dims", 0}}); + auto Range96 = makeOP({0, Gather94, 1}, {{"output_type", "i32"}}); + auto Unsqueeze97 = makeOP({Range96, 0}); + auto Unsqueeze98 = makeOP({Range96, 1}); + auto LessEqual99 = makeOP({Unsqueeze97, Unsqueeze98}, {{"auto_broadcast", "numpy"}}); + auto Constant100 = makeConst(element::boolean, ov::Shape({}), {0}); + auto Select101 = makeOP({LessEqual99, Greater90, Constant100}, {{"auto_broadcast", "numpy"}}); + auto Subtract102 = makeOP({Unsqueeze86, Unsqueeze87}, {{"auto_broadcast", "numpy"}}); + auto Constant103 = makeConst(element::f32, ov::Shape({1, 1, 1}), {0.000000f}); + auto Equal104 = makeOP({Subtract102, Constant103}, {{"auto_broadcast", "numpy"}}); + auto LogicalAnd105 = makeOP({Select101, Equal104}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze106 = makeOP({LogicalAnd105, 1}); + + auto ShapeOf107 = makeOP({input_ids}, {{"output_type", "i64"}}); + auto batch_dim_ = makeOP({ShapeOf107, {0}, 0}, {{"batch_dims", 0}}); + + auto ALIBI_CONST = makeConst(element::f32, ov::Shape({40, 4096, 4096}), MOCK_VALUE); + auto seq_len = makeOP({ShapeOf107, {1}, 0}, {{"batch_dims", 0}}); + auto ShapeOf117 = makeOP({Gather65}, {{"output_type", "i64"}}); + auto Gather120 = makeOP({ShapeOf117, {2}, 0}, {{"batch_dims", 0}}); + auto Add121 = makeOP({seq_len, Gather120}, {{"auto_broadcast", "numpy"}}); + auto Broadcast123 = makeOP({Add121, {2}}, {{"mode", "numpy"}}); + + // this slice expected to be replaced with Slice(alibi_const, start {1, 1}, stop {2, 2}, step {1, 1}, axes{1, 2}); + // ALIBI_CONST content: + /* + * >>> line1 = np.reshape(line1, (40, 4096, 4096)) + >>> print(line1[0][:][:]) + [['0' '-inf' '-inf' ... '-inf' '-inf' '-inf'] + ['0' '0.839844' '-inf' ... '-inf' '-inf' '-inf'] + ['0' '0.839844' '1.67969' ... '-inf' '-inf' '-inf'] + ... + ['0' '0.839844' '1.67969' ... '3440' '-inf' '-inf'] + ['0' '0.839844' '1.67969' ... '3440' '3440' '-inf'] + ['0' '0.839844' '1.67969' ... '3440' '3440' '3440']] + >>> print(line1[1][:][:]) + [['0' '-inf' '-inf' ... '-inf' '-inf' '-inf'] + ['0' '0.707031' '-inf' ... '-inf' '-inf' '-inf'] + ['0' '0.707031' '1.41406' ... '-inf' '-inf' '-inf'] + ... + ['0' '0.707031' '1.41406' ... '2896' '-inf' '-inf'] + ['0' '0.707031' '1.41406' ... '2896' '2896' '-inf'] + ['0' '0.707031' '1.41406' ... '2896' '2896' '2896']] + >>> print(line1[3][:][:]) + [['0' '-inf' '-inf' ... '-inf' '-inf' '-inf'] + ['0' '0.5' '-inf' ... '-inf' '-inf' '-inf'] + ['0' '0.5' '1' ... '-inf' '-inf' '-inf'] + ... + ['0' '0.5' '1' ... '2048' '-inf' '-inf'] + ['0' '0.5' '1' ... '2048' '2048' '-inf'] + ['0' '0.5' '1' ... '2048' '2048' '2048']] + >>> print(line1[4][:][:]) + [['0' '-inf' '-inf' ... '-inf' '-inf' '-inf'] + ['0' '0.419922' '-inf' ... '-inf' '-inf' '-inf'] + ['0' '0.419922' '0.839844' ... '-inf' '-inf' '-inf'] + ... + ['0' '0.419922' '0.839844' ... '1720' '-inf' '-inf'] + ['0' '0.419922' '0.839844' ... '1720' '1720' '-inf'] + ['0' '0.419922' '0.839844' ... '1720' '1720' '1720']] + + Slicing from {1, 1} to {2, 2} gives us the expected alibi slope constant to pass it to PagedAttention: + >>> print(line1[5][1][1]) + 0.353516 + >>> print(line1[4][1][1]) + 0.419922 + >>> print(line1[3][1][1]) + 0.5 + + ALibi slope const shape [40, 4096, 4096] + Slicing means that we take only 1 value from each 4096 x 4096 matrix here + The resulting constant will be [40, 1, 1] + After that we need to insert Reshape to get the expected rank = 1 (shape [40]) + */ + auto SLICED_BY_TOTAL_LEN = + makeOP({ALIBI_CONST, {0, 0}, Broadcast123, {1, 1}, {1, 2}}); // alibi??? 0...total_seq_len + + auto ShapeOf127 = makeOP({SLICED_BY_TOTAL_LEN}, {{"output_type", "i64"}}); + auto Gather130 = makeOP({ShapeOf127, {1, 2}, 0}, {{"batch_dims", 0}}); + auto Concat131 = makeOP({batch_dim_, {1ll}, Gather130}, {{"axis", 0}}); + auto Broadcast132 = makeOP({Unsqueeze106, Concat131}, {{"mode", "bidirectional"}}); + auto Convert133 = makeOP({Broadcast132}, {{"destination_type", "f32"}}); // alibi?? + + auto Constant134 = makeConst(element::f32, ov::Shape({1, 1, 1, 1}), {1.000000f}); + auto Multiply135 = makeOP({Convert133, Constant134}, {{"auto_broadcast", "numpy"}}); + auto Subtract136 = makeOP({Constant83, Multiply135}, {{"auto_broadcast", "numpy"}}); // alibi??? + auto Convert137 = makeOP({Subtract136}, {{"destination_type", "boolean"}}); + + auto Select139 = makeOP({Convert137, -FLT_MAX, Subtract136}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze140 = makeOP({SLICED_BY_TOTAL_LEN, 0}); + auto Add141 = makeOP({Select139, Unsqueeze140}, {{"auto_broadcast", "numpy"}}); + + auto Multiply143 = + makeOP({seq_len, {-1ll}}, {{"auto_broadcast", "numpy"}}); // attn_heads,maxpos, + + auto SLICED_BY_CURRENT_LEN = + makeOP({Add141, Multiply143, {LLONG_MAX}, {1}, {2}}); // tensor: 1, 40, current, total + auto res = std::make_shared(SLICED_BY_CURRENT_LEN); + auto mode = std::make_shared(ResultVector{res}, ParameterVector{params}); + + /*ov::pass::Manager manager; + manager.register_pass("alibi_slope_test.svg"); + manager.run_passes(mode);*/ + // auto ScaledDotProductAttention148 = makeOP({Transpose62, Concat72, Concat82, + // Slice147}, {{"causal", false}}); +} \ No newline at end of file diff --git a/src/core/src/pass/sdpa_to_paged_attention.cpp b/src/core/src/pass/sdpa_to_paged_attention.cpp index e6fc744bb5ef4f..a2a953eb2855fb 100644 --- a/src/core/src/pass/sdpa_to_paged_attention.cpp +++ b/src/core/src/pass/sdpa_to_paged_attention.cpp @@ -12,10 +12,12 @@ #include "openvino/op/subtract.hpp" #include "openvino/op/unsqueeze.hpp" #include "openvino/pass/manager.hpp" +#include "openvino/pass/visualize_tree.hpp" #include "transformations/sdpa_to_paged_attention/position_ids_replacer.hpp" #include "transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp" #include "transformations/sdpa_to_paged_attention/state_management_pattern.hpp" #include "transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp" +#include "transformations/utils/print_model.hpp" #include "transformations/utils/utils.hpp" using namespace ov::op; @@ -36,6 +38,10 @@ static std::shared_ptr setName(std::shared_ptr nod bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr& model) { RUN_ON_MODEL_SCOPE(SDPAToPagedAttention); + /* ov::pass::Manager manager3; + manager3.register_pass("before_pa_baichuan.svg"); + manager3.register_pass("before_pa_baichuan.txt"); + manager3.run_passes(model);*/ OPENVINO_ASSERT(ov::op::util::has_op_with_type(model), "No ScaledDotProductAttention operation observed in the graph, cannot perform" "the SDPAToPagedAttention transformation."); @@ -186,5 +192,9 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptradd_parameters(model_remaining_params); model->add_parameters({std::move(max_context_len)}); model->validate_nodes_and_infer_types(); + + /* ov::pass::Manager manager_2; + manager_2.register_pass("after_pa_baichuan.svg"); + manager_2.run_passes(model);*/ return true; }