Skip to content

Commit

Permalink
[CPU] [TRANSFORMATIONS] FakeConvert decomposition transformation (#28118
Browse files Browse the repository at this point in the history
)

### Details:
 - *Implement FakeConvert decomposition transformation.*

### Tickets:
 - *[CVS-156963](https://jira.devtools.intel.com/browse/CVS-156963)*

### Prerequisites:
- *#27949
  • Loading branch information
xuchen-intel authored Dec 29, 2024
1 parent 82d553e commit 2ef42d4
Show file tree
Hide file tree
Showing 8 changed files with 426 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TRANSFORMATIONS_API FakeConvertDecomposition;

} // namespace pass
} // namespace ov

/**
* @ingroup ov_transformation_common_api
* @brief FakeConvertDecomposition transformation decomposes FakeConvert layer.
* f8: f8e4m3, f8e5m2
* downconvert: f32->f8, f16->f8, bf16->f8
* upconvert: f8->f32, f8->f16, f8->bf16
* output = (upconvert(downconvert(input * scale - shift)) + shift) / scale
*
*/

class ov::pass::FakeConvertDecomposition : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("FakeConvertDecomposition");
FakeConvertDecomposition();
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/op_conversions/fake_convert_decomposition.hpp"

#include "itt.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/fake_convert.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/subtract.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"

