Skip to content

Commit

Permalink
Add Alibi slopes pattern for baichuan2 13b
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Dec 28, 2024
1 parent b25413c commit 232a7c9
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ void dump_cpp_style(std::ostream& os, const std::shared_ptr<ov::Model>& model) {
return ss.str();
};

int _idx = 0;
// change name convension
std::map<ov::Node*, std::string> opname;
std::map<std::string, int> opname_count;
Expand All @@ -312,6 +313,7 @@ void dump_cpp_style(std::ostream& os, const std::shared_ptr<ov::Model>& model) {
name += std::to_string(idx);

opname[op.get()] = name;
opname[op.get()] = op->get_type_name() + std::to_string(_idx++);
}

for (auto op : f.get_ordered_ops()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/sqrt.hpp"
#include "openvino/op/strided_slice.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/op/variadic_split.hpp"
Expand Down Expand Up @@ -155,14 +157,25 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
alibi_mask = pattern::wrap_type<v0::Unsqueeze>({alibi_mask, pattern::any_input()});
alibi_mask = pattern::wrap_type<v1::Add>({pattern::any_input(), alibi_mask});

// Baichuan2 13b case
// TODO: make it stricter, more conservative.
auto sub = pattern::wrap_type<v1::Subtract>({pattern::any_input(), pattern::any_input()});
auto select = pattern::wrap_type<v1::Select>({pattern::any_input(), pattern::any_input(), sub});
auto _alibi = pattern::any_input();
auto Unsqueeze140 = pattern::wrap_type<v0::Unsqueeze>({_alibi, pattern::any_input()});
auto Add141 = pattern::wrap_type<v1::Add>({pattern::any_input(), Unsqueeze140});
auto Slice147 = pattern::wrap_type<v8::Slice>(
{Add141, pattern::any_input(), pattern::any_input(), pattern::any_input(), pattern::any_input()});

auto q = pattern::any_input();
auto scale_input = pattern::any_input();

auto k_to_sdpa =
std::make_shared<pattern::op::Or>(OutputVector{k_concat, k_shaped, k_shaped_transposed, k_simply_shaped});
auto v_to_sdpa =
std::make_shared<pattern::op::Or>(OutputVector{v_concat, v_shaped, v_shaped_transposed, v_simply_shaped});
auto mask_to_sdpa = std::make_shared<pattern::op::Or>(OutputVector{sdpa_mask, alibi_mask, pattern::any_input()});
auto mask_to_sdpa =
std::make_shared<pattern::op::Or>(OutputVector{sdpa_mask, alibi_mask, Slice147, pattern::any_input()});

auto sdpa_with_4_inputs =
pattern::wrap_type<v13::ScaledDotProductAttention>({q, k_to_sdpa, v_to_sdpa, mask_to_sdpa});
Expand Down Expand Up @@ -335,7 +348,19 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
}

