Skip to content

Commit

Permalink
updated MulShareTransformation
Browse files Browse the repository at this point in the history
  • Loading branch information
e-ddykim committed Jan 9, 2025
1 parent 6771211 commit 9a3cda3
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace activations_scaling {
class TRANSFORMATIONS_API ScaleDownSingleLayer;
class TRANSFORMATIONS_API EliminateMultiplyScalar;
class TRANSFORMATIONS_API MulConcatTransformation;
class TRANSFORMATIONS_API NormMulTransformation;
class TRANSFORMATIONS_API MulShareTransformation;
class TRANSFORMATIONS_API MulMulTransformation;

} // namespace activations_scaling
Expand Down Expand Up @@ -49,10 +49,10 @@ class ov::pass::activations_scaling::MulConcatTransformation : public ov::pass::
MulConcatTransformation();
};

class ov::pass::activations_scaling::NormMulTransformation : public ov::pass::MatcherPass {
class ov::pass::activations_scaling::MulShareTransformation : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("NormMulTransformation", "0");
NormMulTransformation();
OPENVINO_MATCHER_PASS_RTTI("MulShareTransformation", "0");
MulShareTransformation();
};

class ov::pass::activations_scaling::MulMulTransformation : public ov::pass::MatcherPass {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ ov::pass::activations_scaling::MulConcatTransformation::MulConcatTransformation(
ov::copy_runtime_info(concat, new_mul);
ov::replace_output_update_name(concat->output(0), new_mul->output(0));

return false;
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(concat_m, "MulConcatTransformation");
Expand All @@ -330,8 +330,8 @@ ov::pass::activations_scaling::MulConcatTransformation::MulConcatTransformation(
// op_a op_b Norm op_b
// |
// op_a
ov::pass::activations_scaling::NormMulTransformation::NormMulTransformation() {
MATCHER_SCOPE(NormMulTransformation);
ov::pass::activations_scaling::MulShareTransformation::MulShareTransformation() {
MATCHER_SCOPE(MulShareTransformation);

auto mvn_m = wrap_type<ov::op::v6::MVN>({any_input(), any_input()});
auto rms_m = wrap_type<ov::op::internal::RMS>({any_input(), any_input()});
Expand All @@ -349,34 +349,32 @@ ov::pass::activations_scaling::NormMulTransformation::NormMulTransformation() {
auto norm = pattern_map.at(norm_m).get_node_shared_ptr();

auto parent_output = norm->get_input_source_output(0);
if (parent_output.get_target_inputs().size() != 2)
if (parent_output.get_target_inputs().size() == 1)
return false;

ov::Node* mul = nullptr;
for (auto& child : parent_output.get_target_inputs()) {
if (child == norm->input(0))
continue;
mul = child.get_node();
}

if (!ov::is_type<ov::op::v1::Multiply>(mul))
return false;
if (ov::is_type<ov::op::v1::Multiply>(child.get_node())) {
ov::Output<ov::Node> const_input;
for (auto input : child.get_node()->input_values()) {
if (input == parent_output)
continue;
const_input = input;
}

ov::Output<ov::Node> const_input;
for (auto input : mul->input_values()) {
if (input == parent_output)
continue;
const_input = input;
if (is_scalar_node(const_input) && !is_non_const_node(const_input)) {
norm->input(0).replace_source_output(child.get_node()->output(0));
return true;
}
}
}

if (!is_scalar_node(const_input) || is_non_const_node(const_input))
return false;

norm->input(0).replace_source_output(mul->output(0));
return true;
return false;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(norm_m, "NormMulTransformation");
auto m = std::make_shared<ov::pass::pattern::Matcher>(norm_m, "ScalarMulShareTransformation");
this->register_matcher(m, callback);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
#include "openvino/op/mvn.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/variadic_split.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/utils/utils.hpp"
#include "low_precision/multiply_partial.hpp"

using namespace ov;
using namespace testing;
Expand Down Expand Up @@ -110,10 +110,10 @@ TEST_F(TransformationTestsF, ConcatTransformationTest) {
}
{
auto input0 = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{6, 12, 10, 24});
auto scale_const0 = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{1}, {10});
auto scale_const0 = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{1}, {1});
auto mul0 = std::make_shared<ov::op::v1::Multiply>(input0, scale_const0);
auto input1 = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{6, 12, 10, 24});
auto scale_const1 = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{1}, {10});
auto scale_const1 = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{1}, {1});
auto mul1 = std::make_shared<ov::op::v1::Multiply>(input1, scale_const1);
auto concat = std::make_shared<ov::op::v0::Concat>(OutputVector{mul0, mul1}, 0);
auto new_scale_const = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{1}, {10});
Expand Down Expand Up @@ -150,3 +150,31 @@ TEST_F(TransformationTestsF, MulMulTransformationTest) {
model_ref = std::make_shared<ov::Model>(ov::ResultVector{result}, ov::ParameterVector{input0, input1});
}
}

