From 136e59ce85fd7d6839c1a9ddae29bc7b56a88e1d Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Mon, 14 Oct 2024 13:05:24 +0400 Subject: [PATCH] [PT FE] Support conversion of TorchScript and ExportedProgram from disk (#26877) **Details:** Support conversion of TorchScript and ExportedProgram from disk. Examples: Let us have ExportedProgram `exported_model.pt2` model saved on disk. ```sh ovc exported_model.pt2 ``` **Note:** TorchScript model from disk is not converted using `ovc` since `input` option does not accept input types. ```python from openvino import convert_model import torch convert_model(input_model='torch_scripted_model.pt', example_input=torch.rand(1, 10)) convert_model(input_model='exported_model.pt2') ``` **Ticket:** 103215 --------- Signed-off-by: Kazantsev, Roman --- .../openvino/frontend/pytorch/fx_decoder.py | 5 +- .../common/mo_convert_test_class.py | 41 ++++- .../ovc_python_api_tests/test_pytorch.py | 169 +++++++++++++++++- tools/ovc/openvino/tools/ovc/convert_impl.py | 23 ++- tools/ovc/openvino/tools/ovc/help.py | 3 +- .../moc_frontend/pytorch_frontend_utils.py | 70 +++++++- 6 files changed, 291 insertions(+), 20 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py index 03b91ba545dd51..c448571f1ac17a 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py @@ -24,7 +24,8 @@ def __init__(self, data) -> None: class TorchFXPythonDecoder (Decoder): - def __init__(self, pt_module, fx_gm=None, nodes=None, mark_node_callback=None, input_shapes=[], input_types=[]): + def __init__(self, pt_module, fx_gm=None, nodes=None, + mark_node_callback=None, input_shapes=[], input_types=[], dynamic_shapes=False): Decoder.__init__(self) self.mark_node_callback = mark_node_callback # We store every decoder created by this decoder so that all them are not deleted until the first decoder is deleted @@ -67,7 +68,7 @@ def __init__(self, pt_module, fx_gm=None, nodes=None, mark_node_callback=None, i if shape is not None: new_shape = [] for dim in range(0, len(shape)): - if (type(shape[dim]).__name__ == "SymInt"): + if (dynamic_shapes or type(shape[dim]).__name__ == "SymInt"): new_shape.append(-1) else: new_shape.append(shape[dim]) diff --git a/tests/layer_tests/common/mo_convert_test_class.py b/tests/layer_tests/common/mo_convert_test_class.py index e800e76ed98a88..6a57339cedf111 100644 --- a/tests/layer_tests/common/mo_convert_test_class.py +++ b/tests/layer_tests/common/mo_convert_test_class.py @@ -1,10 +1,33 @@ # Copyright (C) 2018-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import sys from pathlib import Path from common.utils.common_utils import generate_ir +from common.utils.common_utils import shell + from openvino.test_utils import compare_functions +from openvino.tools.ovc import ovc + + +def generate_ir_ovc(coverage=False, **kwargs): + # Get OVC file directory + ovc_path = Path(ovc.__file__).parent + + ovc_runner = ovc_path.joinpath('main.py').as_posix() + if coverage: + params = [sys.executable, '-m', 'coverage', 'run', '-p', '--source={}'.format(ovc_runner.parent), + '--omit=*_test.py', ovc_runner] + else: + params = [sys.executable, ovc_runner] + for key, value in kwargs.items(): + if key == "input_model": + params.append((str(value))) + else: + params.extend(("--{}".format(key), str(value))) + exit_code, stdout, stderr = shell(params) + return exit_code, stderr class CommonMOConvertTest: @@ -54,16 +77,26 @@ def _test(self, temp_dir, test_params, ref_params): flag, msg = compare_functions(ir_test, ir_ref) assert flag, msg - def _test_by_ref_graph(self, temp_dir, test_params, ref_graph, compare_tensor_names=True, compare_layout=True): + def _test_by_ref_graph(self, temp_dir, test_params, ref_graph, compare_tensor_names=True, + compare_layout=True, ovc=False): """ Generates IR using MO Python API, reads it and compares with reference graph. """ from openvino import Core core = Core() - test_params.update({"model_name": 'model_test', "output_dir": temp_dir}) - self.generate_ir_python_api(**test_params) - ir_test = core.read_model(Path(temp_dir, 'model_test.xml')) + if ovc: + ir_file_name = Path(temp_dir, 'model_test.xml') + test_params.update({"output_model": ir_file_name}) + exit_code, stderr = generate_ir_ovc(coverage=False, **test_params) + assert not exit_code, stderr + else: + test_params.update({"model_name": 'model_test', "output_dir": temp_dir}) + ir_file_name = Path(temp_dir, 'model_test.xml') + self.generate_ir_python_api(**test_params) + + ir_test = core.read_model(ir_file_name) + flag, msg = compare_functions(ir_test, ref_graph, compare_tensor_names=compare_tensor_names) assert flag, msg diff --git a/tests/layer_tests/ovc_python_api_tests/test_pytorch.py b/tests/layer_tests/ovc_python_api_tests/test_pytorch.py index 3fbe25ee130e69..7dc40e310330cf 100644 --- a/tests/layer_tests/ovc_python_api_tests/test_pytorch.py +++ b/tests/layer_tests/ovc_python_api_tests/test_pytorch.py @@ -1,14 +1,17 @@ # Copyright (C) 2018-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import os +import tempfile import unittest from typing import Tuple, List import numpy as np -import openvino.runtime as ov import pytest import torch from common.mo_convert_test_class import CommonMOConvertTest + +import openvino.runtime as ov from openvino.runtime import PartialShape, Dimension, Model, Type @@ -1408,3 +1411,167 @@ def test_conversion_params(self, params, ie_device, precision, ir_version, test_params.update({'input_model': fw_model}) self._test_by_ref_graph(temp_dir, test_params, ref_model, compare_tensor_names=False) + + +def pytorch_nn_module_with_enabled_compression(tmp_dir): + import torch + + class NeuralNetwork(torch.nn.Module): + def __init__(self): + super(NeuralNetwork, self).__init__() + self.y = torch.arange(10, dtype=torch.float16) + + def forward(self, x, z): + return (x + self.y.to(torch.float32)) * z + + param_1 = ov.opset13.parameter([10], dtype=np.float32) + param_2 = ov.opset13.parameter([10], dtype=np.float32) + const_1 = ov.opset13.constant(np.arange(10), dtype=np.float16) + convert_1 = ov.opset13.convert(const_1, np.float32) + add_1 = ov.opset13.add(param_1, convert_1) + mul_1 = ov.opset13.multiply(add_1, param_2) + + ov_model_ref = Model([mul_1], [param_1, param_2], "test") + fw_model = NeuralNetwork() + return fw_model, ov_model_ref, {'input': [([10], np.float32), ([10], np.float32)], + 'example_input': (torch.zeros(10), torch.zeros(10))} + + +def pytorch_nn_module_with_disabled_compression(tmp_dir): + import torch + + class NeuralNetwork(torch.nn.Module): + def __init__(self): + super(NeuralNetwork, self).__init__() + self.y = torch.arange(10, dtype=torch.float32) + + def forward(self, x, z): + return (x + self.y) * z + + param_1 = ov.opset13.parameter([-1], dtype=np.float32) + param_2 = ov.opset13.parameter([-1], dtype=np.float32) + const_1 = ov.opset13.constant(np.arange(10), dtype=np.float32) + add_1 = ov.opset13.add(param_1, const_1) + mul_1 = ov.opset13.multiply(add_1, param_2) + + ov_model_ref = Model([mul_1], [param_1, param_2], "test") + fw_model = NeuralNetwork() + return fw_model, ov_model_ref, {'example_input': (torch.zeros(10), torch.zeros(10)), + 'compress_to_fp16': 'False'} + + +class TestConvertModelForPyTorchModelOnDisk(CommonMOConvertTest): + test_data = [ + 'create_pytorch_nn_module_case1', + 'create_pytorch_nn_module_case2', + 'create_pytorch_nn_module_case3', + 'create_pytorch_nn_module_sample_input_int32_two_inputs', + 'pytorch_nn_module_with_enabled_compression' + ] + + @pytest.mark.parametrize('create_model', test_data) + @pytest.mark.parametrize('model_format', ['exported_program', 'torch_script']) + @pytest.mark.nightly + @pytest.mark.precommit + def test_convert_model_for_pytorch_model_on_disk(self, create_model, model_format, + ie_device, precision, ir_version, + temp_dir, use_legacy_frontend): + fw_model, graph_ref, ovc_params = eval(create_model)(temp_dir) + + with tempfile.NamedTemporaryFile(delete=False) as tmpfile: + if model_format == 'torch_script': + scripted_model = torch.jit.script(fw_model) + scripted_model.save(tmpfile.name) + test_params = {'input_model': tmpfile.name} + if ovc_params is not None: + test_params.update(ovc_params) + else: + example_input = ovc_params['example_input'] + exported_program = torch.export.export(fw_model, example_input) + torch.export.save(exported_program, tmpfile.name) + test_params = {'input_model': tmpfile.name} + if ovc_params is not None: + test_params.update(ovc_params) + + self._test_by_ref_graph(temp_dir, test_params, + graph_ref, compare_tensor_names=False) + os.remove(tmpfile.name) + + +def ovc_case1(tmp_dir): + pt_model = make_pt_model_two_inputs() + ref_model = make_ref_pt_model_two_inputs([1, 3, 10, 10]) + + sample_input1 = torch.zeros(1, 3, 10, 10) + sample_input2 = torch.zeros(1, 3, 10, 10) + sample_input = sample_input1, sample_input2 + + return pt_model, ref_model, {'example_input': sample_input} + + +def pytorch_nn_module_case2(tmp_dir): + pt_model = make_pt_model_two_inputs() + ref_model = make_ref_pt_model_two_inputs([-1, 3, -1, -1]) + + sample_input1 = torch.zeros(1, 3, 10, 10) + sample_input2 = torch.zeros(1, 3, 10, 10) + sample_input = sample_input1, sample_input2 + + return pt_model, ref_model, {'input': '[-1,3,-1,-1],[-1,3,-1,-1]', + 'example_input': sample_input} + + +def nested_dict_input_ovc_case2(tmp_dir): + class PTModel(torch.nn.Module): + def forward(self, a, b): + return a["1"] * a["2"] + b + + net = PTModel() + a1 = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32) + a2 = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32) + b = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32) + mul = ov.opset10.multiply(a1, a2) + add = ov.opset10.add(mul, b) + ref_model = Model([add], [a1, a2, b], "test") + example_input = ( + { + "1": torch.tensor([1, 2], dtype=torch.float32), + "2": torch.tensor([3, 4], dtype=torch.float32) + }, + torch.tensor([5, 6], dtype=torch.float32) + ) + return net, ref_model, {'example_input': example_input} + + +class TestOVCForExportedProgramOnDisk(CommonMOConvertTest): + test_data = [ + 'create_pytorch_nn_module_case1', + 'pytorch_nn_module_case2', + 'nested_dict_input_ovc_case2', + 'pytorch_nn_module_with_disabled_compression' + ] + + @pytest.mark.parametrize('create_model', test_data) + @pytest.mark.nightly + @pytest.mark.precommit + def test_ovc_for_exported_program_on_disk(self, create_model, + ie_device, precision, ir_version, + temp_dir, use_legacy_frontend): + fw_model, graph_ref, ovc_params = eval(create_model)(temp_dir) + example_input = ovc_params['example_input'] + del ovc_params['example_input'] + + ep_file_name = None + with tempfile.NamedTemporaryFile(delete=False) as tmpfile: + exported_program = torch.export.export(fw_model, tuple(example_input)) + torch.export.save(exported_program, tmpfile.name) + ep_file_name = tmpfile.name + + test_params = {'input_model': ep_file_name} + if ovc_params is not None: + test_params.update(ovc_params) + + self._test_by_ref_graph(temp_dir, test_params, + graph_ref, compare_tensor_names=False, + ovc=True) + os.remove(ep_file_name) diff --git a/tools/ovc/openvino/tools/ovc/convert_impl.py b/tools/ovc/openvino/tools/ovc/convert_impl.py index 2eb2f2adc133f1..dc0694f0a405b5 100644 --- a/tools/ovc/openvino/tools/ovc/convert_impl.py +++ b/tools/ovc/openvino/tools/ovc/convert_impl.py @@ -12,7 +12,6 @@ from pathlib import Path from typing import Iterable, Callable - try: import openvino_telemetry as tm from openvino_telemetry.backend import backend_ga4 @@ -34,8 +33,10 @@ from openvino.tools.ovc.logger import init_logger from openvino.tools.ovc.telemetry_utils import send_params_info, send_conversion_result, \ init_mo_telemetry -from openvino.tools.ovc.moc_frontend.pytorch_frontend_utils import get_pytorch_decoder, extract_input_info_from_example +from openvino.tools.ovc.moc_frontend.pytorch_frontend_utils import get_pytorch_decoder, \ + extract_input_info_from_example, get_pytorch_decoder_for_model_on_disk from openvino.tools.ovc.moc_frontend.paddle_frontend_utils import paddle_frontend_converter + try: from openvino.tools.ovc.moc_frontend.jax_frontend_utils import get_jax_decoder except: @@ -232,7 +233,7 @@ def check_model_object(argv): paddle.fluid.dygraph.layers.Layer) or isinstance( model, paddle.fluid.executor.Executor): return "paddle" - + if 'jax' in sys.modules: import jax if isinstance(model, (jax.core.Jaxpr, jax.core.ClosedJaxpr)): @@ -475,9 +476,9 @@ def _convert(cli_parser: argparse.ArgumentParser, args, python_api_used): get_jax_decoder(args['input_model'], args) else: raise Error("JAX Frontend is not available.") - argv = pack_params_to_args_namespace(args, cli_parser, python_api_used) + argv.framework = model_framework argv.is_python_object = inp_model_is_object @@ -491,8 +492,22 @@ def _convert(cli_parser: argparse.ArgumentParser, args, python_api_used): argv.framework = model_framework + orig_input_model = argv.input_model + pytorch_model_on_disk = False + if argv.framework is None and get_pytorch_decoder_for_model_on_disk(argv, args): + # try to load a model from disk as TorchScript or ExportedProgram + # TorchScriptPythonDecoder or TorchFXPythonDecoder object will be assigned to argv.input_model + # saved TorchScript and ExportedModel model can be passed to both ovc tool and Python convert_model + pytorch_model_on_disk = True + ov_model = driver(argv, {"conversion_parameters": non_default_params}) + if pytorch_model_on_disk: + # release memory allocated for temporal object + del argv.input_model + # restore original model name in arguments for tool reporting + argv.input_model = orig_input_model + if inp_model_is_object and model_framework == "paddle": if paddle_runtime_converter: paddle_runtime_converter.destroy() diff --git a/tools/ovc/openvino/tools/ovc/help.py b/tools/ovc/openvino/tools/ovc/help.py index 1ef914c0f48143..e09102be39419e 100644 --- a/tools/ovc/openvino/tools/ovc/help.py +++ b/tools/ovc/openvino/tools/ovc/help.py @@ -7,7 +7,8 @@ def get_convert_model_help_specifics(): return { 'input_model': {'description': - 'Input model file(s) from TensorFlow, ONNX, PaddlePaddle. ' + 'Input model file(s) from PyTorch (ExportedProgram saved on a disk), ' + 'TensorFlow, ONNX, PaddlePaddle. ' 'Use openvino.convert_model in Python to convert models from PyTorch.' '', 'action': CanonicalizePathCheckExistenceAction, diff --git a/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py b/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py index dfe25f27d13d7d..0119a541494cb9 100644 --- a/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py +++ b/tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py @@ -2,14 +2,23 @@ # SPDX-License-Identifier: Apache-2.0 import logging as log +import pathlib import sys import numpy as np + # pylint: disable=no-name-in-module,import-error from openvino.runtime import Tensor, PartialShape -from openvino.tools.ovc.error import Error from openvino.tools.ovc.cli_parser import single_input_to_input_cut_info, _InputCutInfo +from openvino.tools.ovc.error import Error + +def extract_module_extensions(args): + from openvino.frontend.pytorch.module_extension import ModuleExtension + extensions = args.get('extension', []) or [] + if not isinstance(extensions, (list, tuple)): + extensions = [extensions] + return {extension.module: extension for extension in extensions if isinstance(extension, ModuleExtension)} def get_pytorch_decoder(model, example_inputs, args): @@ -21,12 +30,6 @@ def get_pytorch_decoder(model, example_inputs, args): except Exception as e: log.error("PyTorch frontend loading failed") raise e - - def extract_module_extensions(args): - extensions = args.get('extension', []) or [] - if not isinstance(extensions, (list, tuple)): - extensions = [extensions] - return {extension.module: extension for extension in extensions if isinstance(extension, ModuleExtension)} if 'nncf' in sys.modules: is_good_version = True @@ -54,7 +57,7 @@ def extract_module_extensions(args): model = model.run_decompositions(decomp_table=decomp) gm = model.module() log.debug(gm.code) - decoder = TorchFXPythonDecoder(gm) + decoder = TorchFXPythonDecoder(gm, dynamic_shapes=True) else: decoder = TorchScriptPythonDecoder( model, @@ -69,6 +72,57 @@ def extract_module_extensions(args): return args +def get_pytorch_decoder_for_model_on_disk(argv, args): + try: + from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder + from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder + import torch + except: + return False + + example_inputs = None + if 'example_input' in args and args['example_input'] is not None: + example_inputs = args['example_input'] + + if isinstance(argv.input_model, (tuple, list)) and len(argv.input_model) == 1: + input_model = argv.input_model[0] + else: + input_model = argv.input_model + + if isinstance(input_model, (str, pathlib.Path)): + # attempt to load scripted model + try: + inputs = prepare_torch_inputs(example_inputs) + model = torch.jit.load(input_model) + model.eval() + decoder = TorchScriptPythonDecoder( + model, + example_input=inputs, + shared_memory=args.get("share_weights", True), + module_extensions=extract_module_extensions(args)) + argv.input_model = decoder + argv.framework = 'pytorch' + return True + except: + pass + if isinstance(input_model, (str, pathlib.Path)): + # attempt to load exported model + try: + exported_program = torch.export.load(input_model) + if hasattr(torch, "export") and isinstance(exported_program, (torch.export.ExportedProgram)): + from packaging import version + if version.parse(torch.__version__) >= version.parse("2.2"): + exported_program = exported_program.run_decompositions() + gm = exported_program.module() + decoder = TorchFXPythonDecoder(gm, dynamic_shapes=True) + argv.input_model = decoder + argv.framework = 'pytorch' + return True + except: + pass + return False + + def update_list_or_dict(container, name, idx, value): if isinstance(container, dict): if name is None: