Skip to content

Commit

Permalink
fixed code style
Browse files Browse the repository at this point in the history
  • Loading branch information
e-ddykim committed Nov 4, 2024
1 parent 5c0a5a6 commit cc4b37f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ const auto is_non_const_node = [](const ov::Output<ov::Node>& output) -> bool {
return true;
}
};
}
} // namespace

using namespace ov::pass::activations_scaling;
using namespace ov::pass::pattern;
Expand Down Expand Up @@ -243,7 +243,8 @@ ov::pass::activations_scaling::MulGroupNormTransformation::MulGroupNormTransform
}

if (mul && norm) {
size_t activation_index = ov::is_type<ov::op::v0::Constant>(mul->get_input_source_output(1).get_node()) ? 0 : 1;
size_t activation_index =
ov::is_type<ov::op::v0::Constant>(mul->get_input_source_output(1).get_node()) ? 0 : 1;
norm->input(0).replace_source_output(mul->get_input_source_output(activation_index));
return true;
}
Expand Down Expand Up @@ -286,7 +287,8 @@ ov::pass::activations_scaling::MulMVNTransformation::MulMVNTransformation() {
}

if (mul && norm) {
size_t activation_index = ov::is_type<ov::op::v0::Constant>(mul->get_input_source_output(1).get_node()) ? 0 : 1;
size_t activation_index =
ov::is_type<ov::op::v0::Constant>(mul->get_input_source_output(1).get_node()) ? 0 : 1;
norm->input(0).replace_source_output(mul->get_input_source_output(activation_index));
return true;
}
Expand Down Expand Up @@ -322,7 +324,8 @@ ov::pass::activations_scaling::SplitTransformation::SplitTransformation() {
OPENVINO_ASSERT(pattern_map.count(split_m));

auto mul = std::dynamic_pointer_cast<ov::op::v1::Multiply>(pattern_map.at(mul_m).get_node_shared_ptr());
auto split = std::dynamic_pointer_cast<ov::op::v1::VariadicSplit>(pattern_map.at(split_m).get_node_shared_ptr());
auto split =
std::dynamic_pointer_cast<ov::op::v1::VariadicSplit>(pattern_map.at(split_m).get_node_shared_ptr());

if (transformation_callback(split)) {
return false;
Expand All @@ -337,17 +340,17 @@ ov::pass::activations_scaling::SplitTransformation::SplitTransformation() {
target_inputs[i] = split->get_output_target_inputs(i);
}

size_t activation_index = ov::is_type<ov::op::v0::Constant>(mul->get_input_source_output(1).get_node()) ? 0 : 1;
size_t activation_index =
ov::is_type<ov::op::v0::Constant>(mul->get_input_source_output(1).get_node()) ? 0 : 1;
size_t const_index = (activation_index == 1) ? 0 : 1;
split->input(0).replace_source_output(mul->input(activation_index).get_source_output());

for (size_t i = 0; i < num_split_outputs; i++) {
auto new_mul = register_new_node<ov::op::v1::Multiply>(
split->output(i),
mul->input(const_index).get_source_output());
auto new_mul = register_new_node<ov::op::v1::Multiply>(split->output(i),
mul->input(const_index).get_source_output());
new_mul->set_friendly_name(mul->get_friendly_name() + "_" + std::to_string(i));
ov::copy_runtime_info(mul, new_mul);

for (auto& in : target_inputs[i]) {
in.replace_source_output(new_mul);
}
Expand Down Expand Up @@ -382,7 +385,8 @@ ov::pass::activations_scaling::ReshapeTransformation::ReshapeTransformation() {
OPENVINO_ASSERT(pattern_map.count(mul_m));
OPENVINO_ASSERT(pattern_map.count(reshape_m));

auto scale_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(pattern_map.at(scale_const_m).get_node_shared_ptr());
auto scale_const =
std::dynamic_pointer_cast<ov::op::v0::Constant>(pattern_map.at(scale_const_m).get_node_shared_ptr());
auto mul = std::dynamic_pointer_cast<ov::op::v1::Multiply>(pattern_map.at(mul_m).get_node_shared_ptr());
auto reshape = std::dynamic_pointer_cast<ov::op::v1::Reshape>(pattern_map.at(reshape_m).get_node_shared_ptr());

Expand All @@ -392,13 +396,14 @@ ov::pass::activations_scaling::ReshapeTransformation::ReshapeTransformation() {

if (scale_const && mul && reshape) {
auto target_inputs = reshape->get_output_target_inputs(0);
size_t activation_index = ov::is_type<ov::op::v0::Constant>(mul->get_input_source_output(1).get_node()) ? 0 : 1;
size_t activation_index =
ov::is_type<ov::op::v0::Constant>(mul->get_input_source_output(1).get_node()) ? 0 : 1;
reshape->input(0).replace_source_output(mul->input(activation_index).get_source_output());

auto new_mul = register_new_node<ov::op::v1::Multiply>(reshape, scale_const);
new_mul->set_friendly_name(mul->get_friendly_name() + "_r");
ov::copy_runtime_info(mul, new_mul);

for (auto& in : target_inputs) {
in.replace_source_output(new_mul);
}
Expand Down Expand Up @@ -461,7 +466,7 @@ ov::pass::activations_scaling::MulMulMulTransformation::MulMulMulTransformation(

mul2->input(0).replace_source_output(mul0->get_input_source_output((const0_index == 0) ? 1 : 0));
mul2->input(1).replace_source_output(mul1->get_input_source_output((const1_index == 0) ? 1 : 0));

auto new_mul = register_new_node<ov::op::v1::Multiply>(
mul2,
ov::op::util::eltwise_fold<ov::op::v1::Multiply>(scale_const0, scale_const1));
Expand All @@ -485,7 +490,7 @@ ov::pass::activations_scaling::MulMulMulTransformation::MulMulMulTransformation(
// \ | /
// \ | /
// ---------- Concat ------------
// ==>
// ==>
// (const_a (const_b (const_c
// input_a /const_c) input_b /const_c) input_c /const_c)
// \ / \ / \ /
Expand All @@ -494,7 +499,7 @@ ov::pass::activations_scaling::MulMulMulTransformation::MulMulMulTransformation(
// \ | /
// ---------- Concat ------------
// | const_c
// | /
// | /
// Multiply
ov::pass::activations_scaling::ConcatTransformation::ConcatTransformation() {
MATCHER_SCOPE(ConcatTransformation);
Expand All @@ -514,33 +519,38 @@ ov::pass::activations_scaling::ConcatTransformation::ConcatTransformation() {

// check if all inputs are Multiply with scalar operand
ov::Output<ov::Node> last_dep_const;
for (auto &input : concat->inputs()) {
auto dep_node = std::dynamic_pointer_cast<ov::op::v1::Multiply>(input.get_source_output().get_node_shared_ptr());
for (auto& input : concat->inputs()) {
auto dep_node =
std::dynamic_pointer_cast<ov::op::v1::Multiply>(input.get_source_output().get_node_shared_ptr());
if (!dep_node) {
return false;
}
auto dep_const0 = std::dynamic_pointer_cast<ov::op::v0::Constant>(dep_node->input(0).get_source_output().get_node_shared_ptr());
auto dep_const1 = std::dynamic_pointer_cast<ov::op::v0::Constant>(dep_node->input(1).get_source_output().get_node_shared_ptr());
auto dep_const0 = std::dynamic_pointer_cast<ov::op::v0::Constant>(
dep_node->input(0).get_source_output().get_node_shared_ptr());
auto dep_const1 = std::dynamic_pointer_cast<ov::op::v0::Constant>(
dep_node->input(1).get_source_output().get_node_shared_ptr());
if (!dep_const0 && !dep_const1) {
return false;
}
last_dep_const = dep_const0 ? dep_node->input(0).get_source_output() : dep_node->input(1).get_source_output();
last_dep_const =
dep_const0 ? dep_node->input(0).get_source_output() : dep_node->input(1).get_source_output();
if (!is_scalar_node(last_dep_const)) {
return false;
}
}

auto target_inputs = concat->get_output_target_inputs(0);

for (auto &input : concat->inputs()) {
for (auto& input : concat->inputs()) {
auto dep_node = input.get_source_output().get_node_shared_ptr();
auto dep_input0 = dep_node->input(0).get_source_output().get_node();
size_t const_index = ov::is_type<ov::op::v0::Constant>(dep_input0) ? 0 : 1;
size_t activation_index = ov::is_type<ov::op::v0::Constant>(dep_input0) ? 1 : 0;

auto new_mul = register_new_node<ov::op::v1::Multiply>(
dep_node->input(activation_index).get_source_output(),
ov::op::util::eltwise_fold<ov::op::v1::Divide>(dep_node->input(const_index).get_source_output(), last_dep_const));
ov::op::util::eltwise_fold<ov::op::v1::Divide>(dep_node->input(const_index).get_source_output(),
last_dep_const));
new_mul->set_friendly_name(dep_node->get_friendly_name() + "_c");
ov::copy_runtime_info(dep_node, new_mul);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ TEST_F(TransformationTestsF, SplitTransformationTest) {
auto convert2 = std::make_shared<ov::op::v0::Convert>(mul2, ov::element::f32);
auto result2 = std::make_shared<ov::op::v0::Result>(convert2);

model_ref = std::make_shared<ov::Model>(ov::ResultVector{result0, result1, result2}, ov::ParameterVector{input});
model_ref =
std::make_shared<ov::Model>(ov::ResultVector{result0, result1, result2}, ov::ParameterVector{input});
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
return static_cast<int32_t>((gamma_shape.back() / vec_size)) > static_cast<int32_t>(device_info.max_work_group_size);
});

// manager.register_pass<ov::pass::RMSFusion>();
manager.register_pass<ov::pass::RMSFusion>();
manager.register_pass<ov::intel_gpu::KVCacheFusion>();
manager.register_pass<ov::intel_gpu::FullyConnectedConvertFusion>();
manager.register_pass<ov::intel_gpu::TransposeFusion>(device_info.supports_immad);
Expand Down

0 comments on commit cc4b37f

Please sign in to comment.