TEST_F(TransformationTestsF, MulShareTransformationTest) {
{
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{6, 12, 10, 24});
auto shape_of = std::make_shared<ov::op::v3::ShapeOf>(input);
auto convert0 = std::make_shared<ov::op::v0::Convert>(shape_of, ov::element::f32);
auto result0 = std::make_shared<ov::op::v0::Result>(convert0);
auto scale_const = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{1}, {10});
auto mul = std::make_shared<ov::op::v1::Multiply>(input, scale_const);
auto convert1 = std::make_shared<ov::op::v0::Convert>(mul, ov::element::f32);
auto result1 = std::make_shared<ov::op::v0::Result>(convert1);

model = std::make_shared<ov::Model>(ov::ResultVector{result0, result1}, ov::ParameterVector{input});
manager.register_pass<ov::pass::activations_scaling::MulShareTransformation>();
}
{
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape{6, 12, 10, 24});
auto scale_const = ov::op::v0::Constant::create(ov::element::f16, ov::Shape{1}, {10});
auto mul = std::make_shared<ov::op::v1::Multiply>(input, scale_const);
auto shape_of = std::make_shared<ov::op::v3::ShapeOf>(mul);
auto convert0 = std::make_shared<ov::op::v0::Convert>(shape_of, ov::element::f32);
auto result0 = std::make_shared<ov::op::v0::Result>(convert0);
auto convert1 = std::make_shared<ov::op::v0::Convert>(mul, ov::element::f32);
auto result1 = std::make_shared<ov::op::v0::Result>(convert1);

model_ref = std::make_shared<ov::Model>(ov::ResultVector{result0, result1}, ov::ParameterVector{input});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "intel_gpu/runtime/debug_configuration.hpp"
#include "intel_gpu/runtime/itt.hpp"
#include "low_precision/add.hpp"
#include "low_precision/clamp.hpp"
#include "low_precision/concat.hpp"
#include "low_precision/convolution.hpp"
#include "low_precision/convolution_backprop_data.hpp"
Expand All @@ -31,12 +30,9 @@
#include "low_precision/pull_reshape_through_dequantization.hpp"
#include "low_precision/pull_transpose_through_dequantization.hpp"
#include "low_precision/recurrent_cell.hpp"
#include "low_precision/reshape.hpp"
#include "low_precision/rt_info/bias_attribute.hpp"
#include "low_precision/strided_slice.hpp"
#include "low_precision/transpose.hpp"
#include "low_precision/unsqueeze.hpp"
#include "low_precision/variadic_split.hpp"
#include "openvino/core/deprecated.hpp"
#include "openvino/core/type/element_type.hpp"
#include "openvino/core/validation_util.hpp"
Expand Down Expand Up @@ -970,7 +966,7 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {

// Move up remained scalar-multiply layers
manager.register_pass<ov::pass::EliminateEltwise>();
manager.register_pass<ov::pass::activations_scaling::NormMulTransformation>();
manager.register_pass<ov::pass::activations_scaling::MulShareTransformation>();

const std::vector<DiscreteTypeInfo> allowed_data_movement_ops = {
ov::op::v1::Reshape::get_type_info_static(),
Expand Down

0 comments on commit 9a3cda3

Please sign in to comment.