std::shared_ptr<Node> alibi_slopes;
if (pattern_map.find(alibi) != pattern_map.end()) {
if (pattern_map.count(_alibi)) {
alibi_slopes = pattern_map.at(_alibi).get_node_shared_ptr();
auto start = v0::Constant::create(element::i64, Shape{2}, {1, 1});
auto stop = v0::Constant::create(element::i64, Shape{2}, {2, 2});
auto step = v0::Constant::create(element::i64, Shape{2}, {1, 1});
auto axes = v0::Constant::create(element::i64, Shape{2}, {1, 2});
alibi_slopes = std::make_shared<v8::Slice>(alibi_slopes->input_value(0), start, stop, step, axes);
alibi_slopes =
std::make_shared<v1::Reshape>(alibi_slopes, v0::Constant::create(element::i64, Shape{1}, {-1}), false);
if (alibi_slopes->get_element_type() == element::f32) {
alibi_slopes = std::make_shared<v0::Convert>(alibi_slopes, element::f32);
}
} else if (pattern_map.find(alibi) != pattern_map.end()) {
alibi_slopes = std::make_shared<v1::Reshape>(pattern_map.at(alibi),
v0::Constant::create(element::i64, Shape{1}, {-1}),
false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "openvino/op/subtract.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/pass/visualize_tree.hpp"
#include "transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp"
#include "transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp"
#include "transformations/utils/gen_pattern.hpp"
Expand Down Expand Up @@ -616,3 +617,134 @@ TEST_F(TransformationTestsF, SDPAToPA_TotalSequenceLengthPatternQwen) {
disable_result_friendly_names_check();
disable_rt_info_check();
}

TEST(TransformationTests, Baichuan2_13b_alibi_slopes) {
auto beam_idx = makeOP<opset1::Parameter>({}, {{"shape", PartialShape{DYN}}, {"element_type", "i32"}});
auto attention_mask = makeOP<opset1::Parameter>({}, {{"shape", PartialShape{DYN, DYN}}, {"element_type", "i64"}});
auto input_ids = makeOP<opset1::Parameter>({}, {{"shape", PartialShape{DYN, DYN}}, {"element_type", "i64"}});
ParameterVector params = nodes_to_params({input_ids, beam_idx, attention_mask});

auto input_ids_shape = makeOP<opset3::ShapeOf>({input_ids}, {{"output_type", "i64"}});
auto batch_dim = makeOP<opset8::Gather>({input_ids_shape, {0}, 0}, {{"batch_dims", 0}});

auto target_shape = makeOP<opset1::Concat>({batch_dim, {40ll}, {0ll}, {128ll}}, {{"axis", 0}});
auto init_to_read_value = makeOP<opset3::Broadcast>({0.000000f, target_shape}, {{"mode", "numpy"}});

auto ReadValue63 = makeOP<opset6::ReadValue>(
{init_to_read_value},
{{"variable_id", "ID2"}, {"variable_type", "f32"}, {"variable_shape", PartialShape{DYN, 40, DYN, 128}}});
auto Gather65 = makeOP<opset8::Gather>({ReadValue63, beam_idx, 0}, {{"batch_dims", 0}});

auto Constant83 = makeConst(element::f32, ov::Shape({1, 1, 1, 1}), {1.000000f});
auto Convert85 = makeOP<opset1::Convert>({attention_mask}, {{"destination_type", "f32"}});
auto Unsqueeze86 = makeOP<opset1::Unsqueeze>({Convert85, 2});
auto Unsqueeze87 = makeOP<opset1::Unsqueeze>({Convert85, 1});
auto Multiply88 = makeOP<opset1::Multiply>({Unsqueeze86, Unsqueeze87}, {{"auto_broadcast", "numpy"}});
auto Constant89 = makeConst(element::f32, ov::Shape({1, 1, 1}), {0.000000f});
auto Greater90 = makeOP<opset1::Greater>({Multiply88, Constant89}, {{"auto_broadcast", "numpy"}});
auto ShapeOf91 = makeOP<opset3::ShapeOf>({Greater90}, {{"output_type", "i32"}});
auto Gather94 = makeOP<opset8::Gather>({ShapeOf91, 1, 0}, {{"batch_dims", 0}});
auto Range96 = makeOP<opset4::Range>({0, Gather94, 1}, {{"output_type", "i32"}});
auto Unsqueeze97 = makeOP<opset1::Unsqueeze>({Range96, 0});
auto Unsqueeze98 = makeOP<opset1::Unsqueeze>({Range96, 1});
auto LessEqual99 = makeOP<opset1::LessEqual>({Unsqueeze97, Unsqueeze98}, {{"auto_broadcast", "numpy"}});
auto Constant100 = makeConst(element::boolean, ov::Shape({}), {0});
auto Select101 = makeOP<opset1::Select>({LessEqual99, Greater90, Constant100}, {{"auto_broadcast", "numpy"}});
auto Subtract102 = makeOP<opset1::Subtract>({Unsqueeze86, Unsqueeze87}, {{"auto_broadcast", "numpy"}});
auto Constant103 = makeConst(element::f32, ov::Shape({1, 1, 1}), {0.000000f});
auto Equal104 = makeOP<opset1::Equal>({Subtract102, Constant103}, {{"auto_broadcast", "numpy"}});
auto LogicalAnd105 = makeOP<opset1::LogicalAnd>({Select101, Equal104}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze106 = makeOP<opset1::Unsqueeze>({LogicalAnd105, 1});

auto ShapeOf107 = makeOP<opset3::ShapeOf>({input_ids}, {{"output_type", "i64"}});
auto batch_dim_ = makeOP<opset8::Gather>({ShapeOf107, {0}, 0}, {{"batch_dims", 0}});

auto ALIBI_CONST = makeConst(element::f32, ov::Shape({40, 4096, 4096}), MOCK_VALUE);
auto seq_len = makeOP<opset8::Gather>({ShapeOf107, {1}, 0}, {{"batch_dims", 0}});
auto ShapeOf117 = makeOP<opset3::ShapeOf>({Gather65}, {{"output_type", "i64"}});
auto Gather120 = makeOP<opset8::Gather>({ShapeOf117, {2}, 0}, {{"batch_dims", 0}});
auto Add121 = makeOP<opset1::Add>({seq_len, Gather120}, {{"auto_broadcast", "numpy"}});
auto Broadcast123 = makeOP<opset3::Broadcast>({Add121, {2}}, {{"mode", "numpy"}});

// this slice expected to be replaced with Slice(alibi_const, start {1, 1}, stop {2, 2}, step {1, 1}, axes{1, 2});
// ALIBI_CONST content:
/*
* >>> line1 = np.reshape(line1, (40, 4096, 4096))
>>> print(line1[0][:][:])
[['0' '-inf' '-inf' ... '-inf' '-inf' '-inf']
['0' '0.839844' '-inf' ... '-inf' '-inf' '-inf']
['0' '0.839844' '1.67969' ... '-inf' '-inf' '-inf']
...
['0' '0.839844' '1.67969' ... '3440' '-inf' '-inf']
['0' '0.839844' '1.67969' ... '3440' '3440' '-inf']
['0' '0.839844' '1.67969' ... '3440' '3440' '3440']]
>>> print(line1[1][:][:])
[['0' '-inf' '-inf' ... '-inf' '-inf' '-inf']
['0' '0.707031' '-inf' ... '-inf' '-inf' '-inf']
['0' '0.707031' '1.41406' ... '-inf' '-inf' '-inf']
...
['0' '0.707031' '1.41406' ... '2896' '-inf' '-inf']
['0' '0.707031' '1.41406' ... '2896' '2896' '-inf']
['0' '0.707031' '1.41406' ... '2896' '2896' '2896']]
>>> print(line1[3][:][:])
[['0' '-inf' '-inf' ... '-inf' '-inf' '-inf']
['0' '0.5' '-inf' ... '-inf' '-inf' '-inf']
['0' '0.5' '1' ... '-inf' '-inf' '-inf']
...
['0' '0.5' '1' ... '2048' '-inf' '-inf']
['0' '0.5' '1' ... '2048' '2048' '-inf']
['0' '0.5' '1' ... '2048' '2048' '2048']]
>>> print(line1[4][:][:])
[['0' '-inf' '-inf' ... '-inf' '-inf' '-inf']
['0' '0.419922' '-inf' ... '-inf' '-inf' '-inf']
['0' '0.419922' '0.839844' ... '-inf' '-inf' '-inf']
...
['0' '0.419922' '0.839844' ... '1720' '-inf' '-inf']
['0' '0.419922' '0.839844' ... '1720' '1720' '-inf']
['0' '0.419922' '0.839844' ... '1720' '1720' '1720']]
Slicing from {1, 1} to {2, 2} gives us the expected alibi slope constant to pass it to PagedAttention:
>>> print(line1[5][1][1])
0.353516
>>> print(line1[4][1][1])
0.419922
>>> print(line1[3][1][1])
0.5
ALibi slope const shape [40, 4096, 4096]
Slicing means that we take only 1 value from each 4096 x 4096 matrix here
The resulting constant will be [40, 1, 1]
After that we need to insert Reshape to get the expected rank = 1 (shape [40])
*/
auto SLICED_BY_TOTAL_LEN =
makeOP<opset8::Slice>({ALIBI_CONST, {0, 0}, Broadcast123, {1, 1}, {1, 2}}); // alibi??? 0...total_seq_len

auto ShapeOf127 = makeOP<opset3::ShapeOf>({SLICED_BY_TOTAL_LEN}, {{"output_type", "i64"}});
auto Gather130 = makeOP<opset8::Gather>({ShapeOf127, {1, 2}, 0}, {{"batch_dims", 0}});
auto Concat131 = makeOP<opset1::Concat>({batch_dim_, {1ll}, Gather130}, {{"axis", 0}});
auto Broadcast132 = makeOP<opset3::Broadcast>({Unsqueeze106, Concat131}, {{"mode", "bidirectional"}});
auto Convert133 = makeOP<opset1::Convert>({Broadcast132}, {{"destination_type", "f32"}}); // alibi??

auto Constant134 = makeConst(element::f32, ov::Shape({1, 1, 1, 1}), {1.000000f});
auto Multiply135 = makeOP<opset1::Multiply>({Convert133, Constant134}, {{"auto_broadcast", "numpy"}});
auto Subtract136 = makeOP<opset1::Subtract>({Constant83, Multiply135}, {{"auto_broadcast", "numpy"}}); // alibi???
auto Convert137 = makeOP<opset1::Convert>({Subtract136}, {{"destination_type", "boolean"}});

auto Select139 = makeOP<opset1::Select>({Convert137, -FLT_MAX, Subtract136}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze140 = makeOP<opset1::Unsqueeze>({SLICED_BY_TOTAL_LEN, 0});
auto Add141 = makeOP<opset1::Add>({Select139, Unsqueeze140}, {{"auto_broadcast", "numpy"}});

auto Multiply143 =
makeOP<opset1::Multiply>({seq_len, {-1ll}}, {{"auto_broadcast", "numpy"}}); // attn_heads,maxpos,

auto SLICED_BY_CURRENT_LEN =
makeOP<opset8::Slice>({Add141, Multiply143, {LLONG_MAX}, {1}, {2}}); // tensor: 1, 40, current, total
auto res = std::make_shared<op::v0::Result>(SLICED_BY_CURRENT_LEN);
auto mode = std::make_shared<ov::Model>(ResultVector{res}, ParameterVector{params});

/*ov::pass::Manager manager;
manager.register_pass<ov::pass::VisualizeTree>("alibi_slope_test.svg");
manager.run_passes(mode);*/
// auto ScaledDotProductAttention148 = makeOP<v13::ScaledDotProductAttention>({Transpose62, Concat72, Concat82,
// Slice147}, {{"causal", false}});
}
10 changes: 10 additions & 0 deletions src/core/src/pass/sdpa_to_paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
#include "openvino/op/subtract.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/visualize_tree.hpp"
#include "transformations/sdpa_to_paged_attention/position_ids_replacer.hpp"
#include "transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp"
#include "transformations/sdpa_to_paged_attention/state_management_pattern.hpp"
#include "transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp"
#include "transformations/utils/print_model.hpp"
#include "transformations/utils/utils.hpp"

using namespace ov::op;
Expand All @@ -36,6 +38,10 @@ static std::shared_ptr<v0::Parameter> setName(std::shared_ptr<v0::Parameter> nod
bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Model>& model) {
RUN_ON_MODEL_SCOPE(SDPAToPagedAttention);

/* ov::pass::Manager manager3;
manager3.register_pass<ov::pass::VisualizeTree>("before_pa_baichuan.svg");
manager3.register_pass<ov::pass::PrintModel>("before_pa_baichuan.txt");
manager3.run_passes(model);*/
OPENVINO_ASSERT(ov::op::util::has_op_with_type<ov::op::v13::ScaledDotProductAttention>(model),
"No ScaledDotProductAttention operation observed in the graph, cannot perform"
"the SDPAToPagedAttention transformation.");
Expand Down Expand Up @@ -186,5 +192,9 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
model->add_parameters(model_remaining_params);
model->add_parameters({std::move(max_context_len)});
model->validate_nodes_and_infer_types();

/* ov::pass::Manager manager_2;
manager_2.register_pass<ov::pass::VisualizeTree>("after_pa_baichuan.svg");
manager_2.run_passes(model);*/
return true;
}

0 comments on commit 232a7c9

Please sign in to comment.