From 7a80fe83ef07651ba60557fe5022bd9277615a5f Mon Sep 17 00:00:00 2001 From: Taylor Yeonbok Lee Date: Mon, 28 Oct 2024 02:59:41 -0700 Subject: [PATCH] [GPU] Support GeLU Tanh for Phi-2 (#27213) ### Details: - Previously GeLU Tanh was supported only for x * (0.5 * (1 + tanh)) - Support pattern with (x * 0.5) * (1 + tanh)) too. ### Tickets: - 155576 --- .../common_optimizations/gelu_fusion.cpp | 36 ++++++++++-------- .../common_optimizations/gelu_fusion.cpp | 38 +++++++++++++++++++ 2 files changed, 59 insertions(+), 15 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/gelu_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/gelu_fusion.cpp index 221484c75cccde..8d075f4a727758 100644 --- a/src/common/transformations/src/transformations/common_optimizations/gelu_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/gelu_fusion.cpp @@ -22,6 +22,7 @@ #include "openvino/op/parameter.hpp" #include "openvino/op/power.hpp" #include "openvino/op/tanh.hpp" +#include "openvino/pass/pattern/op/or.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/utils/utils.hpp" @@ -280,9 +281,16 @@ ov::pass::GeluFusionWithTanh::GeluFusionWithTanh() { auto add_1 = ov::pass::pattern::wrap_type({tanh, add_1_constant}); auto mul_2_constant = ov::pass::pattern::wrap_type(); - auto mul_2 = ov::pass::pattern::wrap_type({add_1, mul_2_constant}); - auto mul_3 = ov::pass::pattern::wrap_type({input, mul_2}); + // x * (0.5 * (1 + tanh)) + auto mul_2_1 = ov::pass::pattern::wrap_type({add_1, mul_2_constant}); + auto mul_3_1 = ov::pass::pattern::wrap_type({input, mul_2_1}); + + // (x * 0.5) * (1 + tanh) + auto mul_2_2 = ov::pass::pattern::wrap_type({input, mul_2_constant}); + auto mul_3_2 = ov::pass::pattern::wrap_type({add_1, mul_2_2}); + + auto mul_3 = std::make_shared(OutputVector{mul_3_1, mul_3_2}); ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { auto& pattern_to_output = m.get_pattern_value_map(); @@ -298,7 +306,6 @@ ov::pass::GeluFusionWithTanh::GeluFusionWithTanh() { ov::as_type_ptr(pattern_to_output.at(mul_2_constant).get_node_shared_ptr()); auto add_1_constant_value = ov::as_type_ptr(pattern_to_output.at(add_1_constant).get_node_shared_ptr()); - if (!pow_constant_value || !add_1_constant_value || !mul_0_constant_value || !mul_1_constant_value || !mul_2_constant_value) { return false; @@ -318,18 +325,17 @@ ov::pass::GeluFusionWithTanh::GeluFusionWithTanh() { auto gelu = std::make_shared(x_output, op::GeluApproximationMode::TANH); gelu->set_friendly_name(m.get_match_root()->get_friendly_name()); - ov::copy_runtime_info( - { - pattern_to_output.at(pow).get_node_shared_ptr(), - pattern_to_output.at(mul_0).get_node_shared_ptr(), - pattern_to_output.at(mul_1).get_node_shared_ptr(), - pattern_to_output.at(mul_2).get_node_shared_ptr(), - pattern_to_output.at(mul_3).get_node_shared_ptr(), - pattern_to_output.at(tanh).get_node_shared_ptr(), - pattern_to_output.at(add_0).get_node_shared_ptr(), - pattern_to_output.at(add_1).get_node_shared_ptr(), - }, - gelu); + + std::vector> pattern_nodes = + {pow, mul_0, mul_1, tanh, add_0, add_1, mul_2_1, mul_2_2, mul_3_1, mul_3_2}; + std::vector> cp_rt_info_nodes; + for (const auto& pattern_node : pattern_nodes) { + if (pattern_to_output.count(pattern_node)) { + cp_rt_info_nodes.push_back(pattern_to_output.at(pattern_node).get_node_shared_ptr()); + } + } + ov::copy_runtime_info(cp_rt_info_nodes, gelu); + ov::replace_node(m.get_match_root(), gelu); return true; }; diff --git a/src/common/transformations/tests/common_optimizations/gelu_fusion.cpp b/src/common/transformations/tests/common_optimizations/gelu_fusion.cpp index 837d2ba6d4597e..dbc54f5492bffa 100644 --- a/src/common/transformations/tests/common_optimizations/gelu_fusion.cpp +++ b/src/common/transformations/tests/common_optimizations/gelu_fusion.cpp @@ -388,6 +388,44 @@ TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_epsilon_pow_value) { } } +TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_epsilon_pow_value_2) { + { + auto input = std::make_shared(element::f32, Shape{2, 2}); + auto pow_constant = + std::make_shared(element::f32, Shape{1}, std::vector{3.0f + 1.0e-8f}); + auto pow = std::make_shared(input, pow_constant); + auto mul_0_constant = + std::make_shared(element::f32, Shape{1}, std::vector{0.044715f}); + auto mul_0 = std::make_shared(pow, mul_0_constant); + auto add_0 = std::make_shared(input, mul_0); + + auto mul_1_constant = + std::make_shared(element::f32, + Shape{1}, + std::vector{static_cast(std::sqrt(2.0 / M_PI))}); + auto mul_1 = std::make_shared(add_0, mul_1_constant); + + auto tanh = std::make_shared(mul_1); + + auto add_1_constant = std::make_shared(element::f32, Shape{1}, std::vector{1.0f}); + auto add_1 = std::make_shared(tanh, add_1_constant); + + auto mul_2_constant = std::make_shared(element::f32, Shape{1}, std::vector{0.5f}); + auto mul_2 = std::make_shared(input, mul_2_constant); + + auto mul_3 = std::make_shared(add_1, mul_2); + + model = std::make_shared(NodeVector{mul_3}, ParameterVector{input}); + manager.register_pass(); + } + + { + auto data = std::make_shared(element::f32, Shape{2, 2}); + auto gelu = std::make_shared(data, op::GeluApproximationMode::TANH); + model_ref = std::make_shared(NodeVector{gelu}, ParameterVector{data}); + } +} + TEST_F(TransformationTestsF, GeluFusionTanhWithTanh_wrong_pow_value) { { auto input = std::make_shared(element::f32, Shape{2, 2});