Skip to content

Commit

Permalink
[Transformations] Added interchangeable reshape elimination (openvino…
Browse files Browse the repository at this point in the history
…toolkit#9691)

* [Transformations] Added interchangeable reshape elimination

* Applied comments #2

* returned Reshape in condition

* applied comments openvinotoolkit#3

* applied comments openvinotoolkit#4

* added comment in plugin with reason about transformation
  • Loading branch information
a-sidorova authored Feb 9, 2022
1 parent a002b26 commit fce49e6
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ class TRANSFORMATIONS_API ReshapeSequenceFusion;

/**
* @ingroup ie_transformation_common_api
* @brief ReshpaeSequenceFusion fuses sequence of Reshape operation into single Reshape
* @brief ReshapeSequenceFusion fuses sequence of Reshape operation into single Reshape or eliminates full redundant sequence
*/

class ngraph::pass::ReshapeSequenceFusion: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ReshapeSequenceFusion();
ReshapeSequenceFusion(bool use_shape_for_elimination = true);
};
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ bool ngraph::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph
common_fusions->add_matcher<ngraph::pass::DivideFusion>();
common_fusions->add_matcher<ngraph::pass::SubtractFusion>();
common_fusions->add_matcher<ngraph::pass::TransposeToReshape>();
common_fusions->add_matcher<ngraph::pass::ReshapeSequenceFusion>();
common_fusions->add_matcher<ngraph::pass::ReshapeSequenceFusion>(m_use_shapes);
common_fusions->set_name("ngraph::pass::CommonFusions");

manager.register_pass<ngraph::pass::BinarizeWeights>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,25 @@ static bool eliminate_reshape_v1(const std::shared_ptr<Node>& node) {
if (ov::as_type_ptr<opset3::Squeeze>(input_node) ||
ov::as_type_ptr<opset3::Unsqueeze>(input_node) ||
ov::as_type_ptr<opset3::Reshape>(input_node)) {
if (input_node->get_output_target_inputs(0).size() != 1)
return false;

auto shape = node->get_output_shape(0);
std::vector<int64_t> vi;
vi.assign(shape.begin(), shape.end());
auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
auto new_reshape =
make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
new_reshape->set_friendly_name(node->get_friendly_name());
copy_runtime_info({input_node, node}, new_reshape);
replace_node(node, new_reshape);
return true;

// remove interchangeable nodes
if (input_node->get_input_partial_shape(0).is_static() && input_node->get_input_shape(0) == shape) {
return replace_output_update_name(node->output(0), input_node->input_value(0));
} else {
std::vector<int64_t> vi;
vi.assign(shape.begin(), shape.end());
auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
auto new_reshape =
make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
new_reshape->set_friendly_name(node->get_friendly_name());
copy_runtime_info({input_node, node}, new_reshape);
replace_node(node, new_reshape);
return true;
}
}

return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ bool has_valid_pattern(const ov::Output<ov::Node>& node_out) {
}
} // namespace

ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion() {
ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion(bool use_shape_for_elimination) {
MATCHER_SCOPE(ReshapeSequenceFusion);
auto reshape_input = pattern::any_input();
auto reshape_a_pattern = pattern::wrap_type<opset8::Constant>();
Expand Down Expand Up @@ -87,9 +87,21 @@ ngraph::pass::ReshapeSequenceFusion::ReshapeSequenceFusion() {
input = node->input_value(0);
}

reshape->input(0).replace_source_output(input);
copy_runtime_info(nodes, reshape);
return false;
// remove redundant reshapes
bool replaced = false;
if (use_shape_for_elimination && input.get_partial_shape().is_static() && reshape->get_output_partial_shape(0).is_static() &&
input.get_shape() == reshape->get_output_shape(0)) {
// in case if elimination is not allowed we still can eliminate all transposes except last one
replaced = replace_output_update_name(reshape->output(0), input);
}

if (!replaced) {
reshape->input(0).replace_source_output(input);
copy_runtime_info(nodes, reshape);
return false; // because root node wasn't replaced
}

return true;
};

auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_b, matcher_name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "transformations/convert_precision.hpp"
#include "transformations/utils/utils.hpp"
#include "rnn_sequences_optimization.hpp"
#include "transformations/common_optimizations/reshape_sequence_fusion.hpp"

namespace MKLDNNPlugin {

Expand All @@ -34,6 +35,8 @@ inline void ConvertToCPUSpecificOpset(std::shared_ptr<ngraph::Function> &nGraphF
if (!ngraph::op::util::has_op_with_type<ngraph::op::FakeQuantize>(nGraphFunc)) {
manager.register_pass<ReshapeFullyConnectedFusion>();
}
// after transformation "MoveEltwiseUpThroughDataMov" there can be Reshape sequences that should be eliminated or fused
manager.register_pass<ngraph::pass::ReshapeSequenceFusion>();
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<ngraph::pass::ConvertPrecision>(precisions_array {{ ngraph::element::i64, ngraph::element::i32 }});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,48 @@ TEST(nop_elimination, squeeze_reshape_elimination_check_info) {
pass_manager.register_pass<pass::NopElimination>();
pass_manager.run_passes(f);

bool reshape_is_missing = true;
bool movement_are_missing = true;
for (auto node : f->get_ops()) {
if (node->get_friendly_name() == "reshape") {
reshape_is_missing = false;
ASSERT_TRUE(std::dynamic_pointer_cast<opset4::Reshape>(node));
auto original_names = ngraph::getFusedNamesVector(node);
sort(original_names.begin(), original_names.end());
ASSERT_EQ(original_names, std::vector<std::string>({"reshape", "squeeze"}));
if (node->get_friendly_name() == "reshape" || node->get_friendly_name() == "squeeze") {
movement_are_missing = false;
}
}
ASSERT_FALSE(reshape_is_missing);
ASSERT_TRUE(movement_are_missing);
}

TEST(nop_elimination, squeeze_unsqueeze_elimination) {
std::shared_ptr<Function> f;
{
auto arg = std::make_shared<opset4::Parameter>(element::f32, PartialShape{8, 16, 1, 3});

auto relu = std::make_shared<opset4::Relu>(arg);
relu->set_friendly_name("relu");

auto squeeze_axes = opset4::Constant::create(element::i64, Shape{1}, {2});
auto squeeze = std::make_shared<opset4::Squeeze>(relu, squeeze_axes);
squeeze->set_friendly_name("squeeze");

auto unsqueeze_axes = opset4::Constant::create(element::i64, Shape{1}, {2});
auto unsqueeze = std::make_shared<opset4::Unsqueeze>(squeeze, unsqueeze_axes);
unsqueeze->set_friendly_name("unsqueeze");

auto abs = std::make_shared<opset4::Abs>(unsqueeze);

f = std::make_shared<Function>(NodeVector{abs}, ParameterVector{arg});
}

pass::Manager pass_manager;
pass_manager.register_pass<pass::InitNodeInfo>();
pass_manager.register_pass<pass::NopElimination>();
pass_manager.run_passes(f);

bool movement_are_missing = true;
for (auto node : f->get_ops()) {
if (node->get_friendly_name() == "squeeze" || node->get_friendly_name() == "unsqueeze") {
movement_are_missing = false;
}
}
ASSERT_TRUE(movement_are_missing);
}

TEST(nop_elimination, reshape_elimination_v1_dynamic) {
Expand All @@ -165,6 +196,33 @@ TEST(nop_elimination, reshape_elimination_v1_dynamic) {
ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(f) == 1);
}

TEST(nop_elimination, reshape_elimination_v1_check_consumer_count) {
std::shared_ptr<Function> f;
{
auto arg = std::make_shared<opset4::Parameter>(element::f32, PartialShape{8, 16, 1, 3});

auto reshape_1_shape = opset4::Constant::create(element::i64, Shape{2}, {128, 3});
auto reshape_1 = std::make_shared<opset4::Reshape>(arg, reshape_1_shape, false);
reshape_1->set_friendly_name("reshape_1");

auto reshape_2_shape = opset4::Constant::create(element::i64, Shape{4}, {8, 16, 1, 3});
auto reshape_2 = std::make_shared<opset4::Reshape>(reshape_1, reshape_2_shape, false);
reshape_2->set_friendly_name("reshape_2");

auto relu = std::make_shared<opset4::Relu>(reshape_1);
relu->set_friendly_name("relu");

f = std::make_shared<Function>(NodeVector{reshape_2, relu}, ParameterVector{arg});
}

pass::Manager pass_manager;
pass_manager.register_pass<pass::InitNodeInfo>();
pass_manager.register_pass<pass::NopElimination>();
pass_manager.run_passes(f);

ASSERT_TRUE(count_ops_of_type<op::v1::Reshape>(f) == 2);
}

TEST(nop_elimination, concat_elimination_single_node) {
int64_t a = 0;
auto A = make_shared<op::Parameter>(element::f32, Shape{2, 3});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,21 @@ TEST_F(TransformationTestsF, ReshapeSequenceFusionNeg5_special_zero_false) {
manager.register_pass<pass::ReshapeSequenceFusion>();
}
}

TEST_F(TransformationTestsF, ReshapeSequenceFusionEliminate) {
{
auto data = std::make_shared<opset6::Parameter>(element::f32, Shape{1, 2, 3});
auto relu = std::make_shared<opset6::Relu>(data);
auto a = reshape(relu, {2, 3});
auto b = reshape(a, {1, 2, 3});
function = std::make_shared<Function>(OutputVector{b}, ParameterVector{data});

manager.register_pass<pass::ReshapeSequenceFusion>();
}

{
auto data = std::make_shared<opset6::Parameter>(element::f32, Shape{1, 2, 3});
auto relu = std::make_shared<opset6::Relu>(data);
function_ref = std::make_shared<Function>(OutputVector{relu}, ParameterVector{data});
}
}

0 comments on commit fce49e6

Please sign in to comment.