diff --git a/docs/articles_en/documentation/openvino-extensibility/frontend-extensions.rst b/docs/articles_en/documentation/openvino-extensibility/frontend-extensions.rst index 5073a789d7e789..72faf29077b985 100644 --- a/docs/articles_en/documentation/openvino-extensibility/frontend-extensions.rst +++ b/docs/articles_en/documentation/openvino-extensibility/frontend-extensions.rst @@ -50,7 +50,7 @@ class that works well if all the following conditions are satisfied: .. note:: - ``OpExtension`` class is currently available for ONNX and TensorFlow frontends. + ``OpExtension`` class is currently available for ONNX, TensorFlow and PyTorch frontends. PaddlePaddle frontend has named inputs and outputs for operation (not indexed) therefore OpExtension mapping is not applicable for this case. diff --git a/docs/snippets/ov_extensions.cpp b/docs/snippets/ov_extensions.cpp index 3fd5db793bd490..1e1044d12e8ce8 100644 --- a/docs/snippets/ov_extensions.cpp +++ b/docs/snippets/ov_extensions.cpp @@ -261,7 +261,7 @@ core.add_extension("openvino_template_extension.so"); { //! [frontend_extension_framework_map_macro_add_extension] ov::Core core; -core.add_extension(ov::frontend::OpExtension()); +core.add_extension(ov::OpExtension()); //! [frontend_extension_framework_map_macro_add_extension] } return 0; diff --git a/src/core/include/openvino/core/op_extension.hpp b/src/core/include/openvino/core/op_extension.hpp index e9a756ea2d32c1..f76ecc43f4418c 100644 --- a/src/core/include/openvino/core/op_extension.hpp +++ b/src/core/include/openvino/core/op_extension.hpp @@ -62,6 +62,7 @@ namespace detail { OV_COLLECT_ATTACHED_EXTENSIONS(onnx) OV_COLLECT_ATTACHED_EXTENSIONS(paddle) OV_COLLECT_ATTACHED_EXTENSIONS(tensorflow) +OV_COLLECT_ATTACHED_EXTENSIONS(pytorch) } // namespace detail /** @@ -98,6 +99,7 @@ class OpExtension : public BaseOpExtension { detail::collect_attached_extensions_onnx(res); detail::collect_attached_extensions_paddle(res); detail::collect_attached_extensions_tensorflow(res); + detail::collect_attached_extensions_pytorch(res); return res; } }; diff --git a/src/frontends/onnx/frontend/src/frontend.cpp b/src/frontends/onnx/frontend/src/frontend.cpp index 27ccd18a5f757c..7a02349b1e0d46 100644 --- a/src/frontends/onnx/frontend/src/frontend.cpp +++ b/src/frontends/onnx/frontend/src/frontend.cpp @@ -207,5 +207,9 @@ void FrontEnd::add_extension(const std::shared_ptr& extension) { m_extensions.conversions.push_back(onnx_conv_ext); } else if (auto progress_reporter = std::dynamic_pointer_cast(extension)) { m_extensions.progress_reporter = progress_reporter; + } else if (auto op_base_ext = std::dynamic_pointer_cast(extension)) { + for (const auto& attached_ext : op_base_ext->get_attached_extensions()) { + add_extension(attached_ext); + } } } diff --git a/src/frontends/paddle/src/frontend.cpp b/src/frontends/paddle/src/frontend.cpp index 1330e3100da407..9bb81b30279def 100644 --- a/src/frontends/paddle/src/frontend.cpp +++ b/src/frontends/paddle/src/frontend.cpp @@ -559,6 +559,10 @@ void FrontEnd::add_extension(const std::shared_ptr& extension) { m_op_translators[paddle_conv_ext->get_op_type()] = [=](const NodeContext& context) { return paddle_conv_ext->get_converter()(context); }; + } else if (auto op_base_ext = std::dynamic_pointer_cast(extension)) { + for (const auto& attached_ext : op_base_ext->get_attached_extensions()) { + add_extension(attached_ext); + } } } diff --git a/src/frontends/pytorch/src/frontend.cpp b/src/frontends/pytorch/src/frontend.cpp index 03835b72935327..ea4b916f83d485 100644 --- a/src/frontends/pytorch/src/frontend.cpp +++ b/src/frontends/pytorch/src/frontend.cpp @@ -258,6 +258,10 @@ void FrontEnd::add_extension(const std::shared_ptr& extension) { m_extensions.push_back(so_ext); } else if (const auto& telemetry = std::dynamic_pointer_cast(extension)) { m_telemetry = telemetry; + } else if (auto op_base_ext = std::dynamic_pointer_cast(extension)) { + for (const auto& attached_ext : op_base_ext->get_attached_extensions()) { + add_extension(attached_ext); + } } } diff --git a/src/frontends/tensorflow/src/frontend.cpp b/src/frontends/tensorflow/src/frontend.cpp index 6e1b2d49e316a6..4cdca6a4f25733 100644 --- a/src/frontends/tensorflow/src/frontend.cpp +++ b/src/frontends/tensorflow/src/frontend.cpp @@ -559,5 +559,9 @@ void FrontEnd::add_extension(const std::shared_ptr& extension) { std::dynamic_pointer_cast(extension)) { m_conversion_extensions.push_back(tensorflow_conv_ext); m_op_translators[tensorflow_conv_ext->get_op_type()] = tensorflow_conv_ext->get_converter(); + } else if (auto op_base_ext = std::dynamic_pointer_cast(extension)) { + for (const auto& attached_ext : op_base_ext->get_attached_extensions()) { + add_extension(attached_ext); + } } } diff --git a/src/frontends/tensorflow_lite/src/frontend.cpp b/src/frontends/tensorflow_lite/src/frontend.cpp index 5f58977743650a..f21d2808532a7e 100644 --- a/src/frontends/tensorflow_lite/src/frontend.cpp +++ b/src/frontends/tensorflow_lite/src/frontend.cpp @@ -315,5 +315,9 @@ void FrontEnd::add_extension(const std::shared_ptr& extension) { m_op_translators[tensorflow_conv_ext->get_op_type()] = [=](const NodeContext& context) { return tensorflow_conv_ext->get_converter()(context); }; + } else if (auto op_base_ext = std::dynamic_pointer_cast(extension)) { + for (const auto& attached_ext : op_base_ext->get_attached_extensions()) { + add_extension(attached_ext); + } } } diff --git a/src/frontends/tests/frontend/shared/test_builtin_extensions/builtin_extensions.cpp b/src/frontends/tests/frontend/shared/test_builtin_extensions/builtin_extensions.cpp index fa7f996d3bb1c4..6e3efd0b2e5429 100644 --- a/src/frontends/tests/frontend/shared/test_builtin_extensions/builtin_extensions.cpp +++ b/src/frontends/tests/frontend/shared/test_builtin_extensions/builtin_extensions.cpp @@ -147,14 +147,23 @@ class CustomElu : public ov::op::Op { }; #ifdef ENABLE_OV_PYTORCH_FRONTEND +# include # include # include -# define PT_EXT \ - std::make_shared>( \ - "aten::elu", \ - std::map{{"m_alpha", 1}}, \ - std::map{{"m_beta", 1.0f}}), \ - std::make_shared("Relu", ReluToSwishTranslator), +# include +class ReluCustom : public ov::op::v0::Relu { +public: + OPENVINO_OP("ReluCustom"); + OPENVINO_FRAMEWORK_MAP(pytorch, "aten::relu"); +}; +# define PT_EXT \ + std::make_shared>( \ + "aten::elu", \ + std::map{{"m_alpha", 1}}, \ + std::map{{"m_beta", 1.0f}}), \ + std::make_shared("Relu", ReluToSwishTranslator), \ + std::make_shared>(), + #else # define PT_EXT #endif diff --git a/tests/layer_tests/mo_python_api_tests/test_mo_convert_extensions.py b/tests/layer_tests/mo_python_api_tests/test_mo_convert_extensions.py index cdad324909ccdd..9a66477aa0695f 100644 --- a/tests/layer_tests/mo_python_api_tests/test_mo_convert_extensions.py +++ b/tests/layer_tests/mo_python_api_tests/test_mo_convert_extensions.py @@ -11,7 +11,7 @@ from openvino.runtime import PartialShape, Model -class TestExtensions(CommonMOConvertTest): +class TestONNXExtensions(CommonMOConvertTest): def create_onnx_model(self, tmp_dir): # # Create ONNX model @@ -23,8 +23,10 @@ def create_onnx_model(self, tmp_dir): shape = [2, 3, 4] - input = helper.make_tensor_value_info('input', TensorProto.FLOAT, shape) - output = helper.make_tensor_value_info('output', TensorProto.FLOAT, shape) + input = helper.make_tensor_value_info( + 'input', TensorProto.FLOAT, shape) + output = helper.make_tensor_value_info( + 'output', TensorProto.FLOAT, shape) node_def = onnx.helper.make_node( 'LeakyRelu', @@ -57,7 +59,7 @@ def create_custom_extension_leaky_relu_to_relu(): # replaces LeakyRelu with Relu from openvino.frontend import ConversionExtension from openvino.frontend import NodeContext - import openvino.runtime.opset8 as ops + import openvino.runtime.opset14 as ops def custom_converter(node: NodeContext): input = node.get_input(0) @@ -66,11 +68,17 @@ def custom_converter(node: NodeContext): return ConversionExtension("LeakyRelu", custom_converter) + def create_custom_op_extension_leaky_relu_to_relu(): + # replaces LeakyRelu with Relu + from openvino.frontend import OpExtension + + return OpExtension("Relu", "LeakyRelu") + def create_custom_extension_elu_to_sigmoid(): # replaces Elu with Sigmoid from openvino.frontend import ConversionExtension from openvino.frontend import NodeContext - import openvino.runtime.opset8 as ops + import openvino.runtime.opset14 as ops def custom_converter(node: NodeContext): input = node.get_input(0) @@ -81,22 +89,22 @@ def custom_converter(node: NodeContext): def create_ref_graph1(): shape = PartialShape([2, 3, 4]) - param = ov.opset8.parameter(shape, dtype=np.float32) + param = ov.opset14.parameter(shape, dtype=np.float32) param.get_output_tensor(0).set_names({"input"}) - relu = ov.opset8.relu(param) + relu = ov.opset14.relu(param) relu.get_output_tensor(0).set_names({"LeakyRelu_data"}) - elu = ov.opset8.elu(relu, alpha=0.1) + elu = ov.opset14.elu(relu, alpha=0.1) elu.get_output_tensor(0).set_names({"output"}) return Model([elu], [param], "test") def create_ref_graph2(): shape = PartialShape([2, 3, 4]) - param = ov.opset8.parameter(shape, dtype=np.float32) + param = ov.opset14.parameter(shape, dtype=np.float32) param.get_output_tensor(0).set_names({"input"}) - relu = ov.opset8.relu(param) + relu = ov.opset14.relu(param) relu.get_output_tensor(0).set_names({"LeakyRelu_data"}) - sigmoid = ov.opset8.sigmoid(relu) + sigmoid = ov.opset14.sigmoid(relu) sigmoid.get_output_tensor(0).set_names({"output"}) return Model([sigmoid], [param], "test") @@ -104,6 +112,8 @@ def create_ref_graph2(): test_data = [ {'params_test': {'extensions': create_custom_extension_leaky_relu_to_relu()}, 'ref_graph': create_ref_graph1()}, + {'params_test': {'extensions': create_custom_op_extension_leaky_relu_to_relu()}, + 'ref_graph': create_ref_graph1()}, {'params_test': {'extensions': [create_custom_extension_leaky_relu_to_relu(), create_custom_extension_elu_to_sigmoid()]}, 'ref_graph': create_ref_graph2()} @@ -112,11 +122,134 @@ def create_ref_graph2(): @pytest.mark.parametrize("params", test_data) @pytest.mark.nightly @pytest.mark.precommit - def test_mo_convert_extensions(self, params, ie_device, precision, ir_version, - temp_dir, use_legacy_frontend): + def test_onnx_mo_convert_extensions(self, params, ie_device, precision, ir_version, + temp_dir, use_legacy_frontend): onnx_net_path = self.create_onnx_model(temp_dir) test_params = params['params_test'] test_params.update({'input_model': onnx_net_path}) test_params.update({'use_convert_model_from_mo': True}) self._test_by_ref_graph(temp_dir, test_params, params['ref_graph']) + + +class TestPyTorchExtensions(CommonMOConvertTest): + def create_model(self, tmp_dir): + import torch + + class CosModel(torch.nn.Module): + def __init__(self): + super(CosModel, self).__init__() + + def forward(self, x): + return torch.cos(x.to(torch.float32)) + + return CosModel() + + def create_custom_extension_cos_to_sin(): + from openvino.frontend import ConversionExtension + from openvino.frontend import NodeContext + import openvino.runtime.opset14 as ops + + def custom_converter(node: NodeContext): + input = node.get_input(0) + sin = ops.sin(input) + return sin.outputs() + + return ConversionExtension("aten::cos", custom_converter) + + def create_custom_op_extension_cos_to_sin(): + from openvino.frontend import OpExtension + + return OpExtension("Sin", "aten::cos") + + def create_ref_graph(): + shape = PartialShape.dynamic() + param = ov.opset14.parameter(shape, dtype=ov.Type.dynamic) + param.get_output_tensor(0).set_names({"x"}) + convert = ov.opset14.convert(param, ov.Type.f32) + convert.get_output_tensor(0).set_names({"5"}) + sin = ov.opset14.sin(convert) + + return Model([sin], [param], "test") + + test_data = [ + {'params_test': {'extension': create_custom_extension_cos_to_sin()}, + 'ref_graph': create_ref_graph()}, + {'params_test': {'extension': create_custom_op_extension_cos_to_sin()}, + 'ref_graph': create_ref_graph()}, + ] + + @pytest.mark.parametrize("params", test_data) + @pytest.mark.nightly + @pytest.mark.precommit + def test_pt_mo_convert_extensions(self, params, ie_device, precision, ir_version, + temp_dir, use_legacy_frontend): + model = self.create_model(temp_dir) + + test_params = params['params_test'] + test_params.update({'input_model': model}) + self._test_by_ref_graph(temp_dir, test_params, params['ref_graph']) + + +class TestTfExtensions(CommonMOConvertTest): + def create_keras_model(self, temp_dir): + import tensorflow as tf + + tf.keras.backend.clear_session() + tf.compat.v1.reset_default_graph() + + input_name = "Input1" + input_shape = [1, 2, 3] + + x = tf.keras.Input(shape=input_shape, name=input_name) + y = tf.cos(x) + keras_net = tf.keras.Model(inputs=[x], outputs=[y]) + tf.keras.backend.clear_session() + + return keras_net + + def create_custom_extension_cos_to_sin(): + from openvino.frontend import ConversionExtension + from openvino.frontend import NodeContext + import openvino.runtime.opset14 as ops + + def custom_converter(node: NodeContext): + input = node.get_input(0) + sin = ops.sin(input) + return sin.outputs() + + return ConversionExtension("Cos", custom_converter) + + def create_custom_op_extension_cos_to_sin(): + from openvino.frontend import OpExtension + + return OpExtension("Sin", "Cos") + + def create_ref_graph(): + shape = PartialShape([-1, 1, 2, 3]) + param = ov.opset14.parameter(shape, dtype=np.float32) + param.get_output_tensor(0).set_names({"Input1"}) + y = ov.opset14.sin(param) + y.get_output_tensor(0).set_names({"tf.math.cos/Cos:0"}) + + parameter_list = [param] + + return Model([y], parameter_list, "test") + + test_data = [ + {'params_test': {'extension': create_custom_extension_cos_to_sin()}, + 'ref_graph': create_ref_graph()}, + {'params_test': {'extension': create_custom_op_extension_cos_to_sin()}, + 'ref_graph': create_ref_graph()}, + ] + + @pytest.mark.parametrize("params", test_data) + @pytest.mark.nightly + @pytest.mark.precommit + def test_tf_mo_convert_extensions(self, params, ie_device, precision, ir_version, + temp_dir, use_legacy_frontend): + model = self.create_keras_model(temp_dir) + + test_params = params['params_test'] + test_params.update({'input_model': model}) + self._test_by_ref_graph(temp_dir, test_params, params['ref_graph']) diff --git a/tests/layer_tests/py_frontend_tests/test_torch_frontend.py b/tests/layer_tests/py_frontend_tests/test_torch_frontend.py index d3b724e84d83ae..63fb7869af55a1 100644 --- a/tests/layer_tests/py_frontend_tests/test_torch_frontend.py +++ b/tests/layer_tests/py_frontend_tests/test_torch_frontend.py @@ -4,7 +4,7 @@ import torch import numpy as np -from openvino.frontend import FrontEndManager, ConversionExtension, NodeContext, OpExtension +from openvino.frontend import FrontEndManager, ConversionExtension, NodeContext from openvino.runtime import PartialShape, Type import openvino.runtime.opset10 as ops @@ -227,7 +227,7 @@ def get_builtin_extensions_path(): raise RuntimeError("Unable to find test_builtin_extensions") -def test_op_extension(): +def test_so_extension(): from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder class Elu(torch.nn.Module): @@ -258,6 +258,100 @@ def forward(self, inp): assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [ "Parameter", "CustomElu", "Result"] +def test_framework_map_macros(): + from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder + + class Relu(torch.nn.Module): + def __init__(self): + super(Relu, self).__init__() + + def forward(self, x): + return torch.nn.functional.relu(x) + + model = Relu() + decoder = TorchScriptPythonDecoder(get_scripted_model(model)) + + fem = FrontEndManager() + fe = fem.load_by_framework(framework="pytorch") + assert fe + + input_model = fe.load(decoder) + assert input_model + converted_model = fe.convert(input_model) + assert converted_model + assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [ + "Parameter", "Relu", "Result"] + + fe.add_extension(get_builtin_extensions_path()) + converted_model = fe.convert(input_model) + assert converted_model + assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [ + "Parameter", "ReluCustom", "Result"] + + +def test_op_extension(): + from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder + from openvino.frontend.pytorch import OpExtension + + class CosModel(torch.nn.Module): + def __init__(self): + super(CosModel, self).__init__() + + def forward(self, x): + return torch.cos(x.to(torch.float32)) + + model = CosModel() + decoder = TorchScriptPythonDecoder(get_scripted_model(model)) + + fem = FrontEndManager() + fe = fem.load_by_framework(framework="pytorch") + assert fe + + input_model = fe.load(decoder) + assert input_model + converted_model = fe.convert(input_model) + assert converted_model + assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [ + "Parameter", "Convert", "Cos", "Result"] + + fe.add_extension(OpExtension("Sin", "aten::cos")) + converted_model = fe.convert(input_model) + assert converted_model + assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [ + "Parameter", "Convert", "Sin", "Result"] + + +def test_op_extension_generic(): + from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder + from openvino.frontend import OpExtension + + class CosModel(torch.nn.Module): + def __init__(self): + super(CosModel, self).__init__() + + def forward(self, x): + return torch.cos(x.to(torch.float32)) + + model = CosModel() + decoder = TorchScriptPythonDecoder(get_scripted_model(model)) + + fem = FrontEndManager() + fe = fem.load_by_framework(framework="pytorch") + assert fe + + input_model = fe.load(decoder) + assert input_model + converted_model = fe.convert(input_model) + assert converted_model + assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [ + "Parameter", "Convert", "Cos", "Result"] + + fe.add_extension(OpExtension("Sin", "aten::cos")) + converted_model = fe.convert(input_model) + assert converted_model + assert [n.get_type_name() for n in converted_model.get_ordered_ops()] == [ + "Parameter", "Convert", "Sin", "Result"] + def test_pytorch_telemetry(): from openvino.frontend import TelemetryExtension @@ -318,20 +412,24 @@ def test_shared_consts_reused(): from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder model = ShareWeghtsConvAndShareLinearModel() - decoder = TorchScriptPythonDecoder(model, example_input=(torch.rand(model.INPUT_SIZE),)) + decoder = TorchScriptPythonDecoder( + model, example_input=(torch.rand(model.INPUT_SIZE),)) fe_manager = FrontEndManager() fe = fe_manager.load_by_framework("pytorch") im = fe.load(decoder) om = fe.convert(im) - const_names = ["self.conv.weight", "self.linear.weight", "self.linear.bias"] + const_names = ["self.conv.weight", + "self.linear.weight", "self.linear.bias"] # self.conv.bias is not reused because of ConstantFolding for n in om.get_ops(): if "Constant" in n.get_type_name(): for name in n.output(0).names: if name in const_names: const_names.remove(name) - assert len(n.output(0).get_target_inputs()) == 2, f"Constant {n} is not reused" - assert len(const_names) == 0, f"Not all constants were found: {const_names}" + assert len(n.output(0).get_target_inputs() + ) == 2, f"Constant {n} is not reused" + assert len( + const_names) == 0, f"Not all constants were found: {const_names}" class TestModel1(torch.nn.Module): @@ -377,4 +475,5 @@ def test_output_tuple_names(): fe = fe_manager.load_by_framework("pytorch") im = fe.load(decoder) om = fe.convert(im) - assert len(om.outputs[0].names) == 0 and len(om.outputs[1].names) == 0, "Output tuple names must be empty" + assert len(om.outputs[0].names) == 0 and len( + om.outputs[1].names) == 0, "Output tuple names must be empty"