ov::pass::FakeConvertDecomposition::FakeConvertDecomposition() {
MATCHER_SCOPE(FakeConvertDecomposition);
auto data = pattern::any_input();

auto fake_convert = ov::pass::pattern::wrap_type<ov::op::v13::FakeConvert>();

matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
const auto fake_convert_node =
ov::as_type_ptr<ov::op::v13::FakeConvert>(pattern_to_output.at(fake_convert).get_node_shared_ptr());

if (fake_convert_node == nullptr || transformation_callback(fake_convert_node)) {
return false;
}

Output<Node> data{fake_convert_node->input_value(0)};
const Output<Node> input_scale{fake_convert_node->input_value(1)};
auto input_type = data.get_element_type();

ov::pass::NodeRegistry decomp_ops;
if (input_type != input_scale.get_element_type()) {
input_type = input_scale.get_element_type();
data = std::make_shared<ov::op::v0::Convert>(data, input_type);
data = decomp_ops.add(data.get_node_shared_ptr());
}

std::shared_ptr<Node> result;
const auto scale = decomp_ops.make<ov::op::v1::Multiply>(data, input_scale);
if (fake_convert_node->get_input_size() == 2) {
const auto downconvert =
decomp_ops.make<ov::op::v0::Convert>(scale, fake_convert_node->get_destination_element_type());
const auto upconvert = decomp_ops.make<ov::op::v0::Convert>(downconvert, input_type);

result = decomp_ops.make<ov::op::v1::Divide>(upconvert, input_scale);
} else {
const Output<Node> input_shift{fake_convert_node->input_value(2)};
const auto shift = decomp_ops.make<ov::op::v1::Subtract>(scale, input_shift);

const auto downconvert =
decomp_ops.make<ov::op::v0::Convert>(shift, fake_convert_node->get_destination_element_type());
const auto upconvert = decomp_ops.make<ov::op::v0::Convert>(downconvert, input_type);

const auto deshift = decomp_ops.make<ov::op::v1::Add>(upconvert, input_shift);
result = decomp_ops.make<ov::op::v1::Divide>(deshift, input_scale);
}

if (result->get_output_element_type(0) != fake_convert_node->get_output_element_type(0)) {
result = decomp_ops.make<ov::op::v0::Convert>(result, fake_convert_node->get_output_element_type(0));
}

result->set_friendly_name(m.get_match_root()->get_friendly_name());
ov::copy_runtime_info(fake_convert_node, decomp_ops.get());
ov::replace_node(m.get_match_root(), result);
return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(fake_convert, matcher_name);
register_matcher(m, callback);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "transformations/op_conversions/fake_convert_decomposition.hpp"

#include <gtest/gtest.h>

#include "common_test_utils/common_utils.hpp"
#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/opsets/opset1.hpp"
#include "openvino/opsets/opset13.hpp"

using namespace ov;

using FakeConvertDecompositionParams = std::tuple<Shape, // data shape
Shape, // scale shape
Shape, // shift shape
element::Type_t, // input precision
element::Type_t, // destination precision
bool>; // default shift

class FakeConvertDecompositionTest : public ov::test::TestsCommon,
public ::testing::WithParamInterface<FakeConvertDecompositionParams> {
public:
static std::string getTestCaseName(::testing::TestParamInfo<FakeConvertDecompositionParams> obj) {
FakeConvertDecompositionParams params = obj.param;

Shape data_shape, scale_shape, shift_shape;
element::Type_t data_prec, dst_prec;
bool default_shift;
std::tie(data_shape, scale_shape, shift_shape, data_prec, dst_prec, default_shift) = params;

std::ostringstream result;
result << "dataShape=" << ov::test::utils::vec2str(data_shape) << "_";
result << "scaleShape=" << ov::test::utils::vec2str(scale_shape) << "_";
result << "shiftShape=" << ov::test::utils::vec2str(shift_shape) << "_";
result << "dataPrecision=" << element::Type(data_prec) << "_";
result << "destinationPrecision=" << element::Type(dst_prec) << "_";
if (default_shift)
result << "defaultShift=true";
else
result << "defaultShift=false";
return result.str();
}
};

TEST_P(FakeConvertDecompositionTest, CompareFunctions) {
FakeConvertDecompositionParams params = this->GetParam();

Shape data_shape, scale_shape, shift_shape;
element::Type_t data_prec, dst_prec;
bool default_shift;
std::tie(data_shape, scale_shape, shift_shape, data_prec, dst_prec, default_shift) = params;

std::shared_ptr<ov::Model> model(nullptr);
{
const auto data = std::make_shared<opset1::Parameter>(data_prec, PartialShape(data_shape));
const auto scale = std::make_shared<opset1::Constant>(data_prec, scale_shape);
const auto shift = std::make_shared<opset1::Constant>(data_prec, shift_shape);

const auto fake_convert = default_shift ? std::make_shared<opset13::FakeConvert>(data, scale, dst_prec)
: std::make_shared<opset13::FakeConvert>(data, scale, shift, dst_prec);
model = std::make_shared<ov::Model>(NodeVector{fake_convert}, ParameterVector{data});

pass::Manager manager;
manager.register_pass<ov::pass::InitNodeInfo>();
manager.register_pass<ov::pass::FakeConvertDecomposition>();
manager.run_passes(model);

OV_ASSERT_NO_THROW(check_rt_info(model));
}

std::shared_ptr<ov::Model> model_ref(nullptr);
{
const auto input_data = std::make_shared<opset1::Parameter>(data_prec, PartialShape(data_shape));
const auto input_scale = std::make_shared<opset1::Constant>(data_prec, scale_shape);
const auto input_shift = std::make_shared<opset1::Constant>(data_prec, shift_shape);
ParameterVector params;
params.push_back(input_data);
std::shared_ptr<Node> data = input_data;

std::shared_ptr<Node> result;
const auto scale = std::make_shared<ov::op::v1::Multiply>(data, input_scale);
if (default_shift) {
const auto downconvert = std::make_shared<ov::op::v0::Convert>(scale, dst_prec);
const auto upconvert = std::make_shared<ov::op::v0::Convert>(downconvert, data_prec);

result = std::make_shared<ov::op::v1::Divide>(upconvert, input_scale);
} else {
const auto shift = std::make_shared<ov::op::v1::Subtract>(scale, input_shift);

const auto downconvert = std::make_shared<ov::op::v0::Convert>(shift, dst_prec);
const auto upconvert = std::make_shared<ov::op::v0::Convert>(downconvert, data_prec);

const auto deshift = std::make_shared<ov::op::v1::Add>(upconvert, input_shift);
result = std::make_shared<ov::op::v1::Divide>(deshift, input_scale);
}

model_ref = std::make_shared<ov::Model>(NodeVector{result}, params);
}

const auto res = compare_functions(model, model_ref);
ASSERT_TRUE(res.first) << res.second;
}

const std::vector<element::Type_t> data_precisions = {element::Type_t::f32,
element::Type_t::f16,
element::Type_t::bf16};

const std::vector<element::Type_t> destination_precisions = {element::Type_t::f8e4m3, element::Type_t::f8e5m2};

const std::vector<bool> default_shift = {true, false};

const auto simple_fake_convert_params = ::testing::Combine(::testing::Values(Shape{2, 3, 4, 5}),
::testing::Values(Shape{1}),
::testing::Values(Shape{1}),
::testing::ValuesIn(data_precisions),
::testing::ValuesIn(destination_precisions),
::testing::ValuesIn(default_shift));

const auto broadcast_fake_convert_params = ::testing::Combine(::testing::Values(Shape{2, 3, 4, 5}),
::testing::Values(Shape{2, 3, 1, 1}),
::testing::Values(Shape{2, 3, 1, 1}),
::testing::ValuesIn(data_precisions),
::testing::ValuesIn(destination_precisions),
::testing::ValuesIn(default_shift));

const auto elementwise_fake_convert_params = ::testing::Combine(::testing::Values(Shape{2, 3, 4, 5}),
::testing::Values(Shape{2, 3, 4, 5}),
::testing::Values(Shape{2, 3, 4, 5}),
::testing::ValuesIn(data_precisions),
::testing::ValuesIn(destination_precisions),
::testing::ValuesIn(default_shift));

INSTANTIATE_TEST_SUITE_P(SimpleFakeConvert_Decomposition,
FakeConvertDecompositionTest,
simple_fake_convert_params,
FakeConvertDecompositionTest::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(BroadcastFakeConvert_Decomposition,
FakeConvertDecompositionTest,
broadcast_fake_convert_params,
FakeConvertDecompositionTest::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(ElementwiseFakeConvert_Decomposition,
FakeConvertDecompositionTest,
elementwise_fake_convert_params,
FakeConvertDecompositionTest::getTestCaseName);
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
#include "transformations/op_conversions/detection_output_downgrade.hpp"
#include "transformations/op_conversions/detection_output_upgrade.hpp"
#include "transformations/op_conversions/eye_decomposition.hpp"
#include "transformations/op_conversions/fake_convert_decomposition.hpp"
#include "transformations/op_conversions/fq_decomposition.hpp"
#include "transformations/op_conversions/gelu7_downgrade.hpp"
#include "transformations/op_conversions/group_normalization_decomposition.hpp"
Expand Down Expand Up @@ -1293,6 +1294,7 @@ void Transformations::PostSnippets(void) {
return node::FakeQuantize::isSupportedOperation(node, errMsg);
},
ov::pass::FakeQuantizeDecomposition);
CPU_REGISTER_PASS_COMMON(postSnippetsManager, ov::pass::FakeConvertDecomposition);
CPU_REGISTER_PASS_COMMON(postSnippetsManager, ov::pass::ConstantFolding);
postSnippetsManager.run_passes(model);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "single_op_tests/fake_convert.hpp"

namespace {
using ov::test::FakeConvertLayerTest;

const std::vector<std::vector<ov::Shape>> shapes = {{{2, 3, 4, 5}}};

const std::vector<ov::element::Type> data_precisions = {ov::element::f32, ov::element::f16, ov::element::bf16};

const std::vector<ov::element::Type> destination_precisions = {ov::element::f8e4m3, ov::element::f8e5m2};

const std::vector<bool> default_shift = {true, false};

const auto simple_fake_convert_params =
::testing::Combine(::testing::ValuesIn(ov::test::static_shapes_to_test_representation(shapes)),
::testing::Values(ov::Shape{1}),
::testing::Values(ov::Shape{1}),
::testing::ValuesIn(data_precisions),
::testing::ValuesIn(destination_precisions),
::testing::ValuesIn(default_shift),
::testing::Values(ov::test::utils::DEVICE_CPU));

const auto broadcast_fake_convert_params =
::testing::Combine(::testing::ValuesIn(ov::test::static_shapes_to_test_representation(shapes)),
::testing::Values(ov::Shape{2, 3, 1, 1}),
::testing::Values(ov::Shape{2, 3, 1, 1}),
::testing::ValuesIn(data_precisions),
::testing::ValuesIn(destination_precisions),
::testing::ValuesIn(default_shift),
::testing::Values(ov::test::utils::DEVICE_CPU));

const auto elementwise_fake_convert_params =
::testing::Combine(::testing::ValuesIn(ov::test::static_shapes_to_test_representation(shapes)),
::testing::Values(ov::Shape{2, 3, 4, 5}),
::testing::Values(ov::Shape{2, 3, 4, 5}),
::testing::ValuesIn(data_precisions),
::testing::ValuesIn(destination_precisions),
::testing::ValuesIn(default_shift),
::testing::Values(ov::test::utils::DEVICE_CPU));

INSTANTIATE_TEST_SUITE_P(smoke_FakeConvert_simple,
FakeConvertLayerTest,
simple_fake_convert_params,
FakeConvertLayerTest::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_FakeConvert_broadcast,
FakeConvertLayerTest,
broadcast_fake_convert_params,
FakeConvertLayerTest::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_FakeConvert_elementwise,
FakeConvertLayerTest,
elementwise_fake_convert_params,
FakeConvertLayerTest::getTestCaseName);
} // namespace
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "shared_test_classes/single_op/fake_convert.hpp"

namespace ov {
namespace test {

TEST_P(FakeConvertLayerTest, Inference) {
run();
}
} // namespace test
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "shared_test_classes/base/ov_subgraph.hpp"

namespace ov {
namespace test {
using FakeConvertParams = std::tuple<std::vector<InputShape>, // Data shape
Shape, // Scale shape
Shape, // Shift shape
ov::element::Type, // Input precision
ov::element::Type, // Ddestination precision
bool, // Default shift
std::string>; // Device name

class FakeConvertLayerTest : public testing::WithParamInterface<FakeConvertParams>,
virtual public ov::test::SubgraphBaseTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<FakeConvertParams>& obj);

protected:
void SetUp() override;
};
} // namespace test
} // namespace ov
Loading

0 comments on commit 2ef42d4

Please sign in to comment.