Skip to content

Commit

Permalink
Merge branch 'master' into cecilia/causal_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
ceciliapeng2011 authored Jan 7, 2025
2 parents 133f3e1 + 26e5fe9 commit 90a6c18
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/visualize_tree.hpp"
#include "transformations/utils/utils.hpp"
#include "openvino/opsets/opset8.hpp"

namespace ov {
namespace intel_gpu {
Expand All @@ -42,7 +43,8 @@ KVCacheFusionMatcher::KVCacheFusionMatcher() {
auto gather_input = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{past, convert_past});
auto beam_idx = wrap_type<ov::op::v0::Parameter>();
auto gather_past = wrap_type<ov::op::v8::Gather>({gather_input, beam_idx, wrap_type<ov::op::v0::Constant>()});
auto concat_past_input = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{past, convert_past, gather_past});
auto gather_convert = wrap_type<ov::op::v0::Convert>({gather_past});
auto concat_past_input = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{past, convert_past, gather_past, gather_convert});
auto concat = wrap_type<ov::op::v0::Concat>({concat_past_input, any_input()});
auto convert_present = wrap_type<ov::op::v0::Convert>({concat});
auto present_input = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{concat, convert_present});
Expand All @@ -63,8 +65,10 @@ KVCacheFusionMatcher::KVCacheFusionMatcher() {
return false;

// TODO: Support conversion internally
if (!concat_node || concat_node->get_output_element_type(0) != past_node->get_output_element_type(0))
return false;
if (ov::is_type<ov::opset8::Gather>(concat_past_input)) {
if (!concat_node || concat_node->get_output_element_type(0) != past_node->get_output_element_type(0))
return false;
}

auto variable = past_node->get_variable();
auto concat_axis = concat_node->get_axis();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,13 @@ UnsqueezeBroadcastReshapeSDPAFusion::UnsqueezeBroadcastReshapeSDPAFusion() {
auto reshape_b_m = wrap_type<ov::op::v1::Reshape>({broadcast_b_m, any_input()}, reshape_predicate);
auto reshape_c_m = wrap_type<ov::op::v1::Reshape>({broadcast_c_m, any_input()}, reshape_predicate);

auto convert_reshape_b_m = wrap_type<ov::op::v0::Convert>({reshape_b_m});
auto reshape_b_m_input = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{reshape_b_m, convert_reshape_b_m});
auto convert_reshape_c_m = wrap_type<ov::op::v0::Convert>({reshape_c_m});
auto reshape_c_m_input = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{reshape_c_m, convert_reshape_c_m});

auto sdpa_without_attn_mask_m = wrap_type<op::SDPA>({ input_a_m, reshape_b_m, reshape_c_m });
auto sdpa_with_attn_mask_m = wrap_type<op::SDPA>({ input_a_m, reshape_b_m, reshape_c_m, input_attn_mask });
auto sdpa_with_attn_mask_m = wrap_type<op::SDPA>({ input_a_m, reshape_b_m_input, reshape_c_m_input, input_attn_mask });
auto sdpa_with_attn_mask_and_scale_m = wrap_type<op::SDPA>({ input_a_m, reshape_b_m, reshape_c_m, input_attn_mask, input_scale });

auto sdpa_m = std::make_shared<Or>(OutputVector{sdpa_without_attn_mask_m, sdpa_with_attn_mask_m, sdpa_with_attn_mask_and_scale_m});
Expand Down

0 comments on commit 90a6c18

Please sign in to comment.