From d67400177e8395820f4b40f71dd0d9d5c362c4b0 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 30 May 2024 12:44:13 +0200 Subject: [PATCH 1/3] Create format.yml --- .github/workflows/format.yml | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 .github/workflows/format.yml diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 0000000..dc66e7b --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,29 @@ +# Ultralytics 🚀 - AGPL-3.0 license +# Ultralytics Actions https://github.com/ultralytics/actions +# This workflow automatically formats code and documentation in PRs to official Ultralytics standards + +name: Ultralytics Actions + +on: + push: + branches: [main] + pull_request: + branches: [main] + types: [opened, closed, synchronize] + +jobs: + format: + runs-on: ubuntu-latest + steps: + - name: Run Ultralytics Formatting + uses: ultralytics/actions@main + with: + token: ${{ secrets.GITHUB_TOKEN }} # automatically generated, do not modify + python: true # format Python code and docstrings + markdown: true # format Markdown + prettier: true # format YAML + spelling: true # check spelling + links: false # check broken links + # summary: true # print PR summary with GPT4 (requires 'openai_api_key' or 'openai_azure_api_key' and 'openai_azure_endpoint') + # openai_azure_api_key: ${{ secrets.OPENAI_AZURE_API_KEY }} + # openai_azure_endpoint: ${{ secrets.OPENAI_AZURE_ENDPOINT }} From 101c08fcd885bda231b6cb5b587bb79ec0dac68b Mon Sep 17 00:00:00 2001 From: UltralyticsAssistant Date: Thu, 30 May 2024 10:44:36 +0000 Subject: [PATCH 2/3] Auto-format by https://ultralytics.com/actions --- README.md | 23 +- .../README.md | 7 +- examples/input_shape_modification/README.md | 4 +- examples/model_inspect/README.md | 4 +- examples/output_modification/README.md | 4 +- onnxslim/__init__.py | 5 +- onnxslim/cli/_main.py | 19 +- onnxslim/core/optimizer.py | 137 +-- onnxslim/core/slim.py | 62 +- onnxslim/core/symbolic_shape_infer.py | 861 ++++-------------- .../exporters/onnx_exporter.py | 38 +- .../graph_pattern/graph_pattern.py | 91 +- .../importers/onnx_importer.py | 79 +- onnxslim/onnx_graphsurgeon/ir/function.py | 65 +- onnxslim/onnx_graphsurgeon/ir/graph.py | 289 ++---- onnxslim/onnx_graphsurgeon/ir/node.py | 31 +- onnxslim/onnx_graphsurgeon/ir/tensor.py | 81 +- onnxslim/onnx_graphsurgeon/logger/logger.py | 25 +- onnxslim/onnx_graphsurgeon/util/exception.py | 4 +- onnxslim/onnx_graphsurgeon/util/misc.py | 4 +- onnxslim/utils/tabulate.py | 295 ++---- onnxslim/utils/utils.py | 66 +- tests/test_onnx_nets.py | 1 - tests/test_onnxslim.py | 26 +- tests/utils.py | 49 +- 25 files changed, 591 insertions(+), 1679 deletions(-) diff --git a/README.md b/README.md index c60bdb4..4e3053e 100644 --- a/README.md +++ b/README.md @@ -12,21 +12,24 @@ OnnxSlim can help you slim your onnx model, with less operators, but same accuracy, better inference speed. - 🚀 OnnxSlim is merged to [mnn-llm](https://github.com/wangzhaode/mnn-llm), performance increased by 5% -- 🚀 Rank 1st in the [AICAS 2024 LLM inference optimiztion challenge](https://tianchi.aliyun.com/competition/entrance/532170/customize440) held by Arm and T-head - +- 🚀 Rank 1st in the [AICAS 2024 LLM inference optimization challenge](https://tianchi.aliyun.com/competition/entrance/532170/customize440) held by Arm and T-head # Installation + ## Using Prebuilt + ```bash pip install onnxslim ``` + ## Build From Source + ``` pip install . ``` - # How to use + ``` onnxslim your_onnx_model slimmed_onnx_model ``` @@ -36,12 +39,14 @@ onnxslim your_onnx_model slimmed_onnx_model For more usage, see onnxslim -h or refer to our [examples](./examples) # References -> * [onnx-graphsurgeon](https://github.com/NVIDIA/TensorRT/tree/main/tools/onnx-graphsurgeon) -> * [Polygraphy](https://github.com/NVIDIA/TensorRT/tree/main/tools/Polygraphy/polygraphy) -> * [onnx-simplifier](https://github.com/daquexian/onnx-simplifier) -> * [tabulate](https://github.com/astanin/python-tabulate) -> * [onnxruntime](https://github.com/microsoft/onnxruntime) + +> - [onnx-graphsurgeon](https://github.com/NVIDIA/TensorRT/tree/main/tools/onnx-graphsurgeon) +> - [Polygraphy](https://github.com/NVIDIA/TensorRT/tree/main/tools/Polygraphy/polygraphy) +> - [onnx-simplifier](https://github.com/daquexian/onnx-simplifier) +> - [tabulate](https://github.com/astanin/python-tabulate) +> - [onnxruntime](https://github.com/microsoft/onnxruntime) # Contact -Discord: https://discord.gg/nRw2Fd3VUS + +Discord: https://discord.gg/nRw2Fd3VUS\ QQ Group: 873569894 diff --git a/examples/common_subexpression_elimination/README.md b/examples/common_subexpression_elimination/README.md index 0ac6641..ee63ebf 100644 --- a/examples/common_subexpression_elimination/README.md +++ b/examples/common_subexpression_elimination/README.md @@ -1,16 +1,20 @@ # Common SubExpression Elimination ## Introduction + Common Subexpression Elimination (CSE) is a powerful optimization technique commonly employed in compilers to improve the efficiency of code execution. It targets redundant computations within a program by identifying and removing duplicate expressions, thus reducing both computational overhead and memory usage. By eliminating redundant computations, CSE enhances the overall performance of slimmed onnx model. ## How CSE Works + In many programs, certain expressions are computed multiple times within a given scope, even though their results remain constant across these computations. Common subexpressions refer to these redundant expressions. CSE identifies such common subexpressions and replaces subsequent occurrences with references to the original computation result. This process effectively reduces the number of computations required during program execution. For example, consider the following code snippet: + ``` int a = b + c; int x = b + c; ``` + In this code, b + c is a common subexpression computed twice. With CSE, the redundant computation of b + c would be eliminated, and both occurrences of x would directly reference the computation result of a. ## Running the example @@ -31,7 +35,6 @@ After onnxslim, the output will look like this: ![../../image/after_cse.png](../../images/after_cse.png) - and the summary is as follow: -![../../image/cse.png](../../images/cse.png) \ No newline at end of file +![../../image/cse.png](../../images/cse.png) diff --git a/examples/input_shape_modification/README.md b/examples/input_shape_modification/README.md index eb63d31..c2968c4 100644 --- a/examples/input_shape_modification/README.md +++ b/examples/input_shape_modification/README.md @@ -1,11 +1,13 @@ # Input Shape Modification ## Introduction + OnnxSlim includes an exploration of essential input shape modification techniques for ONNX models. This concise guide unveils techniques for seamlessly adjusting input tensor dimensions, ensuring optimal compatibility and performance within the dynamic landscape of neural network architectures. ## Running the example + Change the input model by running: ```bash @@ -14,4 +16,4 @@ onnxslim UNetModel-fp16.onnx slim.onnx --input_shapes cc:1,1,768 The slimmed model will look like this: -![../../image/input_shape_modification.jpg](../../images/input_shape_modification.jpg) \ No newline at end of file +![../../image/input_shape_modification.jpg](../../images/input_shape_modification.jpg) diff --git a/examples/model_inspect/README.md b/examples/model_inspect/README.md index 456d19f..3b07a9e 100644 --- a/examples/model_inspect/README.md +++ b/examples/model_inspect/README.md @@ -1,9 +1,11 @@ # Model Inspect ## Introduction + Dive deep into the intricacies of your ONNX model using the powerful --inspect argument with OnnxSlim. This feature provides detailed insights into various aspects of your model, including input and output details, operator information, opset version, and more. ## Running the example + Unveil the secrets of your ONNX model by executing the following command: ```bash @@ -12,4 +14,4 @@ onnxslim --inspect UNetModel-fp16.onnx The output will look like this: -![../../image/model_inspect.jpg](../../images/model_inspect.jpg) \ No newline at end of file +![../../image/model_inspect.jpg](../../images/model_inspect.jpg) diff --git a/examples/output_modification/README.md b/examples/output_modification/README.md index 792d5c6..cac7f2c 100644 --- a/examples/output_modification/README.md +++ b/examples/output_modification/README.md @@ -1,11 +1,13 @@ # Output Modification ## Introduction + OnnxSlim provides capabilities for modifying the output specifications of ONNX models. This section explores techniques to customize the outputs, allowing for flexibility in handling diverse model requirements. ## Running the example + Change the output of one model by running: ```bash @@ -14,4 +16,4 @@ onnxslim yolov5m.onnx slim.onnx --outputs 591 739 443 The slimmed model will look like this: -![../../image/output_modification.jpg](../../images/output_modification.jpg) \ No newline at end of file +![../../image/output_modification.jpg](../../images/output_modification.jpg) diff --git a/onnxslim/__init__.py b/onnxslim/__init__.py index 56dc54c..0520295 100644 --- a/onnxslim/__init__.py +++ b/onnxslim/__init__.py @@ -5,10 +5,7 @@ from .core.optimizer import DEFAULT_FUSION_PATTERNS from .version import __version__ - -if os.path.dirname(os.path.realpath(__file__)) == os.path.join( - os.path.realpath(os.getcwd()), "onnxslim" -): +if os.path.dirname(os.path.realpath(__file__)) == os.path.join(os.path.realpath(os.getcwd()), "onnxslim"): message = ( "You are importing onnxslim within its own root folder ({}). " "This is not expected to work and may give errors. Please exit the " diff --git a/onnxslim/cli/_main.py b/onnxslim/cli/_main.py index cc0ba3d..432c226 100644 --- a/onnxslim/cli/_main.py +++ b/onnxslim/cli/_main.py @@ -1,6 +1,7 @@ from typing import Union import onnx + from onnxslim.utils.utils import logger @@ -83,11 +84,7 @@ def slim( init_logging(verbose) - MAX_ITER = ( - 10 - if not os.getenv("ONNXSLIM_MAX_ITER") - else int(os.getenv("ONNXSLIM_MAX_ITER")) - ) + MAX_ITER = 10 if not os.getenv("ONNXSLIM_MAX_ITER") else int(os.getenv("ONNXSLIM_MAX_ITER")) if isinstance(model, str): model_name = Path(model).name @@ -175,14 +172,10 @@ def main(): formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument("input_model", help="input onnx model") - parser.add_argument( - "output_model", nargs="?", default=None, help="output onnx model" - ) + parser.add_argument("output_model", nargs="?", default=None, help="output onnx model") parser.add_argument("--model_check", action="store_true", help="enable model check") - parser.add_argument( - "-v", "--version", action="version", version=onnxslim.__version__ - ) + parser.add_argument("-v", "--version", action="version", version=onnxslim.__version__) # Input Shape Modification parser.add_argument( @@ -259,9 +252,7 @@ def main(): ) # Verbose - parser.add_argument( - "--verbose", action="store_true", help="verbose mode, default False." - ) + parser.add_argument("--verbose", action="store_true", help="verbose mode, default False.") args, unknown = parser.parse_known_args() diff --git a/onnxslim/core/optimizer.py b/onnxslim/core/optimizer.py index f57f1ba..2721588 100644 --- a/onnxslim/core/optimizer.py +++ b/onnxslim/core/optimizer.py @@ -1,18 +1,15 @@ import contextlib from collections import Counter, OrderedDict - from typing import List, Union import numpy as np - import onnx -from onnxslim.utils.utils import logger import onnxslim.onnx_graphsurgeon as gs from onnxslim.onnx_graphsurgeon.exporters.onnx_exporter import dtype_to_onnx from onnxslim.onnx_graphsurgeon.ir.graph import Graph from onnxslim.onnx_graphsurgeon.ir.tensor import Constant, Variable - +from onnxslim.utils.utils import logger DEFAULT_FUSION_PATTERNS = OrderedDict() @@ -106,9 +103,7 @@ def graph_constant_fold_inplace(graph): elif node.op == "Pad": if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant): pad_value = node.inputs[1].values.tolist() - pad_value = ( - [pad_value] if not isinstance(pad_value, list) else pad_value - ) + pad_value = [pad_value] if not isinstance(pad_value, list) else pad_value if all([value == 0 for value in pad_value]): delete_node(node) elif node.op == "Cast": @@ -123,10 +118,7 @@ def graph_constant_fold_inplace(graph): else: node_output_shape = node.outputs[0].shape if node_output_shape and check_shape(node_output_shape): - shapes = [ - shape if isinstance(shape, int) else -1 - for shape in node_output_shape - ] + shapes = [shape if isinstance(shape, int) else -1 for shape in node_output_shape] reshape_const = gs.Constant( node.inputs[1].name + "_", values=np.array(shapes, dtype=np.int64), @@ -134,24 +126,16 @@ def graph_constant_fold_inplace(graph): node.inputs.pop(1) node.inputs.insert(1, reshape_const) elif node.op == "Mul": - if ( - isinstance(node.inputs[1], Constant) - and isinstance(node.inputs[0], Variable) - ) or ( - isinstance(node.inputs[0], Constant) - and isinstance(node.inputs[1], Variable) + if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or ( + isinstance(node.inputs[0], Constant) and isinstance(node.inputs[1], Variable) ): idx, constant_variable = get_constant_variable(node, return_idx=True) if np.all(constant_variable.values == 1): var_idx = 0 if idx == 1 else 1 delete_node(node, var_idx) elif node.op == "Add": - if ( - isinstance(node.inputs[1], Constant) - and isinstance(node.inputs[0], Variable) - ) or ( - isinstance(node.inputs[0], Constant) - and isinstance(node.inputs[1], Variable) + if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or ( + isinstance(node.inputs[0], Constant) and isinstance(node.inputs[1], Variable) ): idx, constant_variable = get_constant_variable(node, return_idx=True) if np.all(constant_variable.values == 0): @@ -201,10 +185,7 @@ def find_conv_nodes(node, opset): len_conv_pads = int(len(conv_pads) / 2) len_pads = int(len(pad_value) / 2) - pads = ( - pad_value[len_pads - len_conv_pads : len_pads] - + pad_value[len_pads + len_conv_pads :] - ) + pads = pad_value[len_pads - len_conv_pads : len_pads] + pad_value[len_pads + len_conv_pads :] pads = [pad + conv_pad for pad, conv_pad in zip(pads, conv_pads)] attrs["pads"] = pads @@ -227,13 +208,7 @@ def find_conv_nodes(node, opset): @register_fusion_pattern("FusionConvBN") def find_conv_transpose_nodes(node, opset): # fmt: off - ''' - x - | - Conv/ConvTranspose - | - BatchNormalization - ''' + """X | Conv/ConvTranspose | BatchNormalization.""" # fmt: on match = {} if node.op == "BatchNormalization": @@ -260,12 +235,8 @@ def find_conv_transpose_nodes(node, opset): shape[0] = -1 else: shape[1] = -1 - conv_w = conv_transpose_weight * (bn_scale * bn_var_rsqrt).reshape( - shape - ) - conv_b = ( - conv_transpose_bias - bn_running_mean - ) * bn_var_rsqrt * bn_scale + bn_bias + conv_w = conv_transpose_weight * (bn_scale * bn_var_rsqrt).reshape(shape) + conv_b = (conv_transpose_bias - bn_running_mean) * bn_var_rsqrt * bn_scale + bn_bias inputs = [] inputs.append(list(conv_transpose_node.inputs)[0]) @@ -314,19 +285,11 @@ def find_slice_nodes(node, opset): if node.i(0).op == "Slice": first_slice_node = node.i(0) first_slice_node_inputs = list(first_slice_node.inputs) - if all( - [isinstance(input, Constant) for input in first_slice_node_inputs[1:]] - ): + if all([isinstance(input, Constant) for input in first_slice_node_inputs[1:]]): first_slice_node_users = get_node_users(first_slice_node) if all( [ - user.op == "Slice" - and all( - [ - isinstance(input, Constant) - for input in list(user.inputs)[1:] - ] - ) + user.op == "Slice" and all([isinstance(input, Constant) for input in list(user.inputs)[1:]]) for user in first_slice_node_users ] ): @@ -338,18 +301,10 @@ def find_slice_nodes(node, opset): for user_node in first_slice_node_users: second_slice_node = user_node second_slice_node_inputs = list(second_slice_node.inputs) - second_slice_node_starts = second_slice_node_inputs[ - 1 - ].values.tolist() - second_slice_node_ends = second_slice_node_inputs[ - 2 - ].values.tolist() - second_slice_node_axes = second_slice_node_inputs[ - 3 - ].values.tolist() - second_slice_node_steps = second_slice_node_inputs[ - 4 - ].values.tolist() + second_slice_node_starts = second_slice_node_inputs[1].values.tolist() + second_slice_node_ends = second_slice_node_inputs[2].values.tolist() + second_slice_node_axes = second_slice_node_inputs[3].values.tolist() + second_slice_node_steps = second_slice_node_inputs[4].values.tolist() new_starts = first_slice_node_starts + second_slice_node_starts new_ends = first_slice_node_ends + second_slice_node_ends @@ -454,9 +409,7 @@ def check_constant_mergeable(reshape_node): return False return True - if check_constant_mergeable( - first_reshape_node - ) and check_constant_mergeable(second_reshape_node): + if check_constant_mergeable(first_reshape_node) and check_constant_mergeable(second_reshape_node): inputs = [] inputs.append(first_reshape_node_inputs[0]) inputs.append(second_reshape_node.inputs[1]) @@ -509,9 +462,7 @@ def find_slice_nodes(node, opset): ) last_node.inputs.pop(3) last_node.inputs.insert(3, slice_axis) - previous_transpose_node_variable = previous_transpose_node.outputs[ - 0 - ] # pad output variable + previous_transpose_node_variable = previous_transpose_node.outputs[0] # pad output variable previous_transpose_node_variable.outputs.remove(last_node) last_node.inputs.insert(0, previous_transpose_node.inputs[0]) for node in previous_nodes: @@ -539,14 +490,10 @@ def find_matmul_add_nodes(node, opset): if (isinstance(node.inputs[1], Constant) and node.i(0).op == "MatMul") or ( isinstance(node.inputs[0], Constant) and node.i(1).op == "MatMul" ): - matmul_node = ( - node.i(0) if isinstance(node.inputs[1], Constant) else node.i(1) - ) + matmul_node = node.i(0) if isinstance(node.inputs[1], Constant) else node.i(1) matmul_bias_variable = get_constant_variable(matmul_node) input_variable = ( - matmul_node.inputs[0] - if isinstance(matmul_node.inputs[1], Constant) - else matmul_node.inputs[1] + matmul_node.inputs[0] if isinstance(matmul_node.inputs[1], Constant) else matmul_node.inputs[1] ) users = get_node_users(matmul_node) if len(users) == 1 and matmul_bias_variable: @@ -557,9 +504,7 @@ def find_matmul_add_nodes(node, opset): ): pre_reshape_const = gs.Constant( matmul_node.name + "_pre_reshape_in", - values=np.array( - [-1, matmul_bias_variable.values.shape[0]], dtype=np.int64 - ), + values=np.array([-1, matmul_bias_variable.values.shape[0]], dtype=np.int64), ) inputs = [] inputs.append(input_variable) @@ -573,8 +518,7 @@ def find_matmul_add_nodes(node, opset): match.update( { - matmul_node.name - + "_pre_reshape": { + matmul_node.name + "_pre_reshape": { "op": "Reshape", "inputs": inputs, "outputs": outputs, @@ -599,9 +543,7 @@ def find_matmul_add_nodes(node, opset): inputs.append(matmul_bias_transpose_constant) inputs.append(add_bias_variable) - gemm_out_variable = gs.Variable( - matmul_node.name + "_gemm_out", dtype=output_variable.dtype - ) + gemm_out_variable = gs.Variable(matmul_node.name + "_gemm_out", dtype=output_variable.dtype) outputs = [gemm_out_variable] match.update( @@ -622,9 +564,7 @@ def find_matmul_add_nodes(node, opset): } ) - values = input_variable.shape[:-1] + [ - matmul_bias_variable.values.shape[-1] - ] + values = input_variable.shape[:-1] + [matmul_bias_variable.values.shape[-1]] post_reshape_const = gs.Constant( matmul_node.name + "_post_reshape_in", values=np.array(values, dtype=np.int64), @@ -641,8 +581,7 @@ def find_matmul_add_nodes(node, opset): match.update( { - matmul_node.name - + "_post_reshape": { + matmul_node.name + "_post_reshape": { "op": "Reshape", "inputs": inputs, "outputs": outputs, @@ -829,13 +768,7 @@ def find_matches(graph: Graph, fusion_patterns: dict): if "op" not in match: match.update({"op": layer_type}) if "name" not in match: - match.update( - { - "name": "{}_{}".format( - layer_type.lower(), counter[layer_type] - ) - } - ) + match.update({"name": "{}_{}".format(layer_type.lower(), counter[layer_type])}) counter.update([layer_type]) match_map.update(matches) @@ -880,19 +813,13 @@ def replace_node_references(existing_node, to_be_removed_node): if keep_nodes[i]: for j in range(i + 1, len(bucketed_nodes)): if keep_nodes[j]: - logger.debug( - f"node.op {bucketed_nodes[0].op} idx i: {i}, idx j: {j}" - ) + logger.debug(f"node.op {bucketed_nodes[0].op} idx i: {i}, idx j: {j}") if can_be_replaced(node, bucketed_nodes[j]): keep_nodes[j] = False existing_node = node to_be_removed_node = bucketed_nodes[j] - replace_node_references( - existing_node, to_be_removed_node - ) - logger.debug( - f"Node {to_be_removed_node.name} can be replaced by {existing_node.name}" - ) + replace_node_references(existing_node, to_be_removed_node) + logger.debug(f"Node {to_be_removed_node.name} can be replaced by {existing_node.name}") def sequences_equal(seq1, seq2): @@ -927,9 +854,7 @@ def subexpression_elimination(graph): find_and_remove_replaceable_nodes(nodes) -def optimize_model( - model: Union[onnx.ModelProto, gs.Graph], skip_fusion_patterns: str = None -) -> onnx.ModelProto: +def optimize_model(model: Union[onnx.ModelProto, gs.Graph], skip_fusion_patterns: str = None) -> onnx.ModelProto: if isinstance(model, gs.Graph): graph = model else: diff --git a/onnxslim/core/slim.py b/onnxslim/core/slim.py index 3b0ef96..472ad99 100644 --- a/onnxslim/core/slim.py +++ b/onnxslim/core/slim.py @@ -1,5 +1,4 @@ import logging - import os import sys import tempfile @@ -16,20 +15,15 @@ from ..utils.utils import ( dump_model_info_to_disk, gen_onnxruntime_input_data, + logger, onnxruntime_inference, print_model_info_as_table, - logger ) - from .optimizer import delete_node, optimize_model from .symbolic_shape_infer import SymbolicShapeInference DEBUG = bool(os.getenv("ONNXSLIM_DEBUG")) -AUTO_MERGE = ( - True - if os.getenv("ONNXSLIM_AUTO_MERGE") is None - else bool(int(os.getenv("ONNXSLIM_AUTO_MERGE"))) -) +AUTO_MERGE = True if os.getenv("ONNXSLIM_AUTO_MERGE") is None else bool(int(os.getenv("ONNXSLIM_AUTO_MERGE"))) def init_logging(verbose=False): @@ -38,14 +32,18 @@ def init_logging(verbose=False): logging.root.removeHandler(handler) if verbose: # DEBUG - logging.basicConfig(level=logging.DEBUG, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler(sys.stderr)]) + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stderr)], + ) G_LOGGER.severity = logging.DEBUG else: # ERROR - logging.basicConfig(level=logging.ERROR, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - handlers=[logging.StreamHandler(sys.stderr)]) + logging.basicConfig( + level=logging.ERROR, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stderr)], + ) G_LOGGER.severity = logging.ERROR G_LOGGER.colors = False @@ -77,9 +75,7 @@ def summarize_model(model: onnx.ModelProto) -> Dict: op_type_counts = {} def get_tensor_dtype_shape(tensor): - type_str = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE.get( - tensor.type.tensor_type.elem_type, "Unknown" - ) + type_str = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE.get(tensor.type.tensor_type.elem_type, "Unknown") shape = None if tensor.type.tensor_type.HasField("shape"): shape = [] @@ -104,9 +100,7 @@ def get_shape(inputs: onnx.ModelProto) -> Dict[str, List[int]]: return op_shape_info - value_info_dict = { - value_info.name: value_info for value_info in model.graph.value_info - } + value_info_dict = {value_info.name: value_info for value_info in model.graph.value_info} for node in model.graph.node: op_type = node.op_type @@ -148,9 +142,7 @@ def model_save_as_external_data(model: onnx.ModelProto, model_path: str): ) -def input_shape_modification( - model: onnx.ModelProto, input_shapes: str -) -> onnx.ModelProto: +def input_shape_modification(model: onnx.ModelProto, input_shapes: str) -> onnx.ModelProto: if not input_shapes: return @@ -162,9 +154,7 @@ def input_shape_modification( key, values = input_shape.rsplit(":", 1) values_list = [int(value) for value in values.split(",")] if key not in input_names: - raise Exception( - f"Input name {key} not found in model, available keys: {' '.join(input_names)}" - ) + raise Exception(f"Input name {key} not found in model, available keys: {' '.join(input_names)}") tensors[key].shape = values_list for _, tensor in tensors.items(): @@ -187,15 +177,11 @@ def output_modification(model: onnx.ModelProto, outputs: str) -> onnx.ModelProto if len(values) == 1: key = values[0] if key not in tensors.keys(): - raise Exception( - f"Output name {key} not found in model, available keys: {' '.join(tensors.keys())}" - ) + raise Exception(f"Output name {key} not found in model, available keys: {' '.join(tensors.keys())}") dtype = tensors[key].dtype if dtype == None: dtype = np.float32 - logger.warning( - f"Output layer {key} has no dtype, set to default {dtype}" - ) + logger.warning(f"Output layer {key} has no dtype, set to default {dtype}") else: key, dtype = values if dtype == "fp16": @@ -207,13 +193,9 @@ def output_modification(model: onnx.ModelProto, outputs: str) -> onnx.ModelProto elif dtype == "bool": dtype = bool else: - raise Exception( - f"Output layer {key} assigned unsupported dtype {dtype}" - ) + raise Exception(f"Output layer {key} assigned unsupported dtype {dtype}") - graph.outputs.append( - tensors[key].to_variable(dtype=dtype, shape=tensors[key].shape) - ) + graph.outputs.append(tensors[key].to_variable(dtype=dtype, shape=tensors[key].shape)) graph.cleanup(remove_unused_graph_inputs=True).toposort() model = gs.export_onnx(graph) @@ -341,9 +323,7 @@ def check_result(raw_onnx_output, slimmed_onnx_output): if set(raw_onnx_output.keys()) != set(slimmed_onnx_output.keys()): logger.warning("Model output mismatch after slimming.") logger.warning("Raw model output keys: {}".format(raw_onnx_output.keys())) - logger.warning( - "Slimmed model output keys: {}".format(slimmed_onnx_output.keys()) - ) + logger.warning("Slimmed model output keys: {}".format(slimmed_onnx_output.keys())) logger.warning("Please check the model carefully.") return else: diff --git a/onnxslim/core/symbolic_shape_infer.py b/onnxslim/core/symbolic_shape_infer.py index 1cf79ff..e80a278 100644 --- a/onnxslim/core/symbolic_shape_infer.py +++ b/onnxslim/core/symbolic_shape_infer.py @@ -24,11 +24,7 @@ def get_attribute(node, attr_name, default_value=None): def get_dim_from_proto(dim): - return ( - getattr(dim, dim.WhichOneof("value")) - if type(dim.WhichOneof("value")) is str - else None - ) # noqa: E721 + return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) is str else None # noqa: E721 def is_sequence(type_proto): @@ -72,16 +68,11 @@ def make_named_value_info(name): def get_shape_from_sympy_shape(sympy_shape): - return [ - None if i is None else (int(i) if is_literal(i) else str(i)) - for i in sympy_shape - ] + return [None if i is None else (int(i) if is_literal(i) else str(i)) for i in sympy_shape] def is_literal(dim): - return type(dim) in [int, np.int64, np.int32, sympy.Integer] or ( - hasattr(dim, "is_number") and dim.is_number - ) + return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr(dim, "is_number") and dim.is_number) def handle_negative_axis(axis, rank): @@ -263,12 +254,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): self.prefix_ = prefix def _add_suggested_merge(self, symbols, apply=False): - assert all( - [ - (type(s) == str and s in self.symbolic_dims_) or is_literal(s) - for s in symbols - ] - ) # noqa: E721 + assert all([(type(s) == str and s in self.symbolic_dims_) or is_literal(s) for s in symbols]) # noqa: E721 symbols = set(symbols) for k, v in self.suggested_merge_.items(): if k in symbols: @@ -294,11 +280,7 @@ def _add_suggested_merge(self, symbols, apply=False): # when nothing to map to, use the shorter one if map_to is None: if self.verbose_ > 0: - logger.warning( - "Potential unsafe merge between symbolic expressions: ({})".format( - ",".join(symbols) - ) - ) + logger.warning("Potential unsafe merge between symbolic expressions: ({})".format(",".join(symbols))) symbols_list = list(symbols) lens = [len(s) for s in symbols_list] map_to = symbols_list[lens.index(min(lens))] @@ -319,9 +301,7 @@ def _add_suggested_merge(self, symbols, apply=False): def _apply_suggested_merge(self, graph_input_only=False): if not self.suggested_merge_: return - for i in list(self.out_mp_.graph.input) + ( - [] if graph_input_only else list(self.out_mp_.graph.value_info) - ): + for i in list(self.out_mp_.graph.input) + ([] if graph_input_only else list(self.out_mp_.graph.value_info)): for d in i.type.tensor_type.shape.dim: if d.dim_param in self.suggested_merge_: v = self.suggested_merge_[d.dim_param] @@ -348,9 +328,7 @@ def _merge_symbols(self, dims): if self.auto_merge_: unique_dims = list(set(dims)) is_int = [is_literal(d) for d in unique_dims] - assert ( - sum(is_int) <= 1 - ) # if there are more than 1 unique ints, something is wrong + assert sum(is_int) <= 1 # if there are more than 1 unique ints, something is wrong if sum(is_int) == 1: int_dim = is_int.index(1) if self.verbose_ > 0: @@ -364,17 +342,13 @@ def _merge_symbols(self, dims): return unique_dims[int_dim] else: if self.verbose_ > 0: - logger.debug( - f"dim {unique_dims[1:]} has been merged with dim {unique_dims[0]}" - ) + logger.debug(f"dim {unique_dims[1:]} has been merged with dim {unique_dims[0]}") return dims[0] else: return None if all([d == dims[0] for d in dims]): return dims[0] - merged = [ - self.suggested_merge_[d] if d in self.suggested_merge_ else d for d in dims - ] + merged = [self.suggested_merge_[d] if d in self.suggested_merge_ else d for d in dims] if all([d == merged[0] for d in merged]): assert merged[0] in self.symbolic_dims_ return merged[0] @@ -403,12 +377,7 @@ def _broadcast_shapes(self, shape1, shape2): if self.auto_merge_: self._add_suggested_merge([dim1, dim2], apply=True) else: - logger.warning( - "unsupported broadcast between " - + str(dim1) - + " " - + str(dim2) - ) + logger.warning("unsupported broadcast between " + str(dim1) + " " + str(dim2)) new_shape = [new_dim, *new_shape] return new_shape @@ -452,11 +421,7 @@ def _get_sympy_shape(self, node, idx): def _get_value(self, node, idx): name = node.input[idx] assert name in self.sympy_data_ or name in self.initializers_ - return ( - self.sympy_data_[name] - if name in self.sympy_data_ - else numpy_helper.to_array(self.initializers_[name]) - ) + return self.sympy_data_[name] if name in self.sympy_data_ else numpy_helper.to_array(self.initializers_[name]) def _try_get_value(self, node, idx): if idx >= len(node.input): @@ -473,9 +438,7 @@ def _update_computed_dims(self, new_sympy_shape): if str_dim in self.suggested_merge_: if is_literal(self.suggested_merge_[str_dim]): continue # no need to create dim for literals - new_sympy_shape[i] = self.symbolic_dims_[ - self.suggested_merge_[str_dim] - ] + new_sympy_shape[i] = self.symbolic_dims_[self.suggested_merge_[str_dim]] else: # add new_dim if it's a computational expression if str(new_dim) not in self.symbolic_dims_: @@ -545,23 +508,11 @@ def _onnx_infer_single_node(self, node): if node.output[0] in self.known_vi_: vi = self.known_vi_[node.output[0]] out_rank = len(get_shape_from_type_proto(vi.type)) - in_shapes = [ - self._get_shape(node, i) for i in range(len(node.input)) - ] + in_shapes = [self._get_shape(node, i) for i in range(len(node.input))] for d in range( - out_rank - - ( - 2 - if node.op_type - in ["MatMul", "MatMulInteger", "MatMulInteger16"] - else 0 - ) + out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0) ): - in_dims = [ - s[len(s) - out_rank + d] - for s in in_shapes - if len(s) + d >= out_rank - ] + in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank] if len(in_dims) > 1: self._check_merged_dims(in_dims, allow_broadcast=True) @@ -587,36 +538,22 @@ def _onnx_infer_single_node(self, node): vi.name = o self.known_vi_[o] = vi - def _onnx_infer_subgraph( - self, node, subgraph, use_node_input=True, inc_subgraph_id=True - ): + def _onnx_infer_subgraph(self, node, subgraph, use_node_input=True, inc_subgraph_id=True): if self.verbose_ > 2: - logger.debug( - f"Inferencing subgraph of node {node.name} with output({node.output[0]}...): {node.op_type}" - ) + logger.debug(f"Inferencing subgraph of node {node.name} with output({node.output[0]}...): {node.op_type}") # node inputs are not passed directly to the subgraph # it's up to the node dispatcher to prepare subgraph input # for example, with Scan/Loop, subgraph input shape would be trimmed from node input shape # besides, inputs in subgraph could shadow implicit inputs - subgraph_inputs = { - i.name for i in list(subgraph.initializer) + list(subgraph.input) - } - subgraph_implicit_input = { - name for name in self.known_vi_ if name not in subgraph_inputs - } + subgraph_inputs = {i.name for i in list(subgraph.initializer) + list(subgraph.input)} + subgraph_implicit_input = {name for name in self.known_vi_ if name not in subgraph_inputs} tmp_graph = helper.make_graph( list(subgraph.node), "tmp", list(subgraph.input) + [self.known_vi_[i] for i in subgraph_implicit_input], [make_named_value_info(i.name) for i in subgraph.output], ) - tmp_graph.initializer.extend( - [ - i - for i in self.out_mp_.graph.initializer - if i.name in subgraph_implicit_input - ] - ) + tmp_graph.initializer.extend([i for i in self.out_mp_.graph.initializer if i.name in subgraph_implicit_input]) tmp_graph.initializer.extend(subgraph.initializer) self.tmp_mp_.graph.CopyFrom(tmp_graph) @@ -638,9 +575,7 @@ def _onnx_infer_subgraph( if use_node_input: # if subgraph uses node input, it needs to update to merged dims subgraph.ClearField("input") - subgraph.input.extend( - symbolic_shape_inference.out_mp_.graph.input[: len(node.input)] - ) + subgraph.input.extend(symbolic_shape_inference.out_mp_.graph.input[: len(node.input)]) subgraph.ClearField("output") subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output) subgraph.ClearField("value_info") @@ -648,10 +583,7 @@ def _onnx_infer_subgraph( subgraph.ClearField("node") subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node) # for new symbolic dims from subgraph output, add to main graph symbolic dims - subgraph_shapes = [ - get_shape_from_value_info(o) - for o in symbolic_shape_inference.out_mp_.graph.output - ] + subgraph_shapes = [get_shape_from_value_info(o) for o in symbolic_shape_inference.out_mp_.graph.output] subgraph_new_symbolic_dims = { d for s in subgraph_shapes @@ -707,12 +639,10 @@ def _compute_on_sympy_data(self, node, op_func): assert len(node.output) == 1 # Before mul & div operations - # cast inputs into interger might lose decimal part and reduce precision + # cast inputs into integer might lose decimal part and reduce precision # keep them as float, finish the operation, then cast the result into integer if node.op_type in ["Mul", "Div"]: - values = self._get_int_or_float_values( - node, broadcast=True, allow_float_values=True - ) + values = self._get_int_or_float_values(node, broadcast=True, allow_float_values=True) else: values = self._get_int_or_float_values(node, broadcast=True) @@ -764,9 +694,7 @@ def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0): ) def _new_symbolic_shape(self, rank, node, out_idx=0): - return [ - self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank) - ] + return [self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)] def _compute_conv_pool_shape(self, node, channels_last=False): sympy_shape = self._get_sympy_shape(node, 0) @@ -783,9 +711,7 @@ def _compute_conv_pool_shape(self, node, channels_last=False): assert len(sympy_shape) == rank + 2 # only need to symbolic shape inference if input has symbolic dims in spatial axes - spatial_shape = ( - sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:] - ) + spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:] is_symbolic_dims = [not is_literal(i) for i in spatial_shape] if not any(is_symbolic_dims): @@ -793,34 +719,26 @@ def _compute_conv_pool_shape(self, node, channels_last=False): if len(shape) > 0: assert len(sympy_shape) == len(shape) if channels_last: - sympy_shape[-rank - 1 : -1] = [ - sympy.Integer(d) for d in shape[-rank - 1 : -1] - ] + sympy_shape[-rank - 1 : -1] = [sympy.Integer(d) for d in shape[-rank - 1 : -1]] else: sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]] return sympy_shape dilations = get_attribute(node, "dilations", [1] * rank) strides = get_attribute(node, "strides", [1] * rank) - effective_kernel_shape = [ - (k - 1) * d + 1 for k, d in zip(kernel_shape, dilations) - ] + effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)] pads = get_attribute(node, "pads") if pads is None: pads = [0] * (2 * rank) auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8") if auto_pad != "VALID" and auto_pad != "NOTSET": try: - residual = [ - sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides) - ] + residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides)] total_pads = [ max(0, (k - s) if r == 0 else (k - r)) for k, s, r in zip(effective_kernel_shape, strides, residual) ] - except ( - TypeError - ): # sympy may throw TypeError: cannot determine truth value of Relational + except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational total_pads = [ max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides) ] # assuming no residual if sympy throws error @@ -842,12 +760,8 @@ def _compute_conv_pool_shape(self, node, channels_last=False): (effective_input_size - effective_kernel_shape[i]) / strides[i] ) else: - strided_kernel_positions = ( - effective_input_size - effective_kernel_shape[i] - ) // strides[i] - sympy_shape[-rank + i + (-1 if channels_last else 0)] = ( - strided_kernel_positions + 1 - ) + strided_kernel_positions = (effective_input_size - effective_kernel_shape[i]) // strides[i] + sympy_shape[-rank + i + (-1 if channels_last else 0)] = strided_kernel_positions + 1 return sympy_shape def _check_merged_dims(self, dims, allow_broadcast=True): @@ -889,23 +803,15 @@ def _compute_matmul_shape(self, node, output_dtype=None): # infer output_dtype from input type when not specified output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, new_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) def _fuse_tensor_type(self, node, out_idx, dst_type, src_type): - """ - update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches - """ + """Update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches.""" dst_tensor_type = ( - dst_type.sequence_type.elem_type.tensor_type - if is_sequence(dst_type) - else dst_type.tensor_type + dst_type.sequence_type.elem_type.tensor_type if is_sequence(dst_type) else dst_type.tensor_type ) src_tensor_type = ( - src_type.sequence_type.elem_type.tensor_type - if is_sequence(src_type) - else src_type.tensor_type + src_type.sequence_type.elem_type.tensor_type if is_sequence(src_type) else src_type.tensor_type ) if dst_tensor_type.elem_type != src_tensor_type.elem_type: node_id = node.name if node.name else node.op_type @@ -915,17 +821,13 @@ def _fuse_tensor_type(self, node, out_idx, dst_type, src_type): f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}" ) if dst_tensor_type.HasField("shape"): - for di, ds in enumerate( - zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim) - ): + for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)): if ds[0] != ds[1]: # create a new symbolic dimension for node/out_idx/mismatch dim id in dst_tensor_type for tensor_type # for sequence_type, clear the dimension new_dim = onnx.TensorShapeProto.Dimension() if not is_sequence(dst_type): - new_dim.dim_param = str( - self._new_symbolic_dim_from_output(node, out_idx, di) - ) + new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, out_idx, di)) dst_tensor_type.shape.dim[di].CopyFrom(new_dim) else: dst_tensor_type.CopyFrom(src_tensor_type) @@ -955,24 +857,14 @@ def _infer_symbolic_compute_ops(self, node): "Max": lambda l: ( l[1] # noqa: E741 if is_literal(l[0]) and int(l[0]) < -self.int_max_ - else ( - l[0] - if is_literal(l[1]) and int(l[1]) < -self.int_max_ - else sympy.Max(l[0], l[1]) - ) + else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])) ), "Min": lambda l: ( l[1] # noqa: E741 if is_literal(l[0]) and int(l[0]) > self.int_max_ - else ( - l[0] - if is_literal(l[1]) and int(l[1]) > self.int_max_ - else sympy.Min(l[0], l[1]) - ) + else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])) ), - "Mul": lambda l: ( - int(l[0] * l[1]) if isinstance(l[0] * l[1], float) else l[0] * l[1] - ), # noqa: E741 + "Mul": lambda l: (int(l[0] * l[1]) if isinstance(l[0] * l[1], float) else l[0] * l[1]), # noqa: E741 "Sub": lambda l: l[0] - l[1], # noqa: E741 "Where": lambda l: l[1] if l[0] else l[2], # noqa: E741 "Neg": lambda l: -l[0], # noqa: E741 @@ -990,11 +882,7 @@ def _infer_CategoryMapper(self, node): # noqa: N802 else: output_type = onnx.TensorProto.STRING vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], output_type, self._get_shape(node, 0) - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_type, self._get_shape(node, 0))) def _infer_Compress(self, node): # noqa: N802 input_shape = self._get_shape(node, 0) @@ -1040,11 +928,7 @@ def _infer_Concat(self, node): # noqa: N802 for d in range(len(sympy_shape)): if d == axis: continue - dims = [ - self._get_shape(node, i_idx)[d] - for i_idx in range(len(node.input)) - if self._get_shape(node, i_idx) - ] + dims = [self._get_shape(node, i_idx)[d] for i_idx in range(len(node.input)) if self._get_shape(node, i_idx)] if all([d == dims[0] for d in dims]): continue merged = self._merge_symbols(dims) @@ -1064,9 +948,7 @@ def _infer_Concat(self, node): # noqa: N802 def _infer_ConcatFromSequence(self, node): # noqa: N802 seq_shape = self._get_shape(node, 0) new_axis = 1 if get_attribute(node, "new_axis") else 0 - axis = handle_negative_axis( - get_attribute(node, "axis"), len(seq_shape) + new_axis - ) + axis = handle_negative_axis(get_attribute(node, "axis"), len(seq_shape) + new_axis) concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis)) new_shape = seq_shape if new_axis: @@ -1077,9 +959,7 @@ def _infer_ConcatFromSequence(self, node): # noqa: N802 vi.CopyFrom( helper.make_tensor_value_info( node.output[0], - self.known_vi_[ - node.input[0] - ].type.sequence_type.elem_type.tensor_type.elem_type, + self.known_vi_[node.input[0]].type.sequence_type.elem_type.tensor_type.elem_type, new_shape, ) ) @@ -1096,9 +976,7 @@ def _infer_ConstantOfShape(self, node): # noqa: N802 sympy_shape = [sympy_shape] self._update_computed_dims(sympy_shape) # update sympy data if output type is int, and shape is known - if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all( - [is_literal(x) for x in sympy_shape] - ): + if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all([is_literal(x) for x in sympy_shape]): self.sympy_data_[node.output[0]] = np.ones( [int(x) for x in sympy_shape], dtype=np.int64 ) * numpy_helper.to_array(get_attribute(node, "value", 0)) @@ -1147,9 +1025,7 @@ def _infer_DequantizeLinear(self, node): # noqa: N802 output_shape = self._get_shape(node, 0) vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, output_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) def _infer_QuantizeLinear(self, node): # noqa: N802 # Get the output data type from the zero-point input (index 2, optional). @@ -1162,9 +1038,7 @@ def _infer_QuantizeLinear(self, node): # noqa: N802 output_shape = self._get_shape(node, 0) vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, output_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) def _infer_Einsum(self, node): # noqa: N802 # ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275 @@ -1226,9 +1100,7 @@ def _infer_Einsum(self, node): # noqa: N802 output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, new_sympy_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_sympy_shape)) def _infer_Expand(self, node): # noqa: N802 expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True) @@ -1236,9 +1108,7 @@ def _infer_Expand(self, node): # noqa: N802 # new_shape's dim can come from shape value self._update_computed_dims(expand_to_shape) shape = self._get_shape(node, 0) - new_shape = self._broadcast_shapes( - shape, get_shape_from_sympy_shape(expand_to_shape) - ) + new_shape = self._broadcast_shapes(shape, get_shape_from_sympy_shape(expand_to_shape)) vi = self.known_vi_[node.output[0]] vi.CopyFrom( helper.make_tensor_value_info( @@ -1261,11 +1131,7 @@ def _infer_Gather(self, node): # noqa: N802 ) ) # for 1D input, do some sympy compute - if ( - node.input[0] in self.sympy_data_ - and len(data_shape) == 1 - and get_attribute(node, "axis", 0) == 0 - ): + if node.input[0] in self.sympy_data_ and len(data_shape) == 1 and get_attribute(node, "axis", 0) == 0: idx = self._try_get_value(node, 1) if idx is not None: data = self.sympy_data_[node.input[0]] @@ -1320,32 +1186,24 @@ def _infer_If(self, node): # noqa: N802 subgraphs[0].CopyFrom(subgraphs[1]) for i_sub, subgraph in enumerate(subgraphs): - subgraph_infer = self._onnx_infer_subgraph( - node, subgraph, use_node_input=False - ) + subgraph_infer = self._onnx_infer_subgraph(node, subgraph, use_node_input=False) for i_out in range(len(node.output)): vi = self.known_vi_[node.output[i_out]] if i_sub == 0: vi.CopyFrom(subgraph.output[i_out]) vi.name = node.output[i_out] else: - self._fuse_tensor_type( - node, i_out, vi.type, subgraph.output[i_out].type - ) + self._fuse_tensor_type(node, i_out, vi.type, subgraph.output[i_out].type) # pass on sympy data from subgraph, if cond is constant if cond is not None and i_sub == (0 if as_scalar(cond) > 0 else 1): if subgraph.output[i_out].name in subgraph_infer.sympy_data_: - self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[ - subgraph.output[i_out].name - ] + self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[subgraph.output[i_out].name] def _infer_Loop(self, node): # noqa: N802 subgraph = get_attribute(node, "body") assert len(subgraph.input) == len(node.input) - num_loop_carried = ( - len(node.input) - 2 - ) # minus the length and initial loop condition + num_loop_carried = len(node.input) - 2 # minus the length and initial loop condition # when sequence_type is used as loop carried input # needs to run subgraph infer twice if the tensor shape in sequence contains None for i, si in enumerate(subgraph.input): @@ -1367,9 +1225,7 @@ def _infer_Loop(self, node): # noqa: N802 # copy shape from output to input # note that loop input is [loop_len, cond, input_0, input_1, ...] # while loop output is [cond, output_0, output_1, ...] - subgraph.input[i_out + 1].type.sequence_type.elem_type.CopyFrom( - so.type.sequence_type.elem_type - ) + subgraph.input[i_out + 1].type.sequence_type.elem_type.CopyFrom(so.type.sequence_type.elem_type) need_second_infer = True else: si = subgraph.input[i_out + 1] @@ -1377,9 +1233,7 @@ def _infer_Loop(self, node): # noqa: N802 for di, dims in enumerate(zip(si_shape, so_shape)): if dims[0] != dims[1]: new_dim = onnx.TensorShapeProto.Dimension() - new_dim.dim_param = str( - self._new_symbolic_dim_from_output(node, i_out, di) - ) + new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, i_out, di)) si.type.tensor_type.shape.dim[di].CopyFrom(new_dim) so.type.tensor_type.shape.dim[di].CopyFrom(new_dim) need_second_infer = True @@ -1397,13 +1251,9 @@ def _infer_Loop(self, node): # noqa: N802 loop_iter_dim = str(self._new_symbolic_dim_from_output(node)) for i in range(len(node.output)): vi = self.known_vi_[node.output[i]] - vi.CopyFrom( - subgraph.output[i + 1] - ) # first subgraph output is condition, not in node output + vi.CopyFrom(subgraph.output[i + 1]) # first subgraph output is condition, not in node output if i >= num_loop_carried: - assert not is_sequence( - vi.type - ) # TODO: handle loop accumulation in sequence_type + assert not is_sequence(vi.type) # TODO: handle loop accumulation in sequence_type subgraph_vi_dim = subgraph.output[i + 1].type.tensor_type.shape.dim vi.type.tensor_type.shape.ClearField("dim") vi_dim = vi.type.tensor_type.shape.dim @@ -1420,22 +1270,14 @@ def _infer_MatMulInteger(self, node): # noqa: N802 def _infer_NonMaxSuppression(self, node): # noqa: N802 selected = str(self._new_symbolic_dim_from_output(node)) vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], onnx.TensorProto.INT64, [selected, 3] - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [selected, 3])) def _infer_NonZero(self, node): # noqa: N802 input_rank = self._get_shape_rank(node, 0) # create a new symbolic dimension for NonZero output nz_len = str(self._new_symbolic_dim_from_output(node, 0, 1)) vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], vi.type.tensor_type.elem_type, [input_rank, nz_len] - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, [input_rank, nz_len])) def _infer_OneHot(self, node): # noqa: N802 sympy_shape = self._get_sympy_shape(node, 0) @@ -1444,13 +1286,7 @@ def _infer_OneHot(self, node): # noqa: N802 axis = handle_negative_axis(axis, len(sympy_shape) + 1) new_shape = get_shape_from_sympy_shape( sympy_shape[:axis] - + [ - ( - self._new_symbolic_dim_from_output(node) - if not is_literal(depth) - else depth - ) - ] + + [(self._new_symbolic_dim_from_output(node) if not is_literal(depth) else depth)] + sympy_shape[axis:] ) vi = self.known_vi_[node.output[0]] @@ -1474,8 +1310,7 @@ def _infer_Pad(self, node): # noqa: N802 if pads is not None: assert len(pads) == 2 * rank new_sympy_shape = [ - d + pad_up + pad_down - for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[rank:]) + d + pad_up + pad_down for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[rank:]) ] self._update_computed_dims(new_sympy_shape) else: @@ -1485,9 +1320,7 @@ def _infer_Pad(self, node): # noqa: N802 vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], output_tp, get_shape_from_sympy_shape(new_sympy_shape) - ) + helper.make_tensor_value_info(node.output[0], output_tp, get_shape_from_sympy_shape(new_sympy_shape)) ) def _infer_Pool(self, node): # noqa: N802 @@ -1511,11 +1344,7 @@ def _infer_aten_bitwise_or(self, node): new_shape = self._broadcast_shapes(shape0, shape1) t0 = self.known_vi_[node.input[0]] vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], t0.type.tensor_type.elem_type, new_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], t0.type.tensor_type.elem_type, new_shape)) def _infer_aten_diagonal(self, node): sympy_shape = self._get_sympy_shape(node, 0) @@ -1557,11 +1386,7 @@ def _infer_aten_multinomial(self, node): assert rank in [1, 2] num_samples = self._try_get_value(node, 1) di = rank - 1 - last_dim = ( - num_samples - if num_samples - else str(self._new_symbolic_dim_from_output(node, 0, di)) - ) + last_dim = num_samples if num_samples else str(self._new_symbolic_dim_from_output(node, 0, di)) output_shape = sympy_shape[:-1] + [last_dim] vi = self.known_vi_[node.output[0]] vi.CopyFrom( @@ -1575,24 +1400,14 @@ def _infer_aten_multinomial(self, node): def _infer_aten_pool2d(self, node): sympy_shape = self._get_sympy_shape(node, 0) assert len(sympy_shape) == 4 - sympy_shape[-2:] = [ - self._new_symbolic_dim_from_output(node, 0, i) for i in [2, 3] - ] + sympy_shape[-2:] = [self._new_symbolic_dim_from_output(node, 0, i) for i in [2, 3]] self._update_computed_dims(sympy_shape) for i, o in enumerate(node.output): if not o: continue vi = self.known_vi_[o] - elem_type = ( - onnx.TensorProto.INT64 - if i == 1 - else self.known_vi_[node.input[0]].type.tensor_type.elem_type - ) - vi.CopyFrom( - helper.make_tensor_value_info( - o, elem_type, get_shape_from_sympy_shape(sympy_shape) - ) - ) + elem_type = onnx.TensorProto.INT64 if i == 1 else self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi.CopyFrom(helper.make_tensor_value_info(o, elem_type, get_shape_from_sympy_shape(sympy_shape))) def _infer_aten_minmax(self, node): vi = self.known_vi_[node.output[0]] @@ -1611,9 +1426,7 @@ def _infer_aten_minmax(self, node): dim = self._try_get_value(node, 1) if dim is None: rank = self._get_shape_rank(node, 0) - output_shape = self._new_symbolic_shape( - rank if keepdim else rank - 1, node - ) + output_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node) else: shape = self._get_sympy_shape(node, 0) dim = handle_negative_axis(dim, len(shape)) @@ -1631,11 +1444,7 @@ def _infer_aten_minmax(self, node): ) ) vi1 = self.known_vi_[node.output[1]] - vi1.CopyFrom( - helper.make_tensor_value_info( - node.output[1], onnx.TensorProto.INT64, output_shape - ) - ) + vi1.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT64, output_shape)) def _infer_aten_unfold(self, node): sympy_shape = self._get_sympy_shape(node, 0) @@ -1678,27 +1487,17 @@ def _infer_aten_argmax(self, node): del sympy_shape[dim] else: rank = len(sympy_shape) - sympy_shape = self._new_symbolic_shape( - rank if keepdim else rank - 1, node - ) + sympy_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node) self._update_computed_dims(sympy_shape) new_shape = get_shape_from_sympy_shape(sympy_shape) if node.output[0] and new_shape is not None: vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], onnx.TensorProto.INT64, new_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, new_shape)) def _infer_aten_group_norm(self, node): self._propagate_shape_and_type(node) input_shape = self._get_shape(node, 0) - N = ( - input_shape[0] - if input_shape is not None and len(input_shape) != 0 - else None - ) # noqa: N806 + N = input_shape[0] if input_shape is not None and len(input_shape) != 0 else None # noqa: N806 group = self._try_get_value(node, 6) output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type for i in [1, 2]: @@ -1709,11 +1508,7 @@ def _infer_aten_group_norm(self, node): node.output[i], output_dtype, [ - ( - N - if N is not None - else str(self._new_symbolic_dim_from_output(node, i, 0)) - ), + (N if N is not None else str(self._new_symbolic_dim_from_output(node, i, 0))), ( as_scalar(group) if group is not None @@ -1730,22 +1525,14 @@ def _infer_aten_upsample(self, node): new_shape = input_shape[:2] output_size = self._try_get_value(node, 1) if output_size is not None: - new_shape += [ - dim_size.item() if type(dim_size) == np.int64 else dim_size - for dim_size in output_size - ] + new_shape += [dim_size.item() if type(dim_size) == np.int64 else dim_size for dim_size in output_size] else: rank = len(input_shape) - new_shape += [ - str(self._new_symbolic_dim_from_output(node, 0, i)) - for i in range(2, rank) - ] + new_shape += [str(self._new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank)] if node.output[0] and new_shape is not None: output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, new_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) def _infer_BatchNormalization(self, node): # noqa: N802 self._propagate_shape_and_type(node) @@ -1787,11 +1574,7 @@ def _infer_ReduceSum(self, node): # noqa: N802 helper.make_tensor_value_info( node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape( - self._new_symbolic_shape( - self._get_shape_rank(node, 0), node - ) - ), + get_shape_from_sympy_shape(self._new_symbolic_shape(self._get_shape_rank(node, 0), node)), ) ) else: @@ -1831,9 +1614,7 @@ def _infer_RelativePositionBias(self, node): # noqa: N802 output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, new_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) def _infer_Reshape(self, node): # noqa: N802 shape_value = self._try_get_value(node, 1) @@ -1847,9 +1628,7 @@ def _infer_Reshape(self, node): # noqa: N802 helper.make_tensor_value_info( node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape( - self._new_symbolic_shape(shape_rank, node) - ), + get_shape_from_sympy_shape(self._new_symbolic_shape(shape_rank, node)), ) ) else: @@ -1895,10 +1674,7 @@ def _infer_Resize(self, node): # noqa: N802 if get_opset(self.out_mp_) <= 10: scales = self._try_get_value(node, 1) if scales is not None: - new_sympy_shape = [ - sympy.simplify(sympy.floor(d * s)) - for d, s in zip(input_sympy_shape, scales) - ] + new_sympy_shape = [sympy.simplify(sympy.floor(d * s)) for d, s in zip(input_sympy_shape, scales)] self._update_computed_dims(new_sympy_shape) vi.CopyFrom( helper.make_tensor_value_info( @@ -1916,10 +1692,7 @@ def _infer_Resize(self, node): # noqa: N802 self._update_computed_dims(new_sympy_shape) elif scales is not None: rank = len(scales) - if ( - get_attribute(node, "coordinate_transformation_mode") - == "tf_crop_and_resize" - ): + if get_attribute(node, "coordinate_transformation_mode") == "tf_crop_and_resize": assert len(roi) == 2 * rank roi_start = list(roi)[:rank] roi_end = list(roi)[rank:] @@ -1929,15 +1702,11 @@ def _infer_Resize(self, node): # noqa: N802 scales = list(scales) new_sympy_shape = [ sympy.simplify(sympy.floor(d * (end - start) * scale)) - for d, start, end, scale in zip( - input_sympy_shape, roi_start, roi_end, scales - ) + for d, start, end, scale in zip(input_sympy_shape, roi_start, roi_end, scales) ] self._update_computed_dims(new_sympy_shape) else: - new_sympy_shape = self._new_symbolic_shape( - self._get_shape_rank(node, 0), node - ) + new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node) vi.CopyFrom( helper.make_tensor_value_info( @@ -1965,31 +1734,19 @@ def _infer_Scan(self, node): # noqa: N802 si.CopyFrom(self.known_vi_[node.input[i]]) if i >= num_scan_states: scan_input_dim = si.type.tensor_type.shape.dim - scan_input_dim.remove( - scan_input_dim[scan_input_axes[i - num_scan_states]] - ) + scan_input_dim.remove(scan_input_dim[scan_input_axes[i - num_scan_states]]) si.name = subgraph_name self._onnx_infer_subgraph(node, subgraph) num_scan_outputs = len(node.output) - num_scan_states - scan_output_axes = get_attribute( - node, "scan_output_axes", [0] * num_scan_outputs - ) - scan_input_dim = get_shape_from_type_proto(self.known_vi_[node.input[-1]].type)[ - scan_input_axes[-1] - ] + scan_output_axes = get_attribute(node, "scan_output_axes", [0] * num_scan_outputs) + scan_input_dim = get_shape_from_type_proto(self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]] for i, o in enumerate(node.output): vi = self.known_vi_[o] if i >= num_scan_states: shape = get_shape_from_type_proto(subgraph.output[i].type) - new_dim = handle_negative_axis( - scan_output_axes[i - num_scan_states], len(shape) + 1 - ) + new_dim = handle_negative_axis(scan_output_axes[i - num_scan_states], len(shape) + 1) shape = shape[:new_dim] + [scan_input_dim] + shape[new_dim:] - vi.CopyFrom( - helper.make_tensor_value_info( - o, subgraph.output[i].type.tensor_type.elem_type, shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(o, subgraph.output[i].type.tensor_type.elem_type, shape)) else: vi.CopyFrom(subgraph.output[i]) vi.name = o @@ -2045,14 +1802,8 @@ def _infer_Slice(self, node): # noqa: N802 # # If the number of `min(...)` subexpressions is not exactly one, this function just returns `[expr]`. def flatten_min(expr): - assert isinstance( - expr, sympy.Add - ), f"Expected a sum of two arguments, got {expr}" - min_positions = [ - idx - for idx in range(len(expr.args)) - if isinstance(expr.args[idx], sympy.Min) - ] + assert isinstance(expr, sympy.Add), f"Expected a sum of two arguments, got {expr}" + min_positions = [idx for idx in range(len(expr.args)) if isinstance(expr.args[idx], sympy.Min)] if len(min_positions) == 1: min_pos = min_positions[0] @@ -2097,7 +1848,7 @@ def less_equal(x, y): return all(bool(d >= 0) for d in flatten_min(y - x)) def handle_negative_index(index, bound): - """normalizes a negative index to be in [0, bound)""" + """Normalizes a negative index to be in [0, bound)""" try: if not less_equal(0, index): if is_literal(index) and index <= -self.int_max_: @@ -2162,18 +1913,14 @@ def handle_negative_index(index, bound): if not less_equal(e, new_sympy_shape[i]): e = new_sympy_shape[i] # noqa: PLW2901 except Exception: - logger.warning( - f"Unable to determine if {e} <= {new_sympy_shape[i]}, treat as equal" - ) + logger.warning(f"Unable to determine if {e} <= {new_sympy_shape[i]}, treat as equal") e = new_sympy_shape[i] # noqa: PLW2901 s = handle_negative_index(s, new_sympy_shape[i]) # noqa: PLW2901 if is_literal(new_sympy_shape[i]) and is_literal(s): s = max(0, min(s, new_sympy_shape[i])) # noqa: PLW2901 - new_sympy_shape[i] = sympy.simplify( - (e - s + t + (-1 if t > 0 else 1)) // t - ) + new_sympy_shape[i] = sympy.simplify((e - s + t + (-1 if t > 0 else 1)) // t) self._update_computed_dims(new_sympy_shape) @@ -2201,9 +1948,7 @@ def handle_negative_index(index, bound): if type(input_sympy_data) == list or ( # noqa: E721 type(input_sympy_data) == np.array and len(input_sympy_data.shape) == 1 ): - self.sympy_data_[node.output[0]] = input_sympy_data[ - starts[0] : ends[0] : steps[0] - ] + self.sympy_data_[node.output[0]] = input_sympy_data[starts[0] : ends[0] : steps[0]] def _infer_SoftmaxCrossEntropyLoss(self, node): # noqa: N802 vi = self.known_vi_[node.output[0]] @@ -2224,9 +1969,7 @@ def _infer_SoftmaxCrossEntropyLoss(self, node): # noqa: N802 def _infer_Split_Common(self, node, make_value_info_func): # noqa: N802 input_sympy_shape = self._get_sympy_shape(node, 0) - axis = handle_negative_axis( - get_attribute(node, "axis", 0), len(input_sympy_shape) - ) + axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape)) op_set = get_opset(self.out_mp_) # Depending on op-version 'split' are provided as attribute or via 2nd input @@ -2250,11 +1993,7 @@ def _infer_Split_Common(self, node, make_value_info_func): # noqa: N802 make_value_info_func( node.output[i_o], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape( - input_sympy_shape[:axis] - + [split[i_o]] - + input_sympy_shape[axis + 1 :] - ), + get_shape_from_sympy_shape(input_sympy_shape[:axis] + [split[i_o]] + input_sympy_shape[axis + 1 :]), ) ) self.known_vi_[vi.name] = vi @@ -2283,9 +2022,7 @@ def _infer_Squeeze(self, node): # noqa: N802 # For symbolic dimensions we guess they are !=1. output_shape = [s for s in input_shape if s != 1] if self.verbose_ > 0: - symbolic_dimensions = [ - s for s in input_shape if type(s) != int - ] # noqa: E721 + symbolic_dimensions = [s for s in input_shape if type(s) != int] # noqa: E721 if len(symbolic_dimensions) > 0: logger.debug( f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " @@ -2298,9 +2035,7 @@ def _infer_Squeeze(self, node): # noqa: N802 if i not in axes: output_shape.append(input_shape[i]) else: - assert ( - input_shape[i] == 1 or type(input_shape[i]) != int - ) # noqa: E721 + assert input_shape[i] == 1 or type(input_shape[i]) != int # noqa: E721 if self.verbose_ > 0 and type(input_shape[i]) != int: # noqa: E721 logger.debug( f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " @@ -2327,9 +2062,7 @@ def _infer_Tile(self, node): # noqa: N802 new_sympy_shape.append(new_dim) self._update_computed_dims(new_sympy_shape) else: - new_sympy_shape = self._new_symbolic_shape( - self._get_shape_rank(node, 0), node - ) + new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node) vi = self.known_vi_[node.output[0]] vi.CopyFrom( helper.make_tensor_value_info( @@ -2366,11 +2099,7 @@ def _infer_TopK(self, node): # noqa: N802 for i_o in range(len(node.output)): vi = self.known_vi_[node.output[i_o]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[i_o], vi.type.tensor_type.elem_type, new_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[i_o], vi.type.tensor_type.elem_type, new_shape)) def _infer_Transpose(self, node): # noqa: N802 if node.input[0] in self.sympy_data_: @@ -2378,11 +2107,7 @@ def _infer_Transpose(self, node): # noqa: N802 perm = get_attribute(node, "perm", reversed(list(range(len(data_shape))))) input_data = self.sympy_data_[node.input[0]] self.sympy_data_[node.output[0]] = ( - np.transpose( - np.array(input_data).reshape(*data_shape), axes=tuple(perm) - ) - .flatten() - .tolist() + np.transpose(np.array(input_data).reshape(*data_shape), axes=tuple(perm)).flatten().tolist() ) def _infer_Unsqueeze(self, node): # noqa: N802 @@ -2430,9 +2155,7 @@ def _infer_ZipMap(self, node): # noqa: N802 assert map_key_type is not None new_vi = onnx.ValueInfoProto() new_vi.name = node.output[0] - new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = ( - onnx.TensorProto.FLOAT - ) + new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type vi = self.known_vi_[node.output[0]] vi.CopyFrom(new_vi) @@ -2443,9 +2166,7 @@ def _infer_Attention(self, node): # noqa: N802 shape_bias = self._try_get_shape(node, 2) if shape_bias is not None: assert len(shape_bias) == 1 - tripled_hidden_size = ( - shape_bias[0] if shape_bias is not None else shape_weights[1] - ) + tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1] if shape and len(shape) == 3: qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") if qkv_hidden_sizes_attr is not None: @@ -2455,9 +2176,7 @@ def _infer_Attention(self, node): # noqa: N802 shape[2] = int(tripled_hidden_size / 3) output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) if len(node.output) > 1: # input shape: (batch_size, sequence_length, hidden_size) @@ -2465,31 +2184,19 @@ def _infer_Attention(self, node): # noqa: N802 # mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len) # present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length input_shape = self._get_shape(node, 0) - past_shape = ( - self._get_shape(node, 4) - if len(node.input) > 4 and node.input[4] - else [] - ) - mask_shape = ( - self._get_shape(node, 3) - if len(node.input) > 3 and node.input[3] - else [] - ) + past_shape = self._get_shape(node, 4) if len(node.input) > 4 and node.input[4] else [] + mask_shape = self._get_shape(node, 3) if len(node.input) > 3 and node.input[3] else [] if past_shape and len(past_shape) == 5: if mask_shape and len(mask_shape) in [2, 3]: past_shape[3] = mask_shape[-1] elif input_shape and len(input_shape) == 3: - if isinstance(input_shape[1], int) and isinstance( - past_shape[3], int - ): + if isinstance(input_shape[1], int) and isinstance(past_shape[3], int): past_shape[3] = input_shape[1] + past_shape[3] else: past_shape[3] = f"{past_shape[3]}+{input_shape[1]}" vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, output_dtype, past_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) # No past input but present output still exists else: num_heads = get_attribute(node, "num_heads") @@ -2502,11 +2209,7 @@ def _infer_Attention(self, node): # noqa: N802 head_size, ] vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info( - vi.name, output_dtype, present_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) def _infer_GatedRelativePositionBias(self, node): # noqa: N802 # When padding is removed: @@ -2538,9 +2241,7 @@ def _infer_GatedRelativePositionBias(self, node): # noqa: N802 output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, output_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) def _infer_PackedAttention(self, node): # noqa: N802 shape = self._get_shape(node, 0) @@ -2548,9 +2249,7 @@ def _infer_PackedAttention(self, node): # noqa: N802 shape_bias = self._try_get_shape(node, 2) if shape_bias is not None: assert len(shape_bias) == 1 - tripled_hidden_size = ( - shape_bias[0] if shape_bias is not None else shape_weights[1] - ) + tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1] if shape and len(shape) == 2: qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") if qkv_hidden_sizes_attr is not None: @@ -2560,9 +2259,7 @@ def _infer_PackedAttention(self, node): # noqa: N802 shape[1] = int(tripled_hidden_size / 3) output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape)) def _infer_PackedMultiHeadAttention(self, node): # noqa: N802 shape_value = self._try_get_shape(node, 2) @@ -2575,51 +2272,32 @@ def _infer_PackedMultiHeadAttention(self, node): # noqa: N802 output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_dtype, output_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) def _infer_RemovePadding(self, node): # noqa: N802 shape = self._get_shape(node, 0) if shape and len(shape) == 3: output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], output_dtype, ["token_count", shape[2]] - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, ["token_count", shape[2]])) vi_token_offset = self.known_vi_[node.output[1]] vi_token_offset.CopyFrom( - helper.make_tensor_value_info( - node.output[1], onnx.TensorProto.INT32, [shape[0], shape[1]] - ) + helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, [shape[0], shape[1]]) ) vi_cumulated_seq_len = self.known_vi_[node.output[2]] vi_cumulated_seq_len.CopyFrom( - helper.make_tensor_value_info( - node.output[2], onnx.TensorProto.INT32, ["batch_size + 1"] - ) + helper.make_tensor_value_info(node.output[2], onnx.TensorProto.INT32, ["batch_size + 1"]) ) vi_max_seq_len = self.known_vi_[node.output[3]] - vi_max_seq_len.CopyFrom( - helper.make_tensor_value_info( - node.output[3], onnx.TensorProto.INT32, [1] - ) - ) + vi_max_seq_len.CopyFrom(helper.make_tensor_value_info(node.output[3], onnx.TensorProto.INT32, [1])) def _infer_RestorePadding(self, node): # noqa: N802 shape_input = self._get_shape(node, 0) shape_token_offset = self._get_shape(node, 1) - if ( - shape_input - and len(shape_input) == 2 - and shape_token_offset - and len(shape_token_offset) == 2 - ): + if shape_input and len(shape_input) == 2 and shape_token_offset and len(shape_token_offset) == 2: output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] @@ -2628,11 +2306,7 @@ def _infer_RestorePadding(self, node): # noqa: N802 shape_token_offset[1], shape_input[1], ] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], output_dtype, output_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) def _infer_BiasGelu(self, node): # noqa: N802 self._propagate_shape_and_type(node) @@ -2668,11 +2342,7 @@ def _infer_MultiHeadAttention(self, node): # noqa: N802 output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], output_dtype, output_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) elif len(query_shape) == 5: if isinstance(query_shape[2], int) and isinstance(query_shape[4], int): @@ -2692,11 +2362,7 @@ def _infer_MultiHeadAttention(self, node): # noqa: N802 output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], output_dtype, output_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) if len(node.output) > 1: batch_size = query_shape[0] @@ -2715,14 +2381,10 @@ def _infer_MultiHeadAttention(self, node): # noqa: N802 past_shape = self._try_get_shape(node, 6) if past_shape is not None: - if isinstance(past_shape[2], int) and isinstance( - total_sequence_length, int - ): + if isinstance(past_shape[2], int) and isinstance(total_sequence_length, int): total_sequence_length = past_shape[2] + total_sequence_length else: - total_sequence_length = ( - f"{past_shape[2]}+{total_sequence_length}" - ) + total_sequence_length = f"{past_shape[2]}+{total_sequence_length}" present_shape = [ batch_size, @@ -2734,17 +2396,9 @@ def _infer_MultiHeadAttention(self, node): # noqa: N802 assert output_dtype is not None if len(node.output) > 2 and node.output[1] and node.output[2]: vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info( - vi.name, output_dtype, present_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) vi = self.known_vi_[node.output[2]] - vi.CopyFrom( - helper.make_tensor_value_info( - vi.name, output_dtype, present_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape)) def _infer_DecoderMaskedMultiHeadAttention(self, node): # noqa: N802 # Output 0 has shape (batch_size, 1, v_hidden_size) @@ -2758,23 +2412,15 @@ def _infer_DecoderMaskedMultiHeadAttention(self, node): # noqa: N802 output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type assert output_dtype is not None vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], output_dtype, output_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) if len(node.output) > 2 and node.output[1] and node.output[2]: past_shape = self._try_get_shape(node, 5) if past_shape is not None: vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, output_dtype, past_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) vi = self.known_vi_[node.output[2]] - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, output_dtype, past_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) def _infer_FastGelu(self, node): # noqa: N802 self._propagate_shape_and_type(node) @@ -2803,24 +2449,13 @@ def _infer_LayerNormalization(self, node): # noqa: N802 axis = handle_negative_axis(axis, rank) mean_shape = x_shape[:axis] + [1 for _ in range(rank - axis)] mean_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - if ( - mean_dtype == onnx.TensorProto.FLOAT16 - or mean_dtype == onnx.TensorProto.BFLOAT16 - ): + if mean_dtype == onnx.TensorProto.FLOAT16 or mean_dtype == onnx.TensorProto.BFLOAT16: mean_dtype = onnx.TensorProto.FLOAT vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[1], mean_dtype, mean_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[1], mean_dtype, mean_shape)) if len(node.output) > 2: vi = self.known_vi_[node.output[2]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[2], mean_dtype, mean_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[2], mean_dtype, mean_shape)) def _infer_LongformerAttention(self, node): # noqa: N802 self._propagate_shape_and_type(node) @@ -2833,30 +2468,18 @@ def _infer_EmbedLayerNormalization(self, node): # noqa: N802 word_embedding_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[0], word_embedding_dtype, output_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], word_embedding_dtype, output_shape)) if len(node.output) > 1 and node.output[1]: mask_index_shape = [input_ids_shape[0]] vi = self.known_vi_[node.output[1]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[1], onnx.TensorProto.INT32, mask_index_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, mask_index_shape)) if len(node.output) > 2: # Optional output of add before layer normalization is done # shape is same as the output vi = self.known_vi_[node.output[2]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[2], word_embedding_dtype, output_shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[2], word_embedding_dtype, output_shape)) def _infer_SkipLayerNormalization(self, node): # noqa: N802 self._propagate_shape_and_type(node) @@ -2882,9 +2505,7 @@ def _infer_BiasSplitGelu(self, node): # noqa: N802 output_shape[2] = int(bias_shape[0] / 2) vi = self.known_vi_[node.output[0]] output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi.CopyFrom( - helper.make_tensor_value_info(vi.name, output_dtype, output_shape) - ) + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape)) def _infer_BiasAdd(self, node): # noqa: N802 self._propagate_shape_and_type(node) @@ -2895,26 +2516,18 @@ def _infer_RotaryEmbedding(self, node): # noqa: N802 elif len(node.output) == 2: # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions` self._propagate_shape_and_type(node, input_index=1, output_index=0) - self._propagate_shape_and_type( - node, input_index=0, output_index=1 - ) # true output + self._propagate_shape_and_type(node, input_index=0, output_index=1) # true output elif len(node.output) == 3: # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions` self._propagate_shape_and_type(node, input_index=1, output_index=0) self._propagate_shape_and_type(node, input_index=1, output_index=1) - self._propagate_shape_and_type( - node, input_index=0, output_index=2 - ) # true output + self._propagate_shape_and_type(node, input_index=0, output_index=2) # true output def _infer_PythonOp(self, node): # noqa: N802 output_tensor_types = get_attribute(node, "output_tensor_types") - assert ( - output_tensor_types - ), f"PythonOp '{node.name}' has no output_tensor_types attribute." + assert output_tensor_types, f"PythonOp '{node.name}' has no output_tensor_types attribute." output_tensor_ranks = get_attribute(node, "output_tensor_ranks") - assert ( - output_tensor_ranks - ), f"PythonOp '{node.name}' has no output_tensor_ranks attribute." + assert output_tensor_ranks, f"PythonOp '{node.name}' has no output_tensor_ranks attribute." from onnxruntime.capi._pybind_state import get_shape_inference_function @@ -2924,9 +2537,7 @@ def _infer_PythonOp(self, node): # noqa: N802 # Set the context output separately. # The first output is torch.autograd.Function''s context. vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, []) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [])) if shape_inferer is not None: input_shapes = [] @@ -2934,13 +2545,9 @@ def _infer_PythonOp(self, node): # noqa: N802 for input_index in range(len(node.input)): shape = self._get_shape(node, input_index) input_shapes.append(shape) - input_dtype = self.known_vi_[ - node.input[input_index] - ].type.tensor_type.elem_type + input_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type input_dtypes.append(input_dtype) - output_shapes, output_dtypes = shape_inferer( - node, input_shapes, input_dtypes - ) + output_shapes, output_dtypes = shape_inferer(node, input_shapes, input_dtypes) assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1), ( f"PythonOp '{func_name}' returned {len(output_shapes)} shapes and {len(output_dtypes)} dtypes, " f"but expected {len(node.output) - 1} outputs." @@ -2949,9 +2556,7 @@ def _infer_PythonOp(self, node): # noqa: N802 output_index = i + 1 vi = self.known_vi_[node.output[output_index]] vi.CopyFrom( - helper.make_tensor_value_info( - node.output[output_index], output_dtypes[i], output_shapes[i] - ) + helper.make_tensor_value_info(node.output[output_index], output_dtypes[i], output_shapes[i]) ) else: # General shape inference for PythonOp. @@ -2962,22 +2567,14 @@ def _infer_PythonOp(self, node): # noqa: N802 vi = self.known_vi_[node.output[i + 1]] sympy_shape = self._new_symbolic_shape(output_tensor_ranks[i], node) shape = get_shape_from_sympy_shape(sympy_shape) - value_info = helper.make_tensor_value_info( - node.output[i + 1], output_tensor_types[i], shape - ) + value_info = helper.make_tensor_value_info(node.output[i + 1], output_tensor_types[i], shape) vi.CopyFrom(value_info) def _propagate_shape_and_type(self, node, input_index=0, output_index=0): shape = self._get_shape(node, input_index) - output_dtype = self.known_vi_[ - node.input[input_index] - ].type.tensor_type.elem_type + output_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type vi = self.known_vi_[node.output[output_index]] - vi.CopyFrom( - helper.make_tensor_value_info( - node.output[output_index], output_dtype, shape - ) - ) + vi.CopyFrom(helper.make_tensor_value_info(node.output[output_index], output_dtype, shape)) def _is_none_dim(self, dim_value): if type(dim_value) != str: # noqa: E721 @@ -3012,13 +2609,9 @@ def _infer_impl(self, start_sympy_data=None): for i_dim, dim in enumerate(input_shape): if dim is None: # some models use None for symbolic dim in input, replace it with a string - input_dims[i_dim].dim_param = str( - self._new_symbolic_dim(i.name, i_dim) - ) + input_dims[i_dim].dim_param = str(self._new_symbolic_dim(i.name, i_dim)) - self.input_symbols_.update( - [d for d in input_shape if type(d) == str] - ) # noqa: E721 + self.input_symbols_.update([d for d in input_shape if type(d) == str]) # noqa: E721 for s in self.input_symbols_: if s in self.suggested_merge_: @@ -3035,11 +2628,9 @@ def _infer_impl(self, start_sympy_data=None): self.tmp_mp_.CopyFrom(self.out_mp_) self.tmp_mp_.graph.ClearField("initializer") - # compute prerequesite for node for topological sort + # compute prerequisite for node for topological sort # node with subgraphs may have dependency on implicit inputs, which will affect topological sort - prereq_for_node = ( - {} - ) # map from node to all its inputs, including implicit ones in subgraph + prereq_for_node = {} # map from node to all its inputs, including implicit ones in subgraph def get_prereq(node): names = {i for i in node.input if i} @@ -3057,13 +2648,7 @@ def get_prereq(node): for n in g.node: g_outputs_and_initializers.update(n.output) for n in g.node: - g_prereq.update( - [ - i - for i in get_prereq(n) - if i not in g_outputs_and_initializers - ] - ) + g_prereq.update([i for i in get_prereq(n) if i not in g_outputs_and_initializers]) names.update(g_prereq) # remove subgraph inputs from g_prereq since those are local-only for i in g.input: @@ -3076,26 +2661,16 @@ def get_prereq(node): # topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate sorted_nodes = [] - sorted_known_vi = { - i.name - for i in list(self.out_mp_.graph.input) - + list(self.out_mp_.graph.initializer) - } + sorted_known_vi = {i.name for i in list(self.out_mp_.graph.input) + list(self.out_mp_.graph.initializer)} if any([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): # Loop/Scan will have some graph output in graph inputs, so don't do topological sort sorted_nodes = self.out_mp_.graph.node else: - while not all( - [o.name in sorted_known_vi for o in self.out_mp_.graph.output] - ): + while not all([o.name in sorted_known_vi for o in self.out_mp_.graph.output]): old_sorted_nodes_len = len(sorted_nodes) for node in self.out_mp_.graph.node: if (node.output[0] not in sorted_known_vi) and all( - [ - i in sorted_known_vi - for i in prereq_for_node[node.output[0]] - if i - ] + [i in sorted_known_vi for i in prereq_for_node[node.output[0]] if i] ): sorted_known_vi.update(node.output) sorted_nodes.append(node) @@ -3121,11 +2696,7 @@ def get_prereq(node): for attr in node.attribute: # TODO: Is overload_name needed? if attr.name == "operator": - aten_op_name = ( - attr.s.decode("utf-8") - if isinstance(attr.s, bytes) - else attr.s - ) + aten_op_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s if aten_op_name in self.aten_op_dispatcher_: known_aten_op = True self.aten_op_dispatcher_[aten_op_name](node) @@ -3135,9 +2706,7 @@ def get_prereq(node): logger.debug(node.op_type + ": " + node.name) for i, name in enumerate(node.input): logger.debug( - " Input {}: {} {}".format( - i, name, "initializer" if name in self.initializers_ else "" - ) + " Input {}: {} {}".format(i, name, "initializer" if name in self.initializers_ else "") ) # onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb'] @@ -3156,20 +2725,8 @@ def get_prereq(node): vi = self.known_vi_[node.output[0]] out_rank = len(get_shape_from_type_proto(vi.type)) in_shapes = [self._get_shape(node, i) for i in range(len(node.input))] - for d in range( - out_rank - - ( - 2 - if node.op_type - in ["MatMul", "MatMulInteger", "MatMulInteger16"] - else 0 - ) - ): - in_dims = [ - s[len(s) - out_rank + d] - for s in in_shapes - if len(s) + d >= out_rank - ] + for d in range(out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0)): + in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank] if len(in_dims) > 1: self._check_merged_dims(in_dims, allow_broadcast=True) @@ -3180,8 +2737,7 @@ def get_prereq(node): # the RotaryEmbedding op created during export can be replaced by the RotaryEmbedding # contrib op if ( - node.op_type == "SkipLayerNormalization" - or node.op_type == "SkipSimplifiedLayerNormalization" + node.op_type == "SkipLayerNormalization" or node.op_type == "SkipSimplifiedLayerNormalization" ) and i_o in [1, 2]: continue if node.op_type == "RotaryEmbedding" and len(node.output) > 1: @@ -3197,9 +2753,7 @@ def get_prereq(node): if out_type_kind not in ["tensor_type", "sparse_tensor_type", None]: if self.verbose_ > 2: if out_type_kind == "sequence_type": - seq_cls_type = out_type.sequence_type.elem_type.WhichOneof( - "value" - ) + seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value") if seq_cls_type == "tensor_type": logger.debug( " {}: sequence of {} {}".format( @@ -3211,38 +2765,27 @@ def get_prereq(node): ) ) else: - logger.debug( - f" {node.output[i_o]}: sequence of {seq_cls_type}" - ) + logger.debug(f" {node.output[i_o]}: sequence of {seq_cls_type}") else: logger.debug(f" {node.output[i_o]}: {out_type_kind}") continue out_shape = get_shape_from_value_info(vi) - out_type_undefined = ( - out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED - ) + out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED if self.verbose_ > 2: logger.debug( " {}: {} {}".format( node.output[i_o], str(out_shape), - onnx.TensorProto.DataType.Name( - vi.type.tensor_type.elem_type - ), + onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type), ) ) if node.output[i_o] in self.sympy_data_: - logger.debug( - " Sympy Data: " + str(self.sympy_data_[node.output[i_o]]) - ) + logger.debug(" Sympy Data: " + str(self.sympy_data_[node.output[i_o]])) # onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain if ( - out_shape is not None - and ( - None in out_shape or self._is_shape_contains_none_dim(out_shape) - ) + out_shape is not None and (None in out_shape or self._is_shape_contains_none_dim(out_shape)) ) or out_type_undefined: if self.auto_merge_: if node.op_type in [ @@ -3264,36 +2807,21 @@ def get_prereq(node): "Min", "Max", ]: - shapes = [ - self._get_shape(node, i) for i in range(len(node.input)) - ] + shapes = [self._get_shape(node, i) for i in range(len(node.input))] if node.op_type in [ "MatMul", "MatMulInteger", "MatMulInteger16", ]: - if ( - None in out_shape - or self._is_shape_contains_none_dim(out_shape) - ): + if None in out_shape or self._is_shape_contains_none_dim(out_shape): if None in out_shape: idx = out_shape.index(None) else: - idx = out_shape.index( - self._is_shape_contains_none_dim(out_shape) - ) - dim_idx = [ - len(s) - len(out_shape) + idx for s in shapes - ] + idx = out_shape.index(self._is_shape_contains_none_dim(out_shape)) + dim_idx = [len(s) - len(out_shape) + idx for s in shapes] # only support auto merge for MatMul for dim < rank-2 when rank > 2 - assert ( - len(shapes[0]) > 2 - and dim_idx[0] < len(shapes[0]) - 2 - ) - assert ( - len(shapes[1]) > 2 - and dim_idx[1] < len(shapes[1]) - 2 - ) + assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2 + assert len(shapes[1]) > 2 and dim_idx[1] < len(shapes[1]) - 2 elif node.op_type == "Expand": # auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq]) shapes = [ @@ -3305,15 +2833,11 @@ def get_prereq(node): if shapes: for idx in range(len(out_shape)): - if out_shape[idx] is not None and not self._is_none_dim( - out_shape[idx] - ): + if out_shape[idx] is not None and not self._is_none_dim(out_shape[idx]): continue # note that the broadcasting rule aligns from right to left # if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge - dim_idx = [ - len(s) - len(out_shape) + idx for s in shapes - ] + dim_idx = [len(s) - len(out_shape) + idx for s in shapes] if len(dim_idx) > 0: self._add_suggested_merge( [ @@ -3329,22 +2853,12 @@ def get_prereq(node): self.run_ = False # create new dynamic dims for ops not handled by symbolic shape inference - if ( - self.run_ is False - and node.op_type not in self.dispatcher_ - and not known_aten_op - ): - is_unknown_op = out_type_undefined and ( - out_shape is None or len(out_shape) == 0 - ) + if self.run_ is False and node.op_type not in self.dispatcher_ and not known_aten_op: + is_unknown_op = out_type_undefined and (out_shape is None or len(out_shape) == 0) if is_unknown_op: # unknown op to ONNX, maybe from higher opset or other domain # only guess the output rank from input 0 when using guess_output_rank option - out_rank = ( - self._get_shape_rank(node, 0) - if self.guess_output_rank_ - else -1 - ) + out_rank = self._get_shape_rank(node, 0) if self.guess_output_rank_ else -1 else: # valid ONNX op, but not handled by symbolic shape inference, just assign dynamic shape out_rank = len(out_shape) @@ -3353,9 +2867,7 @@ def get_prereq(node): new_shape = self._new_symbolic_shape(out_rank, node, i_o) if out_type_undefined: # guess output data type from input vi if not defined - out_dtype = self.known_vi_[ - node.input[0] - ].type.tensor_type.elem_type + out_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type else: # otherwise, use original data type out_dtype = vi.type.tensor_type.elem_type @@ -3386,12 +2898,7 @@ def get_prereq(node): continue # continue the inference after guess, no need to stop as no merge is needed if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined: - logger.debug( - "Stopping at incomplete shape inference at " - + node.op_type - + ": " - + node.name - ) + logger.debug("Stopping at incomplete shape inference at " + node.op_type + ": " + node.name) logger.debug("node inputs:") for i in node.input: if i in self.known_vi_: @@ -3417,16 +2924,12 @@ def _update_output_from_vi(self): output.CopyFrom(self.known_vi_[output.name]) @staticmethod - def infer_shapes( - in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0 - ): + def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0): onnx_opset = get_opset(in_mp) if (not onnx_opset) or onnx_opset < 7: logger.warning("Only support models of onnx opset 7 and above.") return None - symbolic_shape_inference = SymbolicShapeInference( - int_max, auto_merge, guess_output_rank, verbose - ) + symbolic_shape_inference = SymbolicShapeInference(int_max, auto_merge, guess_output_rank, verbose) all_shapes_inferred = False symbolic_shape_inference._preprocess(in_mp) while symbolic_shape_inference.run_: diff --git a/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py b/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py index ba4f521..6c7d704 100644 --- a/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py +++ b/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py @@ -83,16 +83,13 @@ def update_import_domains(graph): DEFAULT_CUSTOM_OPSET_VERSION = 1 for used_domain in all_used_domains: if used_domain not in current_domains: - graph.import_domains.append( - onnx.helper.make_opsetid(used_domain, DEFAULT_CUSTOM_OPSET_VERSION) - ) + graph.import_domains.append(onnx.helper.make_opsetid(used_domain, DEFAULT_CUSTOM_OPSET_VERSION)) current_domains.add(used_domain) return graph.import_domains # Converts a fp32 gs.Constant to a bf16 onnx.TensorProto def tensor_to_onnx_bf16(tensor: Constant): - # Converts the fp32 numpy array to bf16 values and store in a uint16 numpy array def np_float32_to_bf16_as_uint16(arr): new_arr = np.empty(arr.size, dtype=np.uint16) @@ -142,9 +139,7 @@ def export_sparse_tensor_proto(tensor: Constant) -> onnx.SparseTensorProto: return tensor._values.tensor @staticmethod - def export_value_info_proto( - tensor: Tensor, do_type_check: bool - ) -> onnx.ValueInfoProto: + def export_value_info_proto(tensor: Tensor, do_type_check: bool) -> onnx.ValueInfoProto: if do_type_check and tensor.dtype is None: G_LOGGER.critical( "Graph input and output tensors must include dtype information. Please set the dtype attribute for: {:}".format( @@ -154,9 +149,7 @@ def export_value_info_proto( if tensor.dtype is not None: if isinstance(tensor, Constant) or tensor.type == "tensor_type": - onnx_tensor = onnx.helper.make_tensor_value_info( - tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape - ) + onnx_tensor = onnx.helper.make_tensor_value_info(tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape) elif tensor.type == "sequence_type": onnx_tensor = onnx.helper.make_tensor_sequence_value_info( tensor.name, dtype_to_onnx(tensor.dtype), tensor.shape @@ -186,9 +179,7 @@ def export_attributes(attrs: dict) -> List[onnx.AttributeProto]: # Netron has a bug which makes it crash if a Tensor attribute has no tensor data. # So provide some meaningless tensor data for Netron to read. if val.type == Tensor: - tensor_proto = OnnxExporter.export_tensor_proto( - Constant("", np.array([0], dtype=np.float32)) - ) + tensor_proto = OnnxExporter.export_tensor_proto(Constant("", np.array([0], dtype=np.float32))) onnx_attr.t.CopyFrom(tensor_proto) onnx_attr.ref_attr_name = val.name @@ -232,9 +223,7 @@ def export_function(func: Function) -> onnx.FunctionProto: for tensor in func.tensors().values(): if isinstance(tensor, Constant): # Copying the tensor prevents the new node from appearing in the Constant tensor's inputs. - new_const_nodes.append( - Node("Constant", attrs={"value": tensor}, outputs=[tensor.copy()]) - ) + new_const_nodes.append(Node("Constant", attrs={"value": tensor}, outputs=[tensor.copy()])) # Const nodes have no inputs, so this maintains a topological ordering. func_nodes = new_const_nodes + func_nodes @@ -281,20 +270,13 @@ def export_graph(graph: Graph, do_type_check=True) -> onnx.GraphProto: """ check_duplicate_node_names(graph.nodes, level=G_LOGGER.WARNING) nodes = [OnnxExporter.export_node(node) for node in graph.nodes] - inputs = [ - OnnxExporter.export_value_info_proto(inp, do_type_check) - for inp in graph.inputs - ] - outputs = [ - OnnxExporter.export_value_info_proto(out, do_type_check) - for out in graph.outputs - ] + inputs = [OnnxExporter.export_value_info_proto(inp, do_type_check) for inp in graph.inputs] + outputs = [OnnxExporter.export_value_info_proto(out, do_type_check) for out in graph.outputs] tensor_map = graph.tensors() initializer = [ OnnxExporter.export_tensor_proto(tensor) for tensor in tensor_map.values() - if isinstance(tensor, Constant) - and not isinstance(tensor._values, SparseValues) + if isinstance(tensor, Constant) and not isinstance(tensor._values, SparseValues) ] sparse_initializer = [ OnnxExporter.export_sparse_tensor_proto(tensor) @@ -309,9 +291,7 @@ def export_graph(graph: Graph, do_type_check=True) -> onnx.GraphProto: # Omit tensors from value_info if we don't know their shape/dtype def has_value_info(tensor): - return isinstance(tensor, Variable) and ( - tensor.dtype is not None or tensor.shape is not None - ) + return isinstance(tensor, Variable) and (tensor.dtype is not None or tensor.shape is not None) value_info = [ OnnxExporter.export_value_info_proto(tensor, do_type_check) diff --git a/onnxslim/onnx_graphsurgeon/graph_pattern/graph_pattern.py b/onnxslim/onnx_graphsurgeon/graph_pattern/graph_pattern.py index fdb8e09..c54404e 100644 --- a/onnxslim/onnx_graphsurgeon/graph_pattern/graph_pattern.py +++ b/onnxslim/onnx_graphsurgeon/graph_pattern/graph_pattern.py @@ -22,9 +22,7 @@ class PatternMapping(dict): - """ - Represents a graph pattern mapping result. - """ + """Represents a graph pattern mapping result.""" def __init__(self, onnx_node=None) -> None: super().__init__() @@ -42,10 +40,7 @@ def set_input_onnx_tensor(self, onnx_tensor, index): length = len(self.inputs) for _ in range(index - length + 1): self.inputs.append(None) - if ( - self.inputs[index] is not None - and self.inputs[index].name != onnx_tensor.name - ): + if self.inputs[index] is not None and self.inputs[index].name != onnx_tensor.name: return False # This input tensor has been set up by another onnx tensor self.inputs[index] = onnx_tensor return True @@ -54,10 +49,7 @@ def set_output_onnx_tensor(self, onnx_tensor, index): length = len(self.outputs) for _ in range(index - length + 1): self.outputs.append(None) - if ( - self.outputs[index] is not None - and self.outputs[index].name != onnx_tensor.name - ): + if self.outputs[index] is not None and self.outputs[index].name != onnx_tensor.name: return False # This output tensor has been set up by another onnx tensor self.outputs[index] = onnx_tensor return True @@ -88,13 +80,7 @@ def get(self, name: str): def __str__(self) -> str: if self.onnx_node is None: - return ( - "{" - + str.join( - ", ", [f"{key}: {str(value)}" for key, value in self.items()] - ) - + "}" - ) + return "{" + str.join(", ", [f"{key}: {str(value)}" for key, value in self.items()]) + "}" return self.onnx_node.name @@ -252,22 +238,16 @@ def _single_node_match(self, onnx_node: Node) -> bool: if not self.check_func(onnx_node): G_LOGGER.info("No match because: check_func returned false.") return False - G_LOGGER.info( - "Single node is matched: {:}, {:}".format(self.op, onnx_node.name) - ) + G_LOGGER.info("Single node is matched: {:}, {:}".format(self.op, onnx_node.name)) return True - def _get_tensor_index_for_node( - self, node: str, tensor_id: int, is_node_input: bool - ): + def _get_tensor_index_for_node(self, node: str, tensor_id: int, is_node_input: bool): if is_node_input: return self.node_inputs[node].index(tensor_id) else: return self.node_outputs[node].index(tensor_id) - def get_inbound_or_outbound_onnx_node( - self, mapping: PatternMapping, is_inbound: bool, tensor_index: int - ): + def get_inbound_or_outbound_onnx_node(self, mapping: PatternMapping, is_inbound: bool, tensor_index: int): if self.op is not None: onnx_node = mapping._get_node() return onnx_node @@ -277,9 +257,7 @@ def get_inbound_or_outbound_onnx_node( return self.nodes[inbound_node].get_inbound_or_outbound_onnx_node( mapping[inbound_node], is_inbound=True, - tensor_index=self._get_tensor_index_for_node( - inbound_node, inbound_tensor, is_node_input=True - ), + tensor_index=self._get_tensor_index_for_node(inbound_node, inbound_tensor, is_node_input=True), ) else: @@ -288,9 +266,7 @@ def get_inbound_or_outbound_onnx_node( return self.nodes[outbound_node].get_inbound_or_outbound_onnx_node( mapping[outbound_node], is_inbound=False, - tensor_index=self._get_tensor_index_for_node( - outbound_node, outbound_tensor, is_node_input=False - ), + tensor_index=self._get_tensor_index_for_node(outbound_node, outbound_tensor, is_node_input=False), ) return None @@ -346,14 +322,8 @@ def _match_node( from_inbound: bool, ) -> bool: with G_LOGGER.indent(): - G_LOGGER.info( - "Checking node: {:} against pattern node: {:}.".format( - onnx_node.name, node_name - ) - ) - tensor_index_for_node = self._get_tensor_index_for_node( - node_name, from_tensor, is_node_input=from_inbound - ) + G_LOGGER.info("Checking node: {:} against pattern node: {:}.".format(onnx_node.name, node_name)) + tensor_index_for_node = self._get_tensor_index_for_node(node_name, from_tensor, is_node_input=from_inbound) subgraph_mapping = self.nodes[node_name].match( onnx_node, from_inbound, @@ -369,40 +339,30 @@ def _match_node( input_onnx_tensors = subgraph_mapping.inputs if len(input_onnx_tensors) != len(self.node_inputs[node_name]): return False # Number of node inputs should equal to number of input onnx tensors of the node. - for node_input_tensor, onnx_tensor in zip( - self.node_inputs[node_name], input_onnx_tensors - ): + for node_input_tensor, onnx_tensor in zip(self.node_inputs[node_name], input_onnx_tensors): if onnx_tensor is None: return False # tensor paired up. if node_input_tensor in self.input_tensors: - if not mapping.set_input_onnx_tensor( - onnx_tensor, self.input_tensors.index(node_input_tensor) - ): + if not mapping.set_input_onnx_tensor(onnx_tensor, self.input_tensors.index(node_input_tensor)): return False # this tensor is mapped to another onnx tensor continue if node_input_tensor in self.constant_tensors: if not isinstance(onnx_tensor, Constant): return False # constant tensor not match - if not mapping.set_constant_onnx_tensor( - onnx_tensor, self.constant_tensors[node_input_tensor] - ): + if not mapping.set_constant_onnx_tensor(onnx_tensor, self.constant_tensors[node_input_tensor]): # this constant tensor is mapped to another onnx tensor return False continue if len(self.tensor_inputs[node_input_tensor]) != len(onnx_tensor.inputs): return False - for input_node, input_onnx_node in zip( - self.tensor_inputs[node_input_tensor], onnx_tensor.inputs - ): + for input_node, input_onnx_node in zip(self.tensor_inputs[node_input_tensor], onnx_tensor.inputs): # dfs ends when revisiting a node. We need to check if the edges are matched. if input_node in mapping: outbound_tensor_index = self._get_tensor_index_for_node( input_node, node_input_tensor, is_node_input=False ) - outbound_onnx_node_of_input_node = self.nodes[ - input_node - ].get_inbound_or_outbound_onnx_node( + outbound_onnx_node_of_input_node = self.nodes[input_node].get_inbound_or_outbound_onnx_node( mapping[input_node], is_inbound=False, tensor_index=outbound_tensor_index, @@ -428,16 +388,12 @@ def _match_node( output_onnx_tensors = subgraph_mapping.outputs if len(output_onnx_tensors) != len(self.node_outputs[node_name]): return False # Number of node outputs should be equal to number of output onnx tensors of the node. - for node_output_tensor, onnx_tensor in zip( - self.node_outputs[node_name], output_onnx_tensors - ): + for node_output_tensor, onnx_tensor in zip(self.node_outputs[node_name], output_onnx_tensors): if onnx_tensor is None: return False # tensor matched if node_output_tensor in self.output_tensors: - if not mapping.set_output_onnx_tensor( - onnx_tensor, self.output_tensors.index(node_output_tensor) - ): + if not mapping.set_output_onnx_tensor(onnx_tensor, self.output_tensors.index(node_output_tensor)): return False # this tensor is mapped to another onnx tensor continue if onnx_tensor.name in onnx_graph_output_tensors: @@ -446,25 +402,20 @@ def _match_node( # For sub-patterns, each input tensor can only have 1 output node. Otherwise the following test will fail. if len(self.tensor_outputs[node_output_tensor]) != len(onnx_tensor.outputs): return False - for output_node, output_onnx_node in zip( - self.tensor_outputs[node_output_tensor], onnx_tensor.outputs - ): + for output_node, output_onnx_node in zip(self.tensor_outputs[node_output_tensor], onnx_tensor.outputs): # dfs ends when revisiting a node. We need to check if the edges are matched. if output_node in mapping: inbound_tensor_index = self._get_tensor_index_for_node( output_node, node_output_tensor, is_node_input=True ) - inbound_onnx_node_of_output_node = self.nodes[ - output_node - ].get_inbound_or_outbound_onnx_node( + inbound_onnx_node_of_output_node = self.nodes[output_node].get_inbound_or_outbound_onnx_node( mapping[output_node], is_inbound=True, tensor_index=inbound_tensor_index, ) if ( inbound_onnx_node_of_output_node is None - or inbound_onnx_node_of_output_node.name - != output_onnx_node.name + or inbound_onnx_node_of_output_node.name != output_onnx_node.name ): return False continue diff --git a/onnxslim/onnx_graphsurgeon/importers/onnx_importer.py b/onnxslim/onnx_graphsurgeon/importers/onnx_importer.py index f2bcc7f..66b0273 100644 --- a/onnxslim/onnx_graphsurgeon/importers/onnx_importer.py +++ b/onnxslim/onnx_graphsurgeon/importers/onnx_importer.py @@ -53,13 +53,9 @@ } -def get_onnx_tensor_shape( - onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto] -) -> List[int]: +def get_onnx_tensor_shape(onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto]) -> List[int]: shape = None - if isinstance(onnx_tensor, onnx.TensorProto) or isinstance( - onnx_tensor, onnx.SparseTensorProto - ): + if isinstance(onnx_tensor, onnx.TensorProto) or isinstance(onnx_tensor, onnx.SparseTensorProto): shape = onnx_tensor.dims else: if onnx_tensor.type.tensor_type.HasField("shape"): @@ -111,16 +107,13 @@ def get_numpy_type(onnx_type): # TENSOR_TYPE_TO_NP_TYPE maps types unsupported by NumPy to random other types. # This obviously breaks things, so we need to treat this as a special case. - if ( - onnx_type not in numpy_unsupported_types - and onnx_type in onnx.helper.get_all_tensor_dtypes() - ): + if onnx_type not in numpy_unsupported_types and onnx_type in onnx.helper.get_all_tensor_dtypes(): return onnx.helper.tensor_dtype_to_np_dtype(onnx_type) return None def get_onnx_tensor_dtype( - onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto] + onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto], ) -> Union[np.dtype, "onnx.TensorProto.DataType"]: if isinstance(onnx_tensor, onnx.TensorProto): onnx_dtype = onnx_tensor.data_type @@ -152,9 +145,7 @@ def get_onnx_tensor_dtype( return onnx_dtype -def get_onnx_tensor_type( - onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto] -) -> str: +def get_onnx_tensor_type(onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto]) -> str: if isinstance(onnx_tensor, onnx.TensorProto): onnx_type = "tensor_type" else: @@ -176,9 +167,7 @@ def get_onnx_tensor_type( return onnx_type -def get_onnx_tensor_type( - onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto] -) -> str: +def get_onnx_tensor_type(onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto]) -> str: if isinstance(onnx_tensor, onnx.TensorProto): onnx_type = "tensor_type" else: @@ -203,21 +192,15 @@ def get_onnx_tensor_type( class OnnxImporter(BaseImporter): @staticmethod def get_opset(model_or_func: Union[onnx.ModelProto, onnx.FunctionProto]): - class_name = ( - "Function" if isinstance(model_or_func, onnx.FunctionProto) else "Model" - ) + class_name = "Function" if isinstance(model_or_func, onnx.FunctionProto) else "Model" try: for importer in OnnxImporter.get_import_domains(model_or_func): if importer.domain == "" or importer.domain == "ai.onnx": return importer.version - G_LOGGER.warning( - f"{class_name} does not contain ONNX domain opset information! Using default opset." - ) + G_LOGGER.warning(f"{class_name} does not contain ONNX domain opset information! Using default opset.") return None except: - G_LOGGER.warning( - f"{class_name} does not contain opset information! Using default opset." - ) + G_LOGGER.warning(f"{class_name} does not contain opset information! Using default opset.") return None @staticmethod @@ -225,11 +208,7 @@ def get_import_domains(model_or_func: Union[onnx.ModelProto, onnx.FunctionProto] return model_or_func.opset_import @staticmethod - def import_tensor( - onnx_tensor: Union[ - onnx.ValueInfoProto, onnx.TensorProto, onnx.SparseTensorProto - ] - ) -> Tensor: + def import_tensor(onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto, onnx.SparseTensorProto]) -> Tensor: if isinstance(onnx_tensor, onnx.SparseTensorProto): return Constant( name=onnx_tensor.values.name, @@ -237,11 +216,7 @@ def import_tensor( data_location=onnx_tensor.values.data_location, ) elif isinstance(onnx_tensor, onnx.TensorProto): - data_location = ( - int(onnx_tensor.data_location) - if onnx_tensor.HasField("data_location") - else None - ) + data_location = int(onnx_tensor.data_location) if onnx_tensor.HasField("data_location") else None return Constant( name=onnx_tensor.name, values=LazyValues(onnx_tensor), @@ -295,9 +270,7 @@ def process_attr(attr_str: str): attr_dict[attr.name] = process_attr(attr_str) else: G_LOGGER.warning( - "Attribute of type {:} is currently unsupported. Skipping attribute.".format( - attr_str - ) + "Attribute of type {:} is currently unsupported. Skipping attribute.".format(attr_str) ) else: G_LOGGER.warning( @@ -371,9 +344,7 @@ def import_function( model_import_domains: onnx.OperatorSetIdProto = None, ) -> Function: opset = OnnxImporter.get_opset(onnx_function) or model_opset - import_domains = ( - OnnxImporter.get_import_domains(onnx_function) or model_import_domains - ) + import_domains = OnnxImporter.get_import_domains(onnx_function) or model_import_domains subgraph_tensor_map = OrderedDict() # Tensors in this function def make_tensor(name: str) -> Tensor: @@ -384,9 +355,7 @@ def make_tensor(name: str) -> Tensor: function_inputs = [make_tensor(inp) for inp in onnx_function.input] function_outputs = [make_tensor(out) for out in onnx_function.output] nodes = [ - OnnxImporter.import_node( - onnx_node, dict(), subgraph_tensor_map, opset, import_domains - ) + OnnxImporter.import_node(onnx_node, dict(), subgraph_tensor_map, opset, import_domains) for onnx_node in onnx_function.node ] @@ -438,18 +407,14 @@ def import_graph( functions (List[Function]): The list of custom functions which are available to use in the model. """ functions = misc.default_value(functions, []) - tensor_map = copy.copy( - misc.default_value(tensor_map, OrderedDict()) - ) # Outer graph tensors, read-only + tensor_map = copy.copy(misc.default_value(tensor_map, OrderedDict())) # Outer graph tensors, read-only subgraph_tensor_map = OrderedDict() # Tensors in this subgraph # Retrieves a Tensor from subgraph_tensor_map or the outer graph (tensor_map) if present, otherwise imports the tensor # If overwrite=True, this function will overwrite previously imported tensors # if the new tensor has more information available. def get_tensor( - onnx_tensor: Union[ - onnx.ValueInfoProto, onnx.TensorProto, onnx.SparseTensorProto - ], + onnx_tensor: Union[onnx.ValueInfoProto, onnx.TensorProto, onnx.SparseTensorProto], overwrite=False, check_outer_graph=True, ) -> Tensor: @@ -462,12 +427,8 @@ def get_tensor( if overwrite: tensor = OnnxImporter.import_tensor(onnx_tensor) if isinstance(subgraph_tensor_map[name], Variable): - subgraph_tensor_map[name].dtype = ( - subgraph_tensor_map[name].dtype or tensor.dtype - ) - subgraph_tensor_map[name].shape = ( - subgraph_tensor_map[name].shape or tensor.shape - ) + subgraph_tensor_map[name].dtype = subgraph_tensor_map[name].dtype or tensor.dtype + subgraph_tensor_map[name].shape = subgraph_tensor_map[name].shape or tensor.shape return subgraph_tensor_map[name] if check_outer_graph and name in tensor_map: @@ -512,9 +473,7 @@ def get_tensor( G_LOGGER.verbose("Importing nodes") nodes = [] # List[Node] for onnx_node in onnx_graph.node: - node = OnnxImporter.import_node( - onnx_node, tensor_map, subgraph_tensor_map, opset, import_domains - ) + node = OnnxImporter.import_node(onnx_node, tensor_map, subgraph_tensor_map, opset, import_domains) nodes.append(node) return Graph( diff --git a/onnxslim/onnx_graphsurgeon/ir/function.py b/onnxslim/onnx_graphsurgeon/ir/function.py index aa2bd38..a6feae5 100644 --- a/onnxslim/onnx_graphsurgeon/ir/function.py +++ b/onnxslim/onnx_graphsurgeon/ir/function.py @@ -27,15 +27,14 @@ class Function(Graph): """ - Represents a local function, which is a default implementation of a Custom Op. - This default implementation is represented as a Graph of other Ops. + Represents a local function, which is a default implementation of a Custom Op. This default implementation is + represented as a Graph of other Ops. Functions are used in a model by creating a Node with the same name and domain as the function. This can be done - using the __call__() method of a Function, which creates this new node and appends it to a Graph. - A Function is not a subgraph of a Graph, and its Nodes, Tensors, and subgraphs are entirely separate - from the main Graph. + using the __call__() method of a Function, which creates this new node and appends it to a Graph. A Function is not + a subgraph of a Graph, and its Nodes, Tensors, and subgraphs are entirely separate from the main Graph. - Functions can be composed of other functions, but cyclical or recursive defintions are not allowed in ONNX. + Functions can be composed of other functions, but cyclical or recursive definitions are not allowed in ONNX. """ DEFAULT_DOMAIN = "onnx_graphsurgeon" @@ -90,9 +89,7 @@ def __init__( @property def unique_id(self): - """ - Returns a tuple which uniquely identifies this function. - """ + """Returns a tuple which uniquely identifies this function.""" return (self.domain, self.name) def cleanup( @@ -102,9 +99,8 @@ def cleanup( remove_unused_graph_inputs=False, recurse_functions=False, ): - """ - See Graph.cleanup() - The only difference is that 'recurse_functions' defaults to False, so that only this Function is cleaned up. + """See Graph.cleanup() The only difference is that 'recurse_functions' defaults to False, so that only this + Function is cleaned up. """ if recurse_functions: G_LOGGER.warning( @@ -118,9 +114,8 @@ def cleanup( ) def fold_constants(self, recurse_functions=False, **kwargs): - """ - See Graph.fold_constants() - The only difference is that 'recurse_functions' defaults to False, so that only this Function's constants are folded. + """See Graph.fold_constants() The only difference is that 'recurse_functions' defaults to False, so that only + this Function's constants are folded. """ if recurse_functions: G_LOGGER.warning( @@ -134,10 +129,8 @@ def toposort( recurse_functions=False, mode="nodes", ): - """ - See Graph.toposort() - The only difference is that 'recurse_functions' defaults to False and mode defaults to "nodes", - so that by default only this function's nodes will be sorted. + """See Graph.toposort() The only difference is that 'recurse_functions' defaults to False and mode defaults to + "nodes", so that by default only this function's nodes will be sorted. """ if recurse_functions: G_LOGGER.warning( @@ -149,12 +142,10 @@ def toposort( mode=mode, ) - def __call__( - self, graph, inputs=None, outputs=None, *args, **kwargs - ) -> List[Tensor]: + def __call__(self, graph, inputs=None, outputs=None, *args, **kwargs) -> List[Tensor]: """ - Creates a Node which is an instance of this function. - The created node can be used in a Graph or another Function. + Creates a Node which is an instance of this function. The created node can be used in a Graph or another + Function. The provided inputs are processed the same way as in Graph.layer(). If outputs are not provided, they are created based on the Function's outputs. @@ -171,12 +162,8 @@ def __call__( List[Tensor]: The output tensors of the node. """ if inputs is not None and len(inputs) != len(self.inputs): - msg_template = ( - "Function {} expects {} inputs, but was called with {} inputs." - ) - G_LOGGER.warning( - msg_template.format(self.name, len(self.inputs), len(inputs)) - ) + msg_template = "Function {} expects {} inputs, but was called with {} inputs." + G_LOGGER.warning(msg_template.format(self.name, len(self.inputs), len(inputs))) new_output_indices = [] if outputs is None: @@ -184,16 +171,10 @@ def __call__( outputs = [out.name for out in self.outputs] new_output_indices = list(range(len(outputs))) elif len(outputs) != len(self.outputs): - msg_template = ( - "Function {} expects {} outputs, but was called with {} outputs." - ) - G_LOGGER.warning( - msg_template.format(self.name, len(self.outputs), len(outputs)) - ) + msg_template = "Function {} expects {} outputs, but was called with {} outputs." + G_LOGGER.warning(msg_template.format(self.name, len(self.outputs), len(outputs))) else: - new_output_indices = [ - i for i in range(len(outputs)) if not isinstance(outputs[i], Tensor) - ] + new_output_indices = [i for i in range(len(outputs)) if not isinstance(outputs[i], Tensor)] attrs = kwargs.get("attrs", None) if attrs is not None: @@ -213,7 +194,7 @@ def __call__( outputs=outputs, ) - # For newly created output tensors, set their shape and dtype to match the Function defintion. + # For newly created output tensors, set their shape and dtype to match the Function definition. for i in new_output_indices: outputs[i].dtype = self.outputs[i].dtype outputs[i].shape = self.outputs[i].shape @@ -268,9 +249,7 @@ def get_tensor(name): def __eq__(self, other: "Function"): def sequences_equal(seq1, seq2): - return len(seq1) == len(seq2) and all( - [elem1 == elem2 for elem1, elem2 in zip(seq1, seq2)] - ) + return len(seq1) == len(seq2) and all([elem1 == elem2 for elem1, elem2 in zip(seq1, seq2)]) return ( self.unique_id == other.unique_id diff --git a/onnxslim/onnx_graphsurgeon/ir/graph.py b/onnxslim/onnx_graphsurgeon/ir/graph.py index c31df88..5ee32ee 100644 --- a/onnxslim/onnx_graphsurgeon/ir/graph.py +++ b/onnxslim/onnx_graphsurgeon/ir/graph.py @@ -17,7 +17,7 @@ import copy import numbers -from collections import defaultdict, OrderedDict +from collections import OrderedDict, defaultdict from typing import List, Sequence import numpy as np @@ -44,9 +44,7 @@ def __exit__(self, exc_type, exc_value, traceback): class Graph(object): - """ - Represents a graph containing nodes and tensors. - """ + """Represents a graph containing nodes and tensors.""" DEFAULT_OPSET = 11 OPSET_FUNC_MAP = defaultdict(dict) # Ops registered for specific opsets. @@ -55,8 +53,8 @@ class Graph(object): @staticmethod def register(opsets=None): """ - Registers a function with the Graph class for the specified group of opsets. - After registering the function, it can be accessed like a normal member function. + Registers a function with the Graph class for the specified group of opsets. After registering the function, it + can be accessed like a normal member function. For example: :: @@ -150,14 +148,9 @@ def __getattr__(self, name): method_descs = [] # Opset specific ops always take priority over global ops. - if ( - self.opset in Graph.OPSET_FUNC_MAP - and name in Graph.OPSET_FUNC_MAP[self.opset] - ): + if self.opset in Graph.OPSET_FUNC_MAP and name in Graph.OPSET_FUNC_MAP[self.opset]: methods.append(Graph.OPSET_FUNC_MAP[self.opset][name]) - method_descs.append( - f'GraphSurgeon-registered function "{name}" with opset {self.opset}' - ) + method_descs.append(f'GraphSurgeon-registered function "{name}" with opset {self.opset}') # Registered ops take priority over Local Functions. if name in Graph.GLOBAL_FUNC_MAP: @@ -167,27 +160,19 @@ def __getattr__(self, name): for func in self.functions: if func.name == name: methods.append(func.__call__) - method_descs.append( - f'Local Function "{func.name}" with domain "{func.domain}"' - ) + method_descs.append(f'Local Function "{func.name}" with domain "{func.domain}"') if methods: if len(methods) > 1: msg_template = "Method name {} is overloaded with the following candidates: {}. " msg_template += "Choosing candidate {}" G_LOGGER.warning( - message=msg_template.format( - name, method_descs, method_descs[0] - ), + message=msg_template.format(name, method_descs, method_descs[0]), mode=LogMode.ONCE, ) return lambda *args, **kwargs: methods[0](self, *args, **kwargs) - found_in_other_opsets = { - opset - for opset, opset_map in Graph.OPSET_FUNC_MAP.items() - if name in opset_map - } + found_in_other_opsets = {opset for opset, opset_map in Graph.OPSET_FUNC_MAP.items() if name in opset_map} G_LOGGER.error( f"Function: '{name}' was not registered for opset {self.opset}. " @@ -229,9 +214,7 @@ def __eq__(self, other: "Graph"): if not outputs_match: return False - opset_matches = ( - self.opset == other.opset and self.import_domains == other.import_domains - ) + opset_matches = self.opset == other.opset and self.import_domains == other.import_domains if not opset_matches: return False @@ -259,20 +242,14 @@ def _get_node_id(self, node): except AttributeError: G_LOGGER.critical( "Encountered a node not in the graph:\n{:}.\n\n" - "To fix this, please append the node to this graph's `nodes` attribute.".format( - node - ) + "To fix this, please append the node to this graph's `nodes` attribute.".format(node) ) # A tensor is local if it is produced in this graph, or is explicitly a graph input. def _local_tensors(self): - local_tensors = { - t.name: t for node in self.nodes for t in node.outputs if not t.is_empty() - } + local_tensors = {t.name: t for node in self.nodes for t in node.outputs if not t.is_empty()} local_tensors.update({t.name: t for t in self.inputs}) - local_tensors.update( - {t.name: t for t in self.tensors().values() if isinstance(t, Constant)} - ) + local_tensors.update({t.name: t for t in self.tensors().values() if isinstance(t, Constant)}) return local_tensors # Returns tensors used by this graph which are not present in the graph. @@ -285,17 +262,13 @@ def is_foreign_tensor(tensor): return tensor.name not in local_tensors for node in self.nodes: - foreign_tensors.update( - {t.name: t for t in node.inputs if is_foreign_tensor(t)} - ) + foreign_tensors.update({t.name: t for t in node.inputs if is_foreign_tensor(t)}) for subgraph in node.subgraphs(): subgraph_foreign_tensors = subgraph._foreign_tensors() # Some of the foreign tensors from a subgraph may come from this graph. subgraph_foreign_tensors = { - t.name: t - for t in subgraph_foreign_tensors.values() - if is_foreign_tensor(t) + t.name: t for t in subgraph_foreign_tensors.values() if is_foreign_tensor(t) } foreign_tensors.update(subgraph_foreign_tensors) @@ -366,8 +339,8 @@ def absorb_function_list(func_list): def subgraphs(self, recursive=False): """ - Convenience function to iterate over all subgraphs which are contained in this graph. - Subgraphs are found in the attributes of ONNX control flow nodes such as 'If' and 'Loop'. + Convenience function to iterate over all subgraphs which are contained in this graph. Subgraphs are found in the + attributes of ONNX control flow nodes such as 'If' and 'Loop'. Args: recursive (bool): Whether to recursively search this graph's subgraphs for more subgraphs. Defaults to False. @@ -387,8 +360,8 @@ def cleanup( recurse_functions=True, ): """ - Removes unused nodes and tensors from the graph. - A node or tensor is considered unused if it does not contribute to any of the graph outputs. + Removes unused nodes and tensors from the graph. A node or tensor is considered unused if it does not contribute + to any of the graph outputs. Additionally, any producer nodes of graph input tensors, as well as consumer nodes of graph output tensors that are not in the graph, are removed from the graph. @@ -466,9 +439,7 @@ def cleanup_subgraphs(): def is_hanging_tensor(tensor): return ( - not tensor.is_empty() - and len(tensor.outputs) == 0 - and tensor.name not in graph_output_names + not tensor.is_empty() and len(tensor.outputs) == 0 and tensor.name not in graph_output_names ) to_remove = [out for out in node.outputs if is_hanging_tensor(out)] @@ -523,9 +494,7 @@ def toposort( if sort_nodes and recurse_subgraphs: for subgraph in self.subgraphs(): - subgraph.toposort( - recurse_subgraphs=True, recurse_functions=False, mode="nodes" - ) + subgraph.toposort(recurse_subgraphs=True, recurse_functions=False, mode="nodes") G_LOGGER.debug("Topologically sorting {:}".format(self.name)) @@ -558,9 +527,7 @@ def get_hierarchy_level(node_or_func, visited=None): if isinstance(node_or_func, Function): G_LOGGER.critical("Cycle detected in function definitions!") - G_LOGGER.critical( - "Cycle detected in graph! Are there tensors with duplicate names in the graph?" - ) + G_LOGGER.critical("Cycle detected in graph! Are there tensors with duplicate names in the graph?") visited.add(get_id(node_or_func)) def get_inputs(node_or_func): @@ -606,25 +573,17 @@ def get_used_funcs(nodes): # The level of a node is the level of its highest input + 1. max_input_level = max( - [ - get_hierarchy_level(inp, visited=visited) - for inp in get_inputs(node_or_func) - ] - + [-1] + [get_hierarchy_level(inp, visited=visited) for inp in get_inputs(node_or_func)] + [-1] ) visited.remove(get_id(node_or_func)) - hierarchy_levels[get_id(node_or_func)] = HierarchyDescriptor( - node_or_func, level=max_input_level + 1 - ) + hierarchy_levels[get_id(node_or_func)] = HierarchyDescriptor(node_or_func, level=max_input_level + 1) return max_input_level + 1 if sort_nodes: with self.node_ids(): for node in self.nodes: - hierarchy_levels[get_id(node)] = HierarchyDescriptor( - node, level=get_hierarchy_level(node) - ) + hierarchy_levels[get_id(node)] = HierarchyDescriptor(node, level=get_hierarchy_level(node)) self.nodes = [hd.node_or_func for hd in sorted(hierarchy_levels.values())] if sort_functions: @@ -632,18 +591,15 @@ def get_used_funcs(nodes): func_id_to_func.update({func.unique_id: func for func in self.functions}) hierarchy_levels.clear() for func in self.functions: - hierarchy_levels[func.unique_id] = HierarchyDescriptor( - func, level=get_hierarchy_level(func) - ) - self.functions = [ - hd.node_or_func for hd in sorted(hierarchy_levels.values()) - ] + hierarchy_levels[func.unique_id] = HierarchyDescriptor(func, level=get_hierarchy_level(func)) + self.functions = [hd.node_or_func for hd in sorted(hierarchy_levels.values())] return self def tensors(self, check_duplicates=False): """ - Creates a tensor map of all the tensors used by this graph by walking over all nodes. Empty tensors are omitted from this map. + Creates a tensor map of all the tensors used by this graph by walking over all nodes. Empty tensors are omitted + from this map. Tensors are guaranteed to be in order of the nodes in the graph. Hence, if the graph is topologically sorted, the tensor map will be too. @@ -660,18 +616,18 @@ def tensors(self, check_duplicates=False): def add_to_tensor_map(tensor): if not tensor.is_empty(): - if tensor.name in tensor_map and not ( - tensor_map[tensor.name] is tensor - ): + if tensor.name in tensor_map and not (tensor_map[tensor.name] is tensor): msg = "Found distinct tensors that share the same name:\n[id: {:}] {:}\n[id: {:}] {:}\n".format( id(tensor_map[tensor.name]), tensor_map[tensor.name], id(tensor), tensor, ) - msg += "Note: Producer node(s) of first tensor:\n{:}\nProducer node(s) of second tensor:\n{:}".format( - tensor_map[tensor.name].inputs, - tensor.inputs, + msg += ( + "Note: Producer node(s) of first tensor:\n{:}\nProducer node(s) of second tensor:\n{:}".format( + tensor_map[tensor.name].inputs, + tensor.inputs, + ) ) if check_duplicates: @@ -705,8 +661,8 @@ def fold_constants( recurse_functions=True, ): """ - Folds constants in-place in the graph. The graph's nodes and functions must be topologically - sorted prior to calling this function (see `toposort()`). + Folds constants in-place in the graph. The graph's nodes and functions must be topologically sorted prior to + calling this function (see `toposort()`). This function will not remove constants after folding them. In order to get rid of these hanging nodes, you can run the `cleanup()` function. @@ -730,8 +686,8 @@ def fold_constants( - None: Do not partition the graph. If inference fails, no constants are folded. - "basic": Partition the graph. If inference fails in one partition, other partitions will remain unaffected. - - "recursive": Parition the graph recursively. If inference fails in a partition, the partition - will be further paritioned. + - "recursive": Partition the graph recursively. If inference fails in a partition, the partition + will be further partitioned. Defaults to None. error_ok (bool): @@ -767,9 +723,7 @@ def fold_constants( export_onnx, ) - custom_should_exclude_node = misc.default_value( - should_exclude_node, lambda node: False - ) + custom_should_exclude_node = misc.default_value(should_exclude_node, lambda node: False) # Don't fold nodes with attribute values which are variable. def should_exclude_node(node): @@ -780,11 +734,7 @@ def should_exclude_node(node): PARTITIONING_MODES = [None, "basic", "recursive"] if partitioning not in PARTITIONING_MODES: - G_LOGGER.critical( - "Argument for parameter 'partitioning' must be one of: {:}".format( - PARTITIONING_MODES - ) - ) + G_LOGGER.critical("Argument for parameter 'partitioning' must be one of: {:}".format(PARTITIONING_MODES)) ORT_PROVIDERS = ["CPUExecutionProvider"] G_LOGGER.debug("Folding constants in {:}".format(self.name)) @@ -805,9 +755,7 @@ def should_exclude_node(node): node = tensor.inputs[0] if node.op == "Constant": if len(node.attrs) != 1: - G_LOGGER.warning( - "Constant node must contain exactly one attribute" - ) + G_LOGGER.warning("Constant node must contain exactly one attribute") continue attr_name, attr_val = list(node.attrs.items())[0] allowed_attrs = { @@ -818,9 +766,7 @@ def should_exclude_node(node): "value_ints", } if attr_name not in allowed_attrs: - G_LOGGER.warning( - f"Unsupported attribute for Constant node: {attr_name}" - ) + G_LOGGER.warning(f"Unsupported attribute for Constant node: {attr_name}") continue if isinstance(attr_val, Node.AttributeRef): continue @@ -868,8 +814,7 @@ def run_cast_elision(node): inp_node for inp_tensor in node.inputs for inp_node in inp_tensor.inputs - if inp_node.op == "Cast" - and inp_node.attrs["to"] == onnx.TensorProto.DataType.FLOAT + if inp_node.op == "Cast" and inp_node.attrs["to"] == onnx.TensorProto.DataType.FLOAT ] # No cast nodes found, return early @@ -877,9 +822,7 @@ def run_cast_elision(node): return # Ensure that all input cast nodes are casting from the same type - inp_dtypes = [ - dtype_to_onnx(inp_cast.inputs[0].dtype) for inp_cast in inp_casts - ] + inp_dtypes = [dtype_to_onnx(inp_cast.inputs[0].dtype) for inp_cast in inp_casts] if len(set(inp_dtypes)) != 1: return @@ -891,8 +834,7 @@ def run_cast_elision(node): for out_tensor in node.outputs for out_node in out_tensor.outputs if out_node.op == "Cast" - and out_node.attrs["to"] - in [onnx.TensorProto.DataType.INT32, onnx.TensorProto.DataType.INT64] + and out_node.attrs["to"] in [onnx.TensorProto.DataType.INT32, onnx.TensorProto.DataType.INT64] ] # No cast node found on outputs, return early @@ -912,9 +854,7 @@ def run_cast_elision(node): # `cast_node.inputs[0].outputs[0] == cast_node`. for index, inp in enumerate(node.inputs): if isinstance(inp, Constant): - inp.values = inp.values.astype( - onnx.helper.tensor_dtype_to_np_dtype(final_type) - ) + inp.values = inp.values.astype(onnx.helper.tensor_dtype_to_np_dtype(final_type)) for cast in inp_casts: if cast.outputs[0] == inp: @@ -929,9 +869,7 @@ def run_cast_elision(node): if fold_shapes: # Perform shape tensor cast elision prior to most other folding - G_LOGGER.debug( - "Performing shape tensor cast elision in {:}".format(self.name) - ) + G_LOGGER.debug("Performing shape tensor cast elision in {:}".format(self.name)) try: with self.node_ids(): for node in self.nodes: @@ -939,11 +877,7 @@ def run_cast_elision(node): except Exception as err: if not error_ok: raise err - G_LOGGER.warning( - "'{:}' routine failed with: {:}".format( - "Shape tensor cast elision", err - ) - ) + G_LOGGER.warning("'{:}' routine failed with: {:}".format("Shape tensor cast elision", err)) # Note that most of the remaining passes operate on a clone of the original graph. # Pass 3: Find all descendants of constant tensors @@ -967,9 +901,7 @@ def is_foldable(node): def all_tensors_const(tensors): # Ignore omitted optional inputs. - return all( - [t.name in graph_constants for t in tensors if not t.is_empty()] - ) + return all([t.name in graph_constants for t in tensors if not t.is_empty()]) if not all_tensors_const(node.inputs): return False @@ -977,13 +909,9 @@ def all_tensors_const(tensors): all_subgraph_foreign_tensors_const = True for subgraph in node.subgraphs(): foreign_tensors = subgraph._foreign_tensors().values() - all_subgraph_foreign_tensors_const &= all_tensors_const( - foreign_tensors - ) + all_subgraph_foreign_tensors_const &= all_tensors_const(foreign_tensors) - return all_subgraph_foreign_tensors_const and not should_exclude_node( - node - ) + return all_subgraph_foreign_tensors_const and not should_exclude_node(node) # Walks along the outputs of graph_constants to see if they can also be computed statically. # Since the graph is topologically sorted, this should find all constant nodes in the graph. @@ -992,19 +920,13 @@ def all_tensors_const(tensors): graph_constants.update({out.name: out for out in node.outputs}) return graph_constants - graph_constants = { - name: tensor - for name, tensor in clone_tensors.items() - if isinstance(tensor, Constant) - } + graph_constants = {name: tensor for name, tensor in clone_tensors.items() if isinstance(tensor, Constant)} graph_constants = update_foldable_outputs(graph_constants) # Pass 4: Shape Folding def get_producer(tensor, op): - """ - Get the producer of the specified tensor iff it matches op - """ + """Get the producer of the specified tensor iff it matches op.""" if len(tensor.inputs) != 1: return None @@ -1014,9 +936,7 @@ def get_producer(tensor, op): return node def get_input(node, index=0): - """ - Get the input tensor of a node iff the input tensor is not already marked a graph constant. - """ + """Get the input tensor of a node iff the input tensor is not already marked a graph constant.""" if node is None: return None @@ -1029,9 +949,7 @@ def get_input(node, index=0): return inp def get_scalar_value(tensor): - """ - Gets the scalar value of a constant tensor with a single item - """ + """Gets the scalar value of a constant tensor with a single item.""" if not tensor.shape: return tensor.values else: @@ -1130,21 +1048,13 @@ def fold_shape_slice(tensor): shape_of = shape_fold_func(tensor) if shape_of is not None: - G_LOGGER.ultra_verbose( - "Folding shape tensor: {:} to: {:}".format( - tensor.name, shape_of - ) - ) + G_LOGGER.ultra_verbose("Folding shape tensor: {:} to: {:}".format(tensor.name, shape_of)) graph_constants[tensor.name] = tensor.to_constant(shape_of) graph_constants[tensor.name].inputs.clear() except Exception as err: if not error_ok: raise err - G_LOGGER.warning( - "'{:}' routine failed with:\n{:}".format( - shape_fold_func.__name__, err - ) - ) + G_LOGGER.warning("'{:}' routine failed with:\n{:}".format(shape_fold_func.__name__, err)) else: graph_constants = update_foldable_outputs(graph_constants) @@ -1183,11 +1093,7 @@ def get_out_node_ids(): ) values = sess.run(names, {}) except Exception as err: - G_LOGGER.warning( - "Inference failed for subgraph: {:}. Note: Error was:\n{:}".format( - part.name, err - ) - ) + G_LOGGER.warning("Inference failed for subgraph: {:}. Note: Error was:\n{:}".format(part.name, err)) if partitioning == "recursive": G_LOGGER.verbose("Attempting to recursively partition subgraph") # Partition failed, peel off last node. @@ -1197,17 +1103,13 @@ def get_out_node_ids(): out_node.outputs.clear() out_node.inputs.clear() else: - G_LOGGER.info( - "You may see better results if you set partitioning='recursive'" - ) + G_LOGGER.info("You may see better results if you set partitioning='recursive'") if not error_ok: raise err constant_values.update(partition_and_infer(part)) else: - constant_values.update( - {name: val for name, val in zip(names, values)} - ) + constant_values.update({name: val for name, val in zip(names, values)}) return constant_values @@ -1219,35 +1121,23 @@ def should_eval_foldable(tensor): non_const = not isinstance(tensor, Constant) is_graph_output = not tensor.outputs - has_non_foldable_outputs = any( - out.name not in graph_constants for out in tensor.outputs - ) + has_non_foldable_outputs = any(out.name not in graph_constants for out in tensor.outputs) exceeds_size_threshold = ( tensor.shape is not None and not misc.is_dynamic_shape(tensor.shape) and tensor.dtype is not None and size_threshold is not None - ) and ( - misc.volume(tensor.shape) * get_itemsize(tensor.dtype) > size_threshold - ) + ) and (misc.volume(tensor.shape) * get_itemsize(tensor.dtype) > size_threshold) - return ( - non_const - and (is_graph_output or has_non_foldable_outputs) - and not exceeds_size_threshold - ) + return non_const and (is_graph_output or has_non_foldable_outputs) and not exceeds_size_threshold - graph_clone.outputs = [ - t for t in graph_constants.values() if should_eval_foldable(t) - ] + graph_clone.outputs = [t for t in graph_constants.values() if should_eval_foldable(t)] G_LOGGER.debug("Folding tensors: {:}".format(graph_clone.outputs)) graph_clone.cleanup(remove_unused_graph_inputs=True, recurse_functions=False) # Using ._values avoids a deep copy of the values. constant_values = { - name: tensor._values - for name, tensor in graph_constants.items() - if isinstance(tensor, Constant) + name: tensor._values for name, tensor in graph_constants.items() if isinstance(tensor, Constant) } if graph_clone.outputs: if partitioning: @@ -1258,15 +1148,11 @@ def should_eval_foldable(tensor): import onnxruntime as onnxrt sess = onnxrt.InferenceSession( - export_onnx( - graph_clone, do_type_check=False - ).SerializeToString(), + export_onnx(graph_clone, do_type_check=False).SerializeToString(), providers=ORT_PROVIDERS, ) values = sess.run(names, {}) - constant_values.update( - {name: val for name, val in zip(names, values)} - ) + constant_values.update({name: val for name, val in zip(names, values)}) except Exception as err: G_LOGGER.warning( "Inference failed. You may want to try enabling partitioning to see better results. " @@ -1307,15 +1193,12 @@ def should_eval_foldable(tensor): if large_tensors: large_tensors_mib = { - tensor_name: "{:} MiB".format(value // (1 << 20)) - for tensor_name, value in large_tensors.items() + tensor_name: "{:} MiB".format(value // (1 << 20)) for tensor_name, value in large_tensors.items() } G_LOGGER.warning( "It looks like this model contains foldable nodes that produce large outputs.\n" "In order to avoid bloating the model, you may want to set a constant-folding size threshold.\n" - "Note: Large tensors and their corresponding sizes were: {:}".format( - large_tensors_mib - ), + "Note: Large tensors and their corresponding sizes were: {:}".format(large_tensors_mib), mode=LogMode.ONCE, ) @@ -1343,9 +1226,7 @@ def fold_subgraphs(): if node.op == "If" and isinstance(node.inputs[0], Constant): G_LOGGER.debug("Flattening conditional: {:}".format(node)) cond = get_scalar_value(node.inputs[0]) - subgraph = ( - node.attrs["then_branch"] if cond else node.attrs["else_branch"] - ) + subgraph = node.attrs["then_branch"] if cond else node.attrs["else_branch"] # Need to add a suffix to subgraph tensors so they don't collide with outer graph tensors for tensor in subgraph._local_tensors().values(): tensor.name += "_subg_{:}_{:}".format(index, subgraph.name) @@ -1440,27 +1321,15 @@ def process_io(io, existing_names): tensor = Variable(name=name) new_io.append(tensor) elif isinstance(elem, np.ndarray): - name = self._generate_name( - "onnx_graphsurgeon_constant", existing_names - ) + name = self._generate_name("onnx_graphsurgeon_constant", existing_names) new_io.append(Constant(name=name, values=elem)) - elif ( - isinstance(elem, list) - or isinstance(elem, tuple) - or isinstance(elem, numbers.Number) - ): + elif isinstance(elem, list) or isinstance(elem, tuple) or isinstance(elem, numbers.Number): if isinstance(elem, list) or isinstance(elem, tuple): - dtype = ( - np.float32 - if any([isinstance(x, float) for x in elem]) - else np.int64 - ) + dtype = np.float32 if any([isinstance(x, float) for x in elem]) else np.int64 else: dtype = np.float32 if isinstance(elem, float) else np.int64 arr = np.array(elem, dtype=dtype) - name = self._generate_name( - "onnx_graphsurgeon_lst_constant", existing_names - ) + name = self._generate_name("onnx_graphsurgeon_lst_constant", existing_names) new_io.append(Constant(name=name, values=arr)) else: G_LOGGER.critical( @@ -1477,9 +1346,7 @@ def process_io(io, existing_names): outputs = process_io(outputs, existing_names) if "name" not in kwargs: - kwargs["name"] = self._generate_name( - "onnx_graphsurgeon_node", {node.name for node in self.nodes} - ) + kwargs["name"] = self._generate_name("onnx_graphsurgeon_node", {node.name for node in self.nodes}) node = Node(*args, **kwargs, inputs=inputs, outputs=outputs) self.nodes.append(node) @@ -1511,9 +1378,7 @@ def copy(self, tensor_map: "OrderedDict[str, Tensor]" = None): # However, we should prioritize copies already made by the outer graph. local_tensor_copies.update(tensor_map) # And locally produced tensors should take precedence over everything else. - local_tensor_copies.update( - {n: t.copy() for n, t in self._local_tensors().items()} - ) + local_tensor_copies.update({n: t.copy() for n, t in self._local_tensors().items()}) def get_tensor(name): if not name: diff --git a/onnxslim/onnx_graphsurgeon/ir/node.py b/onnxslim/onnx_graphsurgeon/ir/node.py index 307d20a..0a22571 100644 --- a/onnxslim/onnx_graphsurgeon/ir/node.py +++ b/onnxslim/onnx_graphsurgeon/ir/node.py @@ -25,12 +25,11 @@ class Node(object): - @dataclass class AttributeRef: """ - An AttributeRef is an attribute value which references an attribute in the parent function. - A node's attribute can only be an AttributeRef if the node lives inside a Function. + An AttributeRef is an attribute value which references an attribute in the parent function. A node's attribute + can only be an AttributeRef if the node lives inside a Function. Args: name (str): The name of the referenced attribute in the parent Function. @@ -64,18 +63,14 @@ def __init__( self.op = op self.name = misc.default_value(name, "") self.attrs = misc.default_value(attrs, OrderedDict()) - self.inputs = misc.SynchronizedList( - self, field_name="outputs", initial=misc.default_value(inputs, []) - ) - self.outputs = misc.SynchronizedList( - self, field_name="inputs", initial=misc.default_value(outputs, []) - ) + self.inputs = misc.SynchronizedList(self, field_name="outputs", initial=misc.default_value(inputs, [])) + self.outputs = misc.SynchronizedList(self, field_name="inputs", initial=misc.default_value(outputs, [])) self.domain = domain def i(self, tensor_idx=0, producer_idx=0): """ - Convenience function to get a producer node of one of this node's input tensors. - Note that the parameters are swapped compared to the o() function; this is because tensors are likely to have only a single producer + Convenience function to get a producer node of one of this node's input tensors. Note that the parameters are + swapped compared to the o() function; this is because tensors are likely to have only a single producer. For example: :: @@ -113,8 +108,8 @@ def o(self, consumer_idx=0, tensor_idx=0): def subgraphs(self, recursive=False): """ - Convenience function to iterate over all subgraphs which are contained in this node. - Node subgraphs are found in attributes of ONNX control flow nodes such as 'If' and 'Loop'. + Convenience function to iterate over all subgraphs which are contained in this node. Node subgraphs are found in + attributes of ONNX control flow nodes such as 'If' and 'Loop'. Args: recursive (bool): Whether to recurse into the subgraph nodes when looking for subgraphs. Defaults to False. @@ -208,15 +203,9 @@ def __repr__(self): return self.__str__() def __eq__(self, other): - """ - Check whether two nodes are equal by comparing name, attributes, op, inputs, and outputs. - """ + """Check whether two nodes are equal by comparing name, attributes, op, inputs, and outputs.""" G_LOGGER.verbose("Comparing node: {:} with {:}".format(self.name, other.name)) - attrs_match = ( - self.name == other.name - and self.op == other.op - and self.attrs == other.attrs - ) + attrs_match = self.name == other.name and self.op == other.op and self.attrs == other.attrs if not attrs_match: return False diff --git a/onnxslim/onnx_graphsurgeon/ir/tensor.py b/onnxslim/onnx_graphsurgeon/ir/tensor.py index b90da39..8b6e6c4 100644 --- a/onnxslim/onnx_graphsurgeon/ir/tensor.py +++ b/onnxslim/onnx_graphsurgeon/ir/tensor.py @@ -24,14 +24,12 @@ class Tensor(object): - """Abstract base class for tensors in a graph""" + """Abstract base class for tensors in a graph.""" DYNAMIC = -1 def __init__(self): - """ - **This class is abstract and cannot be constructed directly.** - """ + """**This class is abstract and cannot be constructed directly.**""" raise NotImplementedError("Tensor is an abstract class") def __setattr__(self, name, value): @@ -69,7 +67,8 @@ def to_constant( export_dtype: Union[np.dtype, "onnx.TensorProto.DataType"] = None, ): """ - Modifies this tensor in-place to convert it to a Constant. This means that all consumers/producers of the tensor will see the update. + Modifies this tensor in-place to convert it to a Constant. This means that all consumers/producers of the tensor + will see the update. Args: values (np.ndarray): The values in this tensor @@ -95,7 +94,8 @@ def to_variable( shape: Sequence[Union[int, str]] = [], ): """ - Modifies this tensor in-place to convert it to a Variable. This means that all consumers/producers of the tensor will see the update. + Modifies this tensor in-place to convert it to a Variable. This means that all consumers/producers of the tensor + will see the update. Args: dtype (Union[numpy.dtype, onnx.TensorProto.DataType]): The data type of the tensor. @@ -115,8 +115,8 @@ def to_variable( def i(self, tensor_idx=0, producer_idx=0): """ - Convenience function to get an input tensor of one of this tensor's input nodes. - Note that the parameters are swapped compared to the o() function; this is because tensors are likely to have only a single producer + Convenience function to get an input tensor of one of this tensor's input nodes. Note that the parameters are + swapped compared to the o() function; this is because tensors are likely to have only a single producer. For example: :: @@ -153,9 +153,7 @@ def o(self, consumer_idx=0, tensor_idx=0): return self.outputs[consumer_idx].outputs[tensor_idx] def __str__(self): - return "{:} ({:}): (shape={:}, dtype={:})".format( - type(self).__name__, self.name, self.shape, self.dtype - ) + return "{:} ({:}): (shape={:}, dtype={:})".format(type(self).__name__, self.name, self.shape, self.dtype) def __repr__(self): # Hack to make logging output pretty. return self.__str__() @@ -216,44 +214,27 @@ def copy(self): return Variable(self.name, self.dtype, self.shape) def __eq__(self, other): - """ - Perform a check to see if two variables are equal. - """ + """Perform a check to see if two variables are equal.""" if not isinstance(other, Variable): return False name_match = self.name == other.name inputs_match = len(self.inputs) == len(other.inputs) and all( - [ - inp.name == other_inp.name - for inp, other_inp in zip(self.inputs, other.inputs) - ] + [inp.name == other_inp.name for inp, other_inp in zip(self.inputs, other.inputs)] ) outputs_match = len(self.outputs) == len(other.outputs) and all( - [ - out.name == other_out.name - for out, other_out in zip(self.outputs, other.outputs) - ] + [out.name == other_out.name for out, other_out in zip(self.outputs, other.outputs)] ) dtype_match = self.dtype == other.dtype shape_match = self.shape == other.shape type_match = self.type == other.type - return ( - name_match - and inputs_match - and outputs_match - and dtype_match - and shape_match - and type_match - ) + return name_match and inputs_match and outputs_match and dtype_match and shape_match and type_match class LazyValues(object): - """ - A special object that represents constant tensor values that should be lazily loaded. - """ + """A special object that represents constant tensor values that should be lazily loaded.""" def __init__(self, tensor): """ @@ -303,9 +284,7 @@ def __repr__(self): # Hack to make logging output pretty. return self.__str__() def __eq__(self, other): - """ - Perform a check to see if two variables are equal. - """ + """Perform a check to see if two variables are equal.""" if not isinstance(other, LazyValues): return False @@ -317,9 +296,7 @@ def __eq__(self, other): class SparseValues(LazyValues): - """ - A special object that represents constant tensor values that is sparse - """ + """A special object that represents constant tensor values that is sparse.""" def load(self): """ @@ -343,9 +320,7 @@ def load(self): ) if self.tensor.values.data_type == onnx.TensorProto.FLOAT16: - values_data = np.asarray( - self.tensor.values.int32_data, dtype=np.uint16 - ).view(np.float16) + values_data = np.asarray(self.tensor.values.int32_data, dtype=np.uint16).view(np.float16) else: field_name = onnx.helper.tensor_dtype_to_field(self.tensor.values.data_type) values = getattr(self.tensor.values, field_name) @@ -366,9 +341,7 @@ def load(self): for i in range(len(values_data)): values[tuple(indices_data[i])] = values_data[i] else: - G_LOGGER.critical( - f"Unsupported index data dims {self.tensor.indices.dims} in {self.tensor.values.name}" - ) + G_LOGGER.critical(f"Unsupported index data dims {self.tensor.indices.dims} in {self.tensor.values.name}") return values @@ -411,17 +384,13 @@ def __init__( G_LOGGER.critical( "Provided `values` argument is not a NumPy array, a LazyValues instance or a" "SparseValues instance. Please provide a NumPy array or LazyValues instance " - "to construct a Constant. Note: Provided `values` parameter was: {:}".format( - values - ) + "to construct a Constant. Note: Provided `values` parameter was: {:}".format(values) ) self._values = values self.data_location = data_location self._export_dtype = export_dtype - def to_variable( - self, dtype: np.dtype = None, shape: Sequence[Union[int, str]] = [] - ): + def to_variable(self, dtype: np.dtype = None, shape: Sequence[Union[int, str]] = []): var_dtype = self.export_dtype del self._export_dtype @@ -442,7 +411,7 @@ def copy(self): @property def values(self): - # Load values when they are first accesed + # Load values when they are first accessed if isinstance(self._values, LazyValues): self._values = self._values.load() return self._values @@ -476,15 +445,11 @@ def __repr__(self): # Hack to make logging output pretty. return ret def __eq__(self, other): - """ - Perform a check to see if two variables are equal. - """ + """Perform a check to see if two variables are equal.""" if not isinstance(other, Constant): return False - if isinstance(self._values, LazyValues) and isinstance( - other._values, LazyValues - ): + if isinstance(self._values, LazyValues) and isinstance(other._values, LazyValues): value_match = self._values == other._values else: value_match = np.array_equal(self.values, other.values) diff --git a/onnxslim/onnx_graphsurgeon/logger/logger.py b/onnxslim/onnx_graphsurgeon/logger/logger.py index 6e46bd0..326ef51 100644 --- a/onnxslim/onnx_graphsurgeon/logger/logger.py +++ b/onnxslim/onnx_graphsurgeon/logger/logger.py @@ -16,7 +16,6 @@ # import enum - import inspect import os import sys @@ -89,9 +88,7 @@ class Logger(object): CRITICAL: "red_1", } - def __init__( - self, severity=INFO, colors=True, letter=True, timestamp=False, line_info=False - ): + def __init__(self, severity=INFO, colors=True, letter=True, timestamp=False, line_info=False): """ Logger. @@ -104,9 +101,7 @@ def __init__( """ self._severity = severity self.logging_indent = 0 - self.root_dir = os.path.abspath( - os.path.join(os.path.dirname(__file__), os.pardir, os.pardir) - ) + self.root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) self.once_logged = set() self.colors = colors self.letter = letter @@ -126,8 +121,8 @@ def severity(self, value): def register_callback(self, callback): """ - Registers a callback with the logger, which will be invoked when the logging severity is modified. - The callback is guaranteed to be called at least once in the register_callback function. + Registers a callback with the logger, which will be invoked when the logging severity is modified. The callback + is guaranteed to be called at least once in the register_callback function. Args: callback (Callable(Logger.Severity)): A callback that accepts the current logger severity. @@ -136,9 +131,7 @@ def register_callback(self, callback): self.logger_callbacks.append(callback) def indent(self, level=1): - """ - Returns a context manager that indents all strings logged by the specified amount. - """ + """Returns a context manager that indents all strings logged by the specified amount.""" return LoggerIndent(self, level + self.logging_indent) def suppress(self, severity=CRITICAL): @@ -165,9 +158,7 @@ def get_line_info(): # If the file is not located in trt_smeagol, use its basename instead. if os.pardir in filename: filename = os.path.basename(filename) - return "[{:}:{:}] ".format( - filename, sys._getframe(stack_depth).f_lineno - ) + return "[{:}:{:}] ".format(filename, sys._getframe(stack_depth).f_lineno) prefix = "" if self.letter: @@ -180,9 +171,7 @@ def get_line_info(): def apply_indentation(message): message_lines = str(message).splitlines() - return "\n".join( - ["\t" * self.logging_indent + line for line in message_lines] - ) + return "\n".join(["\t" * self.logging_indent + line for line in message_lines]) def apply_color(message): if self.colors: diff --git a/onnxslim/onnx_graphsurgeon/util/exception.py b/onnxslim/onnx_graphsurgeon/util/exception.py index addf20a..deb84a6 100644 --- a/onnxslim/onnx_graphsurgeon/util/exception.py +++ b/onnxslim/onnx_graphsurgeon/util/exception.py @@ -17,6 +17,4 @@ class OnnxGraphSurgeonException(Exception): - """ - An exception raised by ONNX-GraphSurgeon. - """ + """An exception raised by ONNX-GraphSurgeon.""" diff --git a/onnxslim/onnx_graphsurgeon/util/misc.py b/onnxslim/onnx_graphsurgeon/util/misc.py index 97b1424..ff8bf2c 100644 --- a/onnxslim/onnx_graphsurgeon/util/misc.py +++ b/onnxslim/onnx_graphsurgeon/util/misc.py @@ -58,7 +58,9 @@ def default_value(value, default): def combine_dicts(dict0, dict1): """ - Combine two dictionaries. Values in the second will overwrite values in the first. + Combine two dictionaries. + + Values in the second will overwrite values in the first. """ combined = OrderedDict() combined.update(dict0) diff --git a/onnxslim/utils/tabulate.py b/onnxslim/utils/tabulate.py index 7e79985..0e00eb8 100644 --- a/onnxslim/utils/tabulate.py +++ b/onnxslim/utils/tabulate.py @@ -10,7 +10,8 @@ from collections.abc import Iterable, Sized from functools import partial, reduce from html import escape as htmlescape -from itertools import chain, zip_longest as izip_longest +from itertools import chain +from itertools import zip_longest as izip_longest try: import wcwidth # optional wide-character (CJK) support @@ -105,15 +106,15 @@ def _is_file(f): def _is_separating_line(row): row_type = type(row) is_sl = (row_type == list or row_type == str) and ( - (len(row) >= 1 and row[0] == SEPARATING_LINE) - or (len(row) >= 2 and row[1] == SEPARATING_LINE) + (len(row) >= 1 and row[0] == SEPARATING_LINE) or (len(row) >= 2 and row[1] == SEPARATING_LINE) ) return is_sl def _pipe_segment_with_colons(align, colwidth): - """Return a segment of a horizontal line with optional colons which - indicate column's alignment (as in `pipe` output format).""" + """Return a segment of a horizontal line with optional colons which indicate column's alignment (as in `pipe` output + format). + """ w = colwidth if align in ["right", "decimal"]: return ("-" * (w - 1)) + ":" @@ -126,8 +127,7 @@ def _pipe_segment_with_colons(align, colwidth): def _pipe_line_with_colons(colwidths, colaligns): - """Return a horizontal line with optional colons to indicate column's - alignment (as in `pipe` output format).""" + """Return a horizontal line with optional colons to indicate column's alignment (as in `pipe` output format).""" if not colaligns: # e.g. printing an empty data frame (github issue #15) colaligns = [""] * len(colwidths) segments = [_pipe_segment_with_colons(a, w) for a, w in zip(colaligns, colwidths)] @@ -143,9 +143,7 @@ def _mediawiki_row_with_attrs(separator, cell_values, colwidths, colaligns): } # hard-coded padding _around_ align attribute and value together # rather than padding parameter which affects only the value - values_with_attrs = [ - " " + alignment.get(a, "") + c + " " for c, a in zip(cell_values, colaligns) - ] + values_with_attrs = [" " + alignment.get(a, "") + c + " " for c, a in zip(cell_values, colaligns)] colsep = separator * 2 return (separator + colsep.join(values_with_attrs)).rstrip() @@ -171,8 +169,7 @@ def _html_row_with_attrs(celltag, unsafe, cell_values, colwidths, colaligns): } if unsafe: values_with_attrs = [ - "<{0}{1}>{2}".format(celltag, alignment.get(a, ""), c) - for c, a in zip(cell_values, colaligns) + "<{0}{1}>{2}".format(celltag, alignment.get(a, ""), c) for c, a in zip(cell_values, colaligns) ] else: values_with_attrs = [ @@ -193,8 +190,7 @@ def _moin_row_with_attrs(celltag, cell_values, colwidths, colaligns, header=""): "decimal": '', } values_with_attrs = [ - "{}{} {} ".format(celltag, alignment.get(a, ""), header + c + header) - for c, a in zip(cell_values, colaligns) + "{}{} {} ".format(celltag, alignment.get(a, ""), header + c + header) for c, a in zip(cell_values, colaligns) ] return "".join(values_with_attrs) + "||" @@ -204,28 +200,22 @@ def _latex_line_begin_tabular(colwidths, colaligns, booktabs=False, longtable=Fa tabular_columns_fmt = "".join([alignment.get(a, "l") for a in colaligns]) return "\n".join( [ - ("\\begin{tabular}{" if not longtable else "\\begin{longtable}{") - + tabular_columns_fmt - + "}", + ("\\begin{tabular}{" if not longtable else "\\begin{longtable}{") + tabular_columns_fmt + "}", "\\toprule" if booktabs else "\\hline", ] ) def _asciidoc_row(is_header, *args): - """handle header and data rows for asciidoc format""" + """Handle header and data rows for asciidoc format.""" def make_header_line(is_header, colwidths, colaligns): # generate the column specifiers alignment = {"left": "<", "right": ">", "center": "^", "decimal": ">"} # use the column widths generated by tabulate for the asciidoc column width specifiers - asciidoc_alignments = zip( - colwidths, [alignment[colalign] for colalign in colaligns] - ) - asciidoc_column_specifiers = [ - "{:d}{}".format(width, align) for width, align in asciidoc_alignments - ] + asciidoc_alignments = zip(colwidths, [alignment[colalign] for colalign in colaligns]) + asciidoc_column_specifiers = ["{:d}{}".format(width, align) for width, align in asciidoc_alignments] header_list = ['cols="' + (",".join(asciidoc_column_specifiers)) + '"'] # generate the list of options (currently only "header") @@ -772,18 +762,16 @@ def escape_empty(val): _ansi_codes_bytes = re.compile(_ansi_escape_pat.encode("utf8"), re.VERBOSE) _ansi_color_reset_code = "\033[0m" -_float_with_thousands_separators = re.compile( - r"^(([+-]?[0-9]{1,3})(?:,([0-9]{3}))*)?(?(1)\.[0-9]*|\.[0-9]+)?$" -) +_float_with_thousands_separators = re.compile(r"^(([+-]?[0-9]{1,3})(?:,([0-9]{3}))*)?(?(1)\.[0-9]*|\.[0-9]+)?$") def simple_separated_format(separator): - """Construct a simple TableFormat with columns separated by a separator. + """ + Construct a simple TableFormat with columns separated by a separator. >>> tsv = simple_separated_format("\\t") ; \ tabulate([["foo", 1], ["spam", 23]], tablefmt=tsv) == 'foo \\t 1\\nspam\\t23' True - """ return TableFormat( None, @@ -853,9 +841,7 @@ def _isnumber(string): """ if not _isconvertible(float, string): return False - elif isinstance(string, (str, bytes)) and ( - math.isinf(float(string)) or math.isnan(float(string)) - ): + elif isinstance(string, (str, bytes)) and (math.isinf(float(string)) or math.isnan(float(string))): return string.lower() in ["inf", "-inf", "nan"] return True @@ -873,9 +859,7 @@ def _isint(string, inttype=int): (hasattr(string, "is_integer") or hasattr(string, "__array__")) and str(type(string)).startswith(">> _isbool(1) False """ - return type(string) is bool or ( - isinstance(string, (bytes, str)) and string in ("True", "False") - ) + return type(string) is bool or (isinstance(string, (bytes, str)) and string in ("True", "False")) def _type(string, has_invisible=True, numparse=True): - """The least generic type (type(None), int, float, str, unicode). + """ + The least generic type (type(None), int, float, str, unicode). >>> _type(None) is type(None) True @@ -906,7 +889,6 @@ def _type(string, has_invisible=True, numparse=True): True >>> _type('\x1b[31m42\x1b[0m') is type(42) True - """ if has_invisible and isinstance(string, (str, bytes)): @@ -929,7 +911,8 @@ def _type(string, has_invisible=True, numparse=True): def _afterpoint(string): - """Symbols after a decimal point, -1 if the string lacks the decimal point. + """ + Symbols after a decimal point, -1 if the string lacks the decimal point. >>> _afterpoint("123.45") 2 @@ -941,7 +924,6 @@ def _afterpoint(string): 2 >>> _afterpoint("123,456.78") 2 - """ if _isnumber(string) or _isnumber_with_thousands_separator(string): if _isint(string): @@ -958,33 +940,33 @@ def _afterpoint(string): def _padleft(width, s): - """Flush right. + """ + Flush right. >>> _padleft(6, '\u044f\u0439\u0446\u0430') == ' \u044f\u0439\u0446\u0430' True - """ fmt = "{0:>%ds}" % width return fmt.format(s) def _padright(width, s): - """Flush left. + """ + Flush left. >>> _padright(6, '\u044f\u0439\u0446\u0430') == '\u044f\u0439\u0446\u0430 ' True - """ fmt = "{0:<%ds}" % width return fmt.format(s) def _padboth(width, s): - """Center string. + """ + Center string. >>> _padboth(6, '\u044f\u0439\u0446\u0430') == ' \u044f\u0439\u0446\u0430 ' True - """ fmt = "{0:^%ds}" % width return fmt.format(s) @@ -995,7 +977,8 @@ def _padnone(ignore_width, s): def _strip_ansi(s): - r"""Remove ANSI escape sequences, both CSI (color codes, etc) and OSC hyperlinks. + r""" + Remove ANSI escape sequences, both CSI (color codes, etc) and OSC hyperlinks. CSI sequences are simply removed from the output, while OSC hyperlinks are replaced with the link text. Note: it may be desirable to show the URI instead but this is not @@ -1006,7 +989,6 @@ def _strip_ansi(s): >>> repr(_strip_ansi('\x1b[31mred\x1b[0m text')) "'red text'" - """ if isinstance(s, str): return _ansi_codes.sub(r"\4", s) @@ -1015,11 +997,11 @@ def _strip_ansi(s): def _visible_width(s): - """Visible width of a printed string. ANSI color codes are removed. + """ + Visible width of a printed string. ANSI color codes are removed. >>> _visible_width('\x1b[31mhello\x1b[0m'), _visible_width("world") (5, 5) - """ # optional wide-character support if wcwidth is not None and WIDE_CHARS_MODE: @@ -1125,26 +1107,18 @@ def _align_column( ): """[string] -> [padded_string]""" strings, padfn = _align_column_choose_padfn(strings, alignment, has_invisible) - width_fn = _align_column_choose_width_fn( - has_invisible, enable_widechars, is_multiline - ) + width_fn = _align_column_choose_width_fn(has_invisible, enable_widechars, is_multiline) s_widths = list(map(width_fn, strings)) maxwidth = max(max(_flat_list(s_widths)), minwidth) # TODO: refactor column alignment in single-line and multiline modes if is_multiline: if not enable_widechars and not has_invisible: - padded_strings = [ - "\n".join([padfn(maxwidth, s) for s in ms.splitlines()]) - for ms in strings - ] + padded_strings = ["\n".join([padfn(maxwidth, s) for s in ms.splitlines()]) for ms in strings] else: # enable wide-character width corrections s_lens = [[len(s) for s in re.split("[\r\n]", ms)] for ms in strings] - visible_widths = [ - [maxwidth - (w - l) for w, l in zip(mw, ml)] - for mw, ml in zip(s_widths, s_lens) - ] + visible_widths = [[maxwidth - (w - l) for w, l in zip(mw, ml)] for mw, ml in zip(s_widths, s_lens)] # wcswidth and _visible_width don't count invisible characters; # padfn doesn't need to apply another correction padded_strings = [ @@ -1186,7 +1160,8 @@ def _more_generic(type1, type2): def _column_type(strings, has_invisible=True, numparse=True): - """The least generic type all column values are convertible to. + """ + The least generic type all column values are convertible to. >>> _column_type([True, False]) is bool True @@ -1205,14 +1180,14 @@ def _column_type(strings, has_invisible=True, numparse=True): >>> import datetime as dt >>> _column_type([dt.datetime(1991,2,19), dt.time(17,35)]) is str True - """ types = [_type(s, has_invisible, numparse) for s in strings] return reduce(_more_generic, types, bool) def _format(val, valtype, floatfmt, intfmt, missingval="", has_invisible=True): - """Format a value according to its type. + """ + Format a value according to its type. Unicode is supported: @@ -1221,7 +1196,6 @@ def _format(val, valtype, floatfmt, intfmt, missingval="", has_invisible=True): good_result = '\\u0431\\u0443\\u043a\\u0432\\u0430 \\u0446\\u0438\\u0444\\u0440\\u0430\\n------- -------\\n\\u0430\\u0437 2\\n\\u0431\\u0443\\u043a\\u0438 4' ; \ tabulate(tbl, headers=hrow) == good_result True - """ # noqa if val is None: return missingval @@ -1247,15 +1221,11 @@ def _format(val, valtype, floatfmt, intfmt, missingval="", has_invisible=True): return f"{val}" -def _align_header( - header, alignment, width, visible_width, is_multiline=False, width_fn=None -): +def _align_header(header, alignment, width, visible_width, is_multiline=False, width_fn=None): "Pad string header to width chars given known visible_width of the header." if is_multiline: header_lines = re.split(_multiline_codes, header) - padded_lines = [ - _align_header(h, alignment, width, width_fn(h)) for h in header_lines - ] + padded_lines = [_align_header(h, alignment, width, width_fn(h)) for h in header_lines] return "\n".join(padded_lines) # else: not multiline ninvisible = len(header) - visible_width @@ -1319,7 +1289,8 @@ def _bool(val): def _normalize_tabular_data(tabular_data, headers, showindex="default"): - """Transform a supported data type to a list of lists, and a list of headers, with headers padding. + """ + Transform a supported data type to a list of lists, and a list of headers, with headers padding. Supported tabular data types: @@ -1348,7 +1319,6 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): If showindex="always", show row indices for all types of data. If showindex="never", don't show row indices for all types of data. If showindex is an iterable, show its values as row indices. - """ try: @@ -1364,16 +1334,11 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): if hasattr(tabular_data.values, "__call__"): # likely a conventional dict keys = tabular_data.keys() - rows = list( - izip_longest(*tabular_data.values()) - ) # columns have to be transposed + rows = list(izip_longest(*tabular_data.values())) # columns have to be transposed elif hasattr(tabular_data, "index"): # values is a property, has .index => it's likely a pandas.DataFrame (pandas 0.11.0) keys = list(tabular_data) - if ( - showindex in ["default", "always", True] - and tabular_data.index.name is not None - ): + if showindex in ["default", "always", True] and tabular_data.index.name is not None: if isinstance(tabular_data.index.name, list): keys[:0] = tabular_data.index.name else: @@ -1394,19 +1359,10 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): if headers == "keys" and not rows: # an empty table (issue #81) headers = [] - elif ( - headers == "keys" - and hasattr(tabular_data, "dtype") - and getattr(tabular_data.dtype, "names") - ): + elif headers == "keys" and hasattr(tabular_data, "dtype") and getattr(tabular_data.dtype, "names"): # numpy record array headers = tabular_data.dtype.names - elif ( - headers == "keys" - and len(rows) > 0 - and isinstance(rows[0], tuple) - and hasattr(rows[0], "_fields") - ): + elif headers == "keys" and len(rows) > 0 and isinstance(rows[0], tuple) and hasattr(rows[0], "_fields"): # namedtuple headers = list(map(str, rows[0]._fields)) elif len(rows) > 0 and hasattr(rows[0], "keys") and hasattr(rows[0], "values"): @@ -1437,9 +1393,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): else: headers = [] elif headers: - raise ValueError( - "headers for a list of dicts is not a dict or a keyword" - ) + raise ValueError("headers for a list of dicts is not a dict or a keyword") rows = [[row.get(k) for k in keys] for row in rows] elif ( @@ -1452,11 +1406,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"): # print tabulate(cursor, headers='keys') headers = [column[0] for column in tabular_data.description] - elif ( - dataclasses is not None - and len(rows) > 0 - and dataclasses.is_dataclass(rows[0]) - ): + elif dataclasses is not None and len(rows) > 0 and dataclasses.is_dataclass(rows[0]): # Python 3.7+'s dataclass field_names = [field.name for field in dataclasses.fields(rows[0])] if headers == "keys": @@ -1528,14 +1478,8 @@ def _wrap_text_to_colwidths(list_of_lists, colwidths, numparses=True): # Cast based on our internal type handling # Any future custom formatting of types (such as datetimes) # may need to be more explicit than just `str` of the object - casted_cell = ( - str(cell) if _isnumber(cell) else _type(cell, False, numparse)(cell) - ) - wrapped = [ - "\n".join(wrapper.wrap(line)) - for line in casted_cell.splitlines() - if line.strip() != "" - ] + casted_cell = str(cell) if _isnumber(cell) else _type(cell, False, numparse)(cell) + wrapped = ["\n".join(wrapper.wrap(line)) for line in casted_cell.splitlines() if line.strip() != ""] new_row.append("\n".join(wrapped)) else: new_row.append(cell) @@ -1588,7 +1532,8 @@ def tabulate( rowalign=None, maxheadercolwidths=None, ): - """Format a fixed width table for pretty printing. + """ + Format a fixed width table for pretty printing. >>> print(tabulate([[1, 2.34], [-56, "8.999"], ["2", "10001"]])) --- --------- @@ -2079,15 +2024,12 @@ def tabulate( +------------+------------+-------------------------------+ Header column width can be specified in a similar way using `maxheadercolwidth` - """ if tabular_data is None: tabular_data = [] - list_of_lists, headers, headers_pad = _normalize_tabular_data( - tabular_data, headers, showindex=showindex - ) + list_of_lists, headers, headers_pad = _normalize_tabular_data(tabular_data, headers, showindex=showindex) list_of_lists, separating_lines = _remove_separating_lines(list_of_lists) if maxcolwidths is not None: @@ -2101,23 +2043,17 @@ def tabulate( maxcolwidths = _expand_iterable(maxcolwidths, num_cols, None) numparses = _expand_numparse(disable_numparse, num_cols) - list_of_lists = _wrap_text_to_colwidths( - list_of_lists, maxcolwidths, numparses=numparses - ) + list_of_lists = _wrap_text_to_colwidths(list_of_lists, maxcolwidths, numparses=numparses) if maxheadercolwidths is not None: num_cols = len(list_of_lists[0]) if isinstance(maxheadercolwidths, int): # Expand scalar for all columns - maxheadercolwidths = _expand_iterable( - maxheadercolwidths, num_cols, maxheadercolwidths - ) + maxheadercolwidths = _expand_iterable(maxheadercolwidths, num_cols, maxheadercolwidths) else: # Ignore col width for any 'trailing' columns maxheadercolwidths = _expand_iterable(maxheadercolwidths, num_cols, None) numparses = _expand_numparse(disable_numparse, num_cols) - headers = _wrap_text_to_colwidths( - [headers], maxheadercolwidths, numparses=numparses - )[0] + headers = _wrap_text_to_colwidths([headers], maxheadercolwidths, numparses=numparses)[0] # empty values in the first column of RST tables should be escaped (issue #82) # "" should be escaped as "\\ " or ".." @@ -2156,11 +2092,7 @@ def tabulate( has_invisible = _ansi_codes.search(plain_text) is not None enable_widechars = wcwidth is not None and WIDE_CHARS_MODE - if ( - not isinstance(tablefmt, TableFormat) - and tablefmt in multiline_formats - and _is_multiline(plain_text) - ): + if not isinstance(tablefmt, TableFormat) and tablefmt in multiline_formats and _is_multiline(plain_text): tablefmt = multiline_formats.get(tablefmt, tablefmt) is_multiline = True else: @@ -2172,17 +2104,13 @@ def tabulate( numparses = _expand_numparse(disable_numparse, len(cols)) coltypes = [_column_type(col, numparse=np) for col, np in zip(cols, numparses)] if isinstance(floatfmt, str): # old version - float_formats = len(cols) * [ - floatfmt - ] # just duplicate the string to use in each column + float_formats = len(cols) * [floatfmt] # just duplicate the string to use in each column else: # if floatfmt is list, tuple etc we have one per column float_formats = list(floatfmt) if len(float_formats) < len(cols): float_formats.extend((len(cols) - len(float_formats)) * [_DEFAULT_FLOATFMT]) if isinstance(intfmt, str): # old version - int_formats = len(cols) * [ - intfmt - ] # just duplicate the string to use in each column + int_formats = len(cols) * [intfmt] # just duplicate the string to use in each column else: # if intfmt is list, tuple etc we have one per column int_formats = list(intfmt) if len(int_formats) < len(cols): @@ -2195,9 +2123,7 @@ def tabulate( missing_vals.extend((len(cols) - len(missing_vals)) * [_DEFAULT_MISSINGVAL]) cols = [ [_format(v, ct, fl_fmt, int_fmt, miss_v, has_invisible) for v in c] - for c, ct, fl_fmt, int_fmt, miss_v in zip( - cols, coltypes, float_formats, int_formats, missing_vals - ) + for c, ct, fl_fmt, int_fmt, miss_v in zip(cols, coltypes, float_formats, int_formats, missing_vals) ] # align columns @@ -2206,7 +2132,7 @@ def tabulate( aligns = [colglobalalign] * len(cols) else: # default aligns = [numalign if ct in [int, float] else stralign for ct in coltypes] - # then specific alignements + # then specific alignments if colalign is not None: assert isinstance(colalign, Iterable) if isinstance(colalign, str): @@ -2219,9 +2145,7 @@ def tabulate( break elif align != "global": aligns[idx] = align - minwidths = ( - [width_fn(h) + min_padding for h in headers] if headers else [0] * len(cols) - ) + minwidths = [width_fn(h) + min_padding for h in headers] if headers else [0] * len(cols) cols = [ _align_column(c, a, minw, has_invisible, enable_widechars, is_multiline) for c, a, minw in zip(cols, aligns, minwidths) @@ -2236,7 +2160,7 @@ def tabulate( aligns_headers = [headersglobalalign] * len(t_cols) else: # default aligns_headers = aligns or [stralign] * len(headers) - # then specific header alignements + # then specific header alignments if headersalign is not None: assert isinstance(headersalign, Iterable) if isinstance(headersalign, str): @@ -2252,10 +2176,7 @@ def tabulate( aligns_headers[hidx] = aligns[hidx] elif align != "global": aligns_headers[hidx] = align - minwidths = [ - max(minw, max(width_fn(cl) for cl in c)) - for minw, c in zip(minwidths, t_cols) - ] + minwidths = [max(minw, max(width_fn(cl) for cl in c)) for minw, c in zip(minwidths, t_cols)] headers = [ _align_header(h, a, minw, width_fn(h), is_multiline, width_fn) for h, a, minw in zip(headers, aligns_headers, minwidths) @@ -2286,8 +2207,9 @@ def tabulate( def _expand_numparse(disable_numparse, column_count): """ - Return a list of bools of length `column_count` which indicates whether - number parsing should be used on each column. + Return a list of bools of length `column_count` which indicates whether number parsing should be used on each + column. + If `disable_numparse` is a list of indices, each of those indices are False, and everything else is True. If `disable_numparse` is a bool, then the returned list is all the same. @@ -2303,8 +2225,9 @@ def _expand_numparse(disable_numparse, column_count): def _expand_iterable(original, num_desired, default): """ - Expands the `original` argument to return a return a list of - length `num_desired`. If `original` is shorter than `num_desired`, it will + Expands the `original` argument to return a return a list of length `num_desired`. + + If `original` is shorter than `num_desired`, it will be padded with the value in `default`. If `original` is not a list to begin with (i.e. scalar value) a list of length `num_desired` completely populated with `default will be returned @@ -2359,9 +2282,7 @@ def _align_cell_veritically(text_lines, num_lines, column_width, row_alignment): return text_lines + blank * delta_lines -def _append_multiline_row( - lines, padded_multiline_cells, padded_widths, colaligns, rowfmt, pad, rowalign=None -): +def _append_multiline_row(lines, padded_multiline_cells, padded_widths, colaligns, rowfmt, pad, rowalign=None): colwidths = [w - 2 * pad for w in padded_widths] cells_lines = [c.splitlines() for c in padded_multiline_cells] nlines = max(map(len, cells_lines)) # number of lines in the row @@ -2370,10 +2291,7 @@ def _append_multiline_row( # (cl + [" " * w] * (nlines - len(cl))) for cl, w in zip(cells_lines, colwidths) # ] - cells_lines = [ - _align_cell_veritically(cl, nlines, w, rowalign) - for cl, w in zip(cells_lines, colwidths) - ] + cells_lines = [_align_cell_veritically(cl, nlines, w, rowalign) for cl, w in zip(cells_lines, colwidths)] lines_cells = [[cl[i] for cl in cells_lines] for i in range(nlines)] for ln in lines_cells: padded_ln = _pad_row(ln, pad) @@ -2399,21 +2317,18 @@ def _append_line(lines, colwidths, colaligns, linefmt): class JupyterHTMLStr(str): - """Wrap the string with a _repr_html_ method so that Jupyter - displays the HTML table""" + """Wrap the string with a _repr_html_ method so that Jupyter displays the HTML table.""" def _repr_html_(self): return self @property def str(self): - """add a .str property so that the raw string is still accessible""" + """Add a .str property so that the raw string is still accessible.""" return self -def _format_table( - fmt, headers, headersaligns, rows, colwidths, colaligns, is_multiline, rowaligns -): +def _format_table(fmt, headers, headersaligns, rows, colwidths, colaligns, is_multiline, rowaligns): """Produce a plain-text representation of the table.""" lines = [] hidden = fmt.with_header_hide if (headers and fmt.with_header_hide) else [] @@ -2442,9 +2357,7 @@ def _format_table( if padded_rows and fmt.linebetweenrows and "linebetweenrows" not in hidden: # initial rows with a line below for row, ralign in zip(padded_rows[:-1], rowaligns): - append_row( - lines, row, padded_widths, colaligns, fmt.datarow, rowalign=ralign - ) + append_row(lines, row, padded_widths, colaligns, fmt.datarow, rowalign=ralign) _append_line(lines, padded_widths, colaligns, fmt.linebetweenrows) # the last row without a line below append_row( @@ -2457,11 +2370,7 @@ def _format_table( ) else: separating_line = ( - fmt.linebetweenrows - or fmt.linebelowheader - or fmt.linebelow - or fmt.lineabove - or Line("", "", "", "") + fmt.linebetweenrows or fmt.linebelowheader or fmt.linebelow or fmt.lineabove or Line("", "", "", "") ) for row in padded_rows: # test to see if either the 1st column or the 2nd column (account for showindex) has @@ -2485,7 +2394,10 @@ def _format_table( class _CustomTextWrap(textwrap.TextWrapper): - """A custom implementation of CPython's textwrap.TextWrapper. This supports + """ + A custom implementation of CPython's textwrap.TextWrapper. + + This supports both wide characters (Korea, Japanese, Chinese) - including mixed string. For the most part, the `_handle_long_word` and `_wrap_chunks` functions were copy pasted out of the CPython baseline, and updated with our custom length @@ -2499,8 +2411,7 @@ def __init__(self, *args, **kwargs): @staticmethod def _len(item): - """Custom len that gets console column width for wide - and non-wide characters as well as ignores color codes""" + """Custom len that gets console column width for wide and non-wide characters as well as ignores color codes.""" stripped = _strip_ansi(item) if wcwidth: return wcwidth.wcswidth(stripped) @@ -2508,15 +2419,12 @@ def _len(item): return len(stripped) def _update_lines(self, lines, new_line): - """Adds a new line to the list of lines the text is being wrapped into - This function will also track any ANSI color codes in this string as well - as add any colors from previous lines order to preserve the same formatting + """Adds a new line to the list of lines the text is being wrapped into This function will also track any ANSI + color codes in this string as well as add any colors from previous lines order to preserve the same formatting as a single unwrapped string. """ code_matches = [x for x in _ansi_codes.finditer(new_line)] - color_codes = [ - code.string[code.span()[0] : code.span()[1]] for code in code_matches - ] + color_codes = [code.string[code.span()[0] : code.span()[1]] for code in code_matches] # Add color codes from earlier in the unwrapped line, and then track any new ones we add. new_line = "".join(self._active_codes) + new_line @@ -2527,7 +2435,7 @@ def _update_lines(self, lines, new_line): else: # A single reset code resets everything self._active_codes = [] - # Always ensure each line is color terminted if any colors are + # Always ensure each line is color terminated if any colors are # still active, otherwise colors will bleed into other cells on the console if len(self._active_codes) > 0: new_line = new_line + _ansi_color_reset_code @@ -2573,9 +2481,11 @@ def _handle_long_word(self, reversed_chunks, cur_line, cur_len, width): # devoted to the long word that we can't handle right now. def _wrap_chunks(self, chunks): - """_wrap_chunks(chunks : [string]) -> [string] - Wrap a sequence of text chunks and return a list of lines of - length 'self.width' or less. (If 'break_long_words' is false, + """ + _wrap_chunks(chunks : [string]) -> [string] Wrap a sequence of text chunks and return a list of lines of length + 'self.width' or less. + + (If 'break_long_words' is false, some lines may be longer than this.) Chunks correspond roughly to words and the whitespace between them: each chunk is indivisible (modulo 'break_long_words'), but a line break can @@ -2646,12 +2556,7 @@ def _wrap_chunks(self, chunks): if ( self.max_lines is None or len(lines) + 1 < self.max_lines - or ( - not chunks - or self.drop_whitespace - and len(chunks) == 1 - and not chunks[0].strip() - ) + or (not chunks or self.drop_whitespace and len(chunks) == 1 and not chunks[0].strip()) and cur_len <= width ): # Convert current line back to a string and store it in @@ -2659,10 +2564,7 @@ def _wrap_chunks(self, chunks): self._update_lines(lines, indent + "".join(cur_line)) else: while cur_line: - if ( - cur_line[-1].strip() - and cur_len + self._len(self.placeholder) <= width - ): + if cur_line[-1].strip() and cur_len + self._len(self.placeholder) <= width: cur_line.append(self.placeholder) self._update_lines(lines, indent + "".join(cur_line)) break @@ -2671,10 +2573,7 @@ def _wrap_chunks(self, chunks): else: if lines: prev_line = lines[-1].rstrip() - if ( - self._len(prev_line) + self._len(self.placeholder) - <= self.width - ): + if self._len(prev_line) + self._len(self.placeholder) <= self.width: lines[-1] = prev_line + self.placeholder break self._update_lines(lines, indent + self.placeholder.lstrip()) diff --git a/onnxslim/utils/utils.py b/onnxslim/utils/utils.py index 4121763..087208c 100644 --- a/onnxslim/utils/utils.py +++ b/onnxslim/utils/utils.py @@ -1,24 +1,18 @@ +import logging from typing import Dict, List, Optional, Tuple, Union import numpy as np - import onnx from ..utils.font import GREEN, WHITE from ..utils.tabulate import SEPARATING_LINE, tabulate - -import logging - # Configure logging logging.basicConfig( level=logging.ERROR, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - handlers=[ - logging.FileHandler("app.log"), - logging.StreamHandler() - ] + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + handlers=[logging.FileHandler("app.log"), logging.StreamHandler()], ) # Create a logger @@ -96,10 +90,7 @@ def gen_onnxruntime_input_data( if "data" in info: input_data_dict[name] = info["data"] else: - shapes = [ - shape if (shape != -1 and not isinstance(shape, str)) else 1 - for shape in info["shape"] - ] + shapes = [shape if (shape != -1 and not isinstance(shape, str)) else 1 for shape in info["shape"]] shapes = shapes if shapes else [1] dtype = info["dtype"] @@ -112,14 +103,10 @@ def gen_onnxruntime_input_data( return input_data_dict -def onnxruntime_inference( - model: onnx.ModelProto, input_data: dict -) -> Dict[str, np.array]: +def onnxruntime_inference(model: onnx.ModelProto, input_data: dict) -> Dict[str, np.array]: import onnxruntime as rt - sess = rt.InferenceSession( - model.SerializeToString(), providers=["CPUExecutionProvider"] - ) + sess = rt.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"]) onnx_output = sess.run(None, input_data) output_names = [output.name for output in sess.get_outputs()] @@ -128,12 +115,8 @@ def onnxruntime_inference( return onnx_output -def print_model_info_as_table( - model_name: str, model_info_list: List[Dict], elapsed_time: float = 0.0 -): - assert ( - len(model_info_list) > 0 - ), "model_info_list must contain more than one model info" +def print_model_info_as_table(model_name: str, model_info_list: List[Dict], elapsed_time: float = 0.0): + assert len(model_info_list) > 0, "model_info_list must contain more than one model info" final_op_info = [] if len(model_info_list) == 1: @@ -142,15 +125,11 @@ def print_model_info_as_table( final_op_info.append(["Op Set ", model_info_list[0]["op_set"]]) else: final_op_info.append( - ["Model Name", model_name, "Op Set: " + model_info_list[0]["op_set"]] - + [""] * (len(model_info_list) - 2) + ["Model Name", model_name, "Op Set: " + model_info_list[0]["op_set"]] + [""] * (len(model_info_list) - 2) ) final_op_info.append([SEPARATING_LINE]) - final_op_info.append( - ["Model Info", "Original Model"] - + ["Slimmed Model"] * (len(model_info_list) - 1) - ) + final_op_info.append(["Model Info", "Original Model"] + ["Slimmed Model"] * (len(model_info_list) - 1)) final_op_info.append([SEPARATING_LINE] * (len(model_info_list) + 1)) all_inputs = list(model_info_list[0]["op_input_info"].keys()) @@ -164,11 +143,7 @@ def print_model_info_as_table( input_info_list.append(inputs_shape) final_op_info.append(input_info_list) - all_outputs = set( - op_type - for model_info in model_info_list - for op_type in model_info.get("op_output_info", {}) - ) + all_outputs = set(op_type for model_info in model_info_list for op_type in model_info.get("op_output_info", {})) for outputs in all_outputs: output_info_list = [ @@ -181,11 +156,7 @@ def print_model_info_as_table( final_op_info.append([SEPARATING_LINE] * (len(model_info_list) + 1)) - all_ops = set( - op_type - for model_info in model_info_list - for op_type in model_info.get("op_type_counts", {}) - ) + all_ops = set(op_type for model_info in model_info_list for op_type in model_info.get("op_type_counts", {})) sorted_ops = list(all_ops) sorted_ops.sort() for op in sorted_ops: @@ -200,10 +171,7 @@ def print_model_info_as_table( final_op_info.append(op_info_list) final_op_info.append([SEPARATING_LINE] * (len(model_info_list) + 1)) - final_op_info.append( - ["Model Size"] - + [format_bytes(model_info["model_size"]) for model_info in model_info_list] - ) + final_op_info.append(["Model Size"] + [format_bytes(model_info["model_size"]) for model_info in model_info_list]) final_op_info.append([SEPARATING_LINE] * (len(model_info_list) + 1)) final_op_info.append(["Elapsed Time"] + [f"{elapsed_time:.2f} s"]) lines = tabulate( @@ -214,11 +182,7 @@ def print_model_info_as_table( ).split("\n") time_row = lines[-2].split("|") - time_row[-3] = ( - time_row[-2][: len(time_row[-2]) // 2 + 1] - + time_row[-3] - + time_row[-2][len(time_row[-2]) // 2 :] - ) + time_row[-3] = time_row[-2][: len(time_row[-2]) // 2 + 1] + time_row[-3] + time_row[-2][len(time_row[-2]) // 2 :] time_row.pop(-2) lines[-2] = "|".join(time_row) output = "\n".join([line if line != "| \x01 |" else lines[0] for line in lines]) diff --git a/tests/test_onnx_nets.py b/tests/test_onnx_nets.py index 63d9f2a..d76a8e0 100644 --- a/tests/test_onnx_nets.py +++ b/tests/test_onnx_nets.py @@ -7,7 +7,6 @@ import torch import torchvision.models as models - FUSE = True PRETRAINED = False diff --git a/tests/test_onnxslim.py b/tests/test_onnxslim.py index c5f1f77..4317705 100644 --- a/tests/test_onnxslim.py +++ b/tests/test_onnxslim.py @@ -1,7 +1,6 @@ import subprocess import pytest - from utils import download_onnx_from_url @@ -20,9 +19,7 @@ ) class TestOnnxModel: def test_onnx_model(self, request, name): - filename = download_onnx_from_url( - f"http://120.224.26.32:15030/aifarm/onnx/{name}.onnx" - ) + filename = download_onnx_from_url(f"http://120.224.26.32:15030/aifarm/onnx/{name}.onnx") command = f"onnxslim {filename} {name}_slim.onnx" result = subprocess.run(command, shell=True, capture_output=True, text=True) output = result.stderr.strip() @@ -32,23 +29,18 @@ def test_onnx_model(self, request, name): def test_onnxslim_python_api(self, request, name): import onnx + from onnxslim import slim - filename = download_onnx_from_url( - f"http://120.224.26.32:15030/aifarm/onnx/{name}.onnx" - ) + filename = download_onnx_from_url(f"http://120.224.26.32:15030/aifarm/onnx/{name}.onnx") model_slim = slim(filename) onnx.save(model_slim, f"{name}_slim.onnx") class TestFeat: def test_input_shape_modification(self, request): - filename = download_onnx_from_url( - f"http://120.224.26.32:15030/aifarm/onnx/UNetModel-fp16.onnx" - ) - command = ( - f"onnxslim {filename} UNetModel-fp16_slim.onnx --input_shapes cc:1,1,768" - ) + filename = download_onnx_from_url(f"http://120.224.26.32:15030/aifarm/onnx/UNetModel-fp16.onnx") + command = f"onnxslim {filename} UNetModel-fp16_slim.onnx --input_shapes cc:1,1,768" result = subprocess.run(command, shell=True, capture_output=True, text=True) output = result.stderr.strip() # Assert the expected return code @@ -56,9 +48,7 @@ def test_input_shape_modification(self, request): assert result.returncode == 0 def test_fp162fp32_conversion(self, request): - filename = download_onnx_from_url( - f"http://120.224.26.32:15030/aifarm/onnx/UNetModel-fp16.onnx" - ) + filename = download_onnx_from_url(f"http://120.224.26.32:15030/aifarm/onnx/UNetModel-fp16.onnx") command = f"onnxslim {filename} UNetModel-fp16_slim.onnx --input_shapes cc:1,1,768 --dtype fp32" result = subprocess.run(command, shell=True, capture_output=True, text=True) output = result.stderr.strip() @@ -67,9 +57,7 @@ def test_fp162fp32_conversion(self, request): assert result.returncode == 0 def test_output_modification(self, request): - filename = download_onnx_from_url( - f"http://120.224.26.32:15030/aifarm/onnx/yolov5m.onnx" - ) + filename = download_onnx_from_url(f"http://120.224.26.32:15030/aifarm/onnx/yolov5m.onnx") command = f"onnxslim {filename} yolov5m_slim.onnx --outputs 591 739 443" result = subprocess.run(command, shell=True, capture_output=True, text=True) output = result.stderr.strip() diff --git a/tests/utils.py b/tests/utils.py index 5902f3f..5333712 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -8,7 +8,6 @@ import tempfile import warnings import zipfile - from urllib.error import HTTPError from urllib.parse import urlparse # noqa: F401 from urllib.request import Request, urlopen @@ -46,9 +45,7 @@ def update(self, n): if self.total is None: sys.stderr.write("\r{0:.1f} bytes".format(self.n)) else: - sys.stderr.write( - "\r{0:.1f}%".format(100 * self.n / float(self.total)) - ) + sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) sys.stderr.flush() def close(self): @@ -98,9 +95,7 @@ def _remove_if_exists(path): def _git_archive_link(repo_owner, repo_name, branch): - return "https://github.com/{}/{}/archive/{}.zip".format( - repo_owner, repo_name, branch - ) + return "https://github.com/{}/{}/archive/{}.zip".format(repo_owner, repo_name, branch) def _load_attr_from_module(module, func_name): @@ -236,9 +231,7 @@ def _check_dependencies(m): if dependencies is not None: missing_deps = [pkg for pkg in dependencies if not _check_module_exists(pkg)] if len(missing_deps): - raise RuntimeError( - "Missing dependencies: {}".format(", ".join(missing_deps)) - ) + raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps))) def _load_entry_from_hubconf(m, model): @@ -306,9 +299,7 @@ def list(github, force_reload=False, skip_validation=False): Example: >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True) """ - repo_dir = _get_cache_or_reload( - github, force_reload, verbose=True, skip_validation=skip_validation - ) + repo_dir = _get_cache_or_reload(github, force_reload, verbose=True, skip_validation=skip_validation) sys.path.insert(0, repo_dir) @@ -318,11 +309,7 @@ def list(github, force_reload=False, skip_validation=False): sys.path.remove(repo_dir) # We take functions starts with '_' as internal helper functions - entrypoints = [ - f - for f in dir(hub_module) - if callable(getattr(hub_module, f)) and not f.startswith("_") - ] + entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith("_")] return entrypoints @@ -346,9 +333,7 @@ def help(github, model, force_reload=False, skip_validation=False): Example: >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True)) """ - repo_dir = _get_cache_or_reload( - github, force_reload, verbose=True, skip_validation=skip_validation - ) + repo_dir = _get_cache_or_reload(github, force_reload, verbose=True, skip_validation=skip_validation) sys.path.insert(0, repo_dir) @@ -424,14 +409,10 @@ def load( source = source.lower() if source not in ("github", "local"): - raise ValueError( - f'Unknown source: "{source}". Allowed values: "github" | "local".' - ) + raise ValueError(f'Unknown source: "{source}". Allowed values: "github" | "local".') if source == "github": - repo_or_dir = _get_cache_or_reload( - repo_or_dir, force_reload, verbose, skip_validation - ) + repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, verbose, skip_validation) model = _load_local(repo_or_dir, model, *args, **kwargs) return model @@ -470,7 +451,8 @@ def _load_local(hubconf_dir, model, *args, **kwargs): def download_url_to_file(url, dst, hash_prefix=None, progress=True): - r"""Download object at the given URL to a local path. + r""" + Download object at the given URL to a local path. Args: url (string): URL of the object to download @@ -482,7 +464,6 @@ def download_url_to_file(url, dst, hash_prefix=None, progress=True): Example: >>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file') - """ file_size = None req = Request(url, headers={"User-Agent": "torch.hub"}) @@ -525,11 +506,7 @@ def download_url_to_file(url, dst, hash_prefix=None, progress=True): if hash_prefix is not None: digest = sha256.hexdigest() if digest[: len(hash_prefix)] != hash_prefix: - raise RuntimeError( - 'invalid hash value (expected "{}", got "{}")'.format( - hash_prefix, digest - ) - ) + raise RuntimeError('invalid hash value (expected "{}", got "{}")'.format(hash_prefix, digest)) shutil.move(f.name, dst) finally: f.close() @@ -575,9 +552,7 @@ def _legacy_zip_load(filename, model_dir, map_location): return torch.load(extracted_file, map_location=map_location) -def download_onnx_from_url( - url, model_dir=None, progress=True, check_hash=False, file_name=None -): +def download_onnx_from_url(url, model_dir=None, progress=True, check_hash=False, file_name=None): if model_dir is None: hub_dir = get_dir() model_dir = os.path.join(hub_dir, "onnx") From 30b34f432170b113542b80b1a7a9790b4fa419d5 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 30 May 2024 12:49:29 +0200 Subject: [PATCH 3/3] Update format.yml --- .github/workflows/format.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index dc66e7b..9459732 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -7,7 +7,7 @@ name: Ultralytics Actions on: push: branches: [main] - pull_request: + pull_request_target: branches: [main] types: [opened, closed, synchronize]