diff --git a/onnxslim/core/__init__.py b/onnxslim/core/__init__.py index 1a3f1e4..6e6e2b5 100644 --- a/onnxslim/core/__init__.py +++ b/onnxslim/core/__init__.py @@ -8,9 +8,8 @@ import onnxslim.third_party.onnx_graphsurgeon as gs from onnxslim.core.optimization import optimize_model -from onnxslim.core.utils import delete_node from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import dtype_to_onnx -from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant +from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Variable from onnxslim.third_party.symbolic_shape_infer import SymbolicShapeInference from onnxslim.utils import save @@ -173,7 +172,7 @@ def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto: if node.op == "Cast": inp_dtype = [input.dtype for input in node.inputs][0] if inp_dtype in [np.float16, np.float32]: - delete_node(node) + node.replace_all_uses_with(node.inputs[0]) else: outp_dtype = [output.dtype for output in node.outputs][0] if outp_dtype == np.float16: diff --git a/onnxslim/core/optimization/__init__.py b/onnxslim/core/optimization/__init__.py index d413ffd..7161432 100644 --- a/onnxslim/core/optimization/__init__.py +++ b/onnxslim/core/optimization/__init__.py @@ -5,7 +5,6 @@ import onnx import onnxslim.third_party.onnx_graphsurgeon as gs -from onnxslim.core.pattern import get_node_feeds from onnxslim.core.pattern.registry import get_fusion_patterns from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph @@ -81,7 +80,7 @@ def get_previous_node_by_type(node, op_type, trajectory=None): """Recursively find and return the first preceding node of a specified type in the computation graph.""" if trajectory is None: trajectory = [] - node_feeds = get_node_feeds(node) + node_feeds = node.feeds for node_feed in node_feeds: trajectory.append(node_feed) if node_feed.op == op_type: diff --git a/onnxslim/core/optimization/dead_node_elimination.py b/onnxslim/core/optimization/dead_node_elimination.py index e686ac4..43aaf62 100644 --- a/onnxslim/core/optimization/dead_node_elimination.py +++ b/onnxslim/core/optimization/dead_node_elimination.py @@ -3,7 +3,6 @@ import numpy as np import onnxslim.third_party.onnx_graphsurgeon as gs -from onnxslim.core.utils import delete_node from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import dtype_to_onnx from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Variable @@ -20,28 +19,28 @@ def dead_node_elimination(graph, is_subgraph=False): for node in graph.nodes: if node.op in {"Identity", "Dropout"}: if not is_subgraph: - delete_node(node) + node.replace_all_uses_with(node.feeds[0]) logger.debug(f"removing {node.op} op: {node.name}") elif node.op == "Pad": if len(node.inputs) > 1 and isinstance(node.inputs[1], Constant): pad_value = node.inputs[1].values.tolist() pad_value = pad_value if isinstance(pad_value, list) else [pad_value] if all(value == 0 for value in pad_value): - delete_node(node) + node.replace_all_uses_with(node.feeds[0]) logger.debug(f"removing {node.op} op: {node.name}") elif node.op == "Cast": inp_dtype = [dtype_to_onnx(input.dtype) for input in node.inputs][0] if inp_dtype == node.attrs["to"]: - delete_node(node) + node.replace_all_uses_with(node.feeds[0]) logger.debug(f"removing {node.op} op: {node.name}") elif node.op == "Reshape": if (node.inputs[0].shape and len(node.inputs[0].shape) == 1) and ( node.outputs[0].shape and len(node.outputs[0].shape) == 1 ): - delete_node(node) + node.replace_all_uses_with(node.feeds[0]) logger.debug(f"removing {node.op} op: {node.name}") elif node.inputs[0].shape and node.outputs[0].shape and node.inputs[0].shape == node.outputs[0].shape: - delete_node(node) + node.replace_all_uses_with(node.feeds[0]) logger.debug(f"removing {node.op} op: {node.name}") else: node_output_shape = node.outputs[0].shape @@ -61,7 +60,7 @@ def dead_node_elimination(graph, is_subgraph=False): idx, constant_variable = get_constant_variable(node, return_idx=True) if np.all(constant_variable.values == 1): var_idx = 0 if idx == 1 else 1 - delete_node(node, var_idx) + node.replace_all_uses_with(node.feeds[var_idx]) logger.debug(f"removing {node.op} op: {node.name}") elif node.op == "Add": if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or ( @@ -71,10 +70,10 @@ def dead_node_elimination(graph, is_subgraph=False): value = constant_variable.values var_idx = 0 if idx == 1 else 1 if value.ndim == 0 and value == 0: - delete_node(node, var_idx) + node.replace_all_uses_with(node.feeds[var_idx]) logger.debug(f"removing {node.op} op: {node.name}") elif np.all(value == 0) and (node.inputs[var_idx].shape == node.outputs[0].shape): - delete_node(node, var_idx) + node.replace_all_uses_with(node.feeds[var_idx]) logger.debug(f"removing {node.op} op: {node.name}") elif node.op == "Expand": # tests/test_onnx_nets.py::TestTimmClass::test_timm[lambda_resnet26rpt_256] @@ -82,14 +81,14 @@ def dead_node_elimination(graph, is_subgraph=False): constant_variable = node.inputs[1] value = constant_variable.values if node.inputs[0].shape == node.outputs[0].shape: - delete_node(node) + node.replace_all_uses_with(node.feeds[0]) logger.debug(f"removing {node.op} op: {node.name}") elif value.ndim == 0 and value == 1: - delete_node(node) + node.replace_all_uses_with(node.feeds[0]) logger.debug(f"removing {node.op} op: {node.name}") elif node.op == "Concat": if len(node.inputs) == 1: - delete_node(node) + node.replace_all_uses_with(node.feeds[0]) logger.debug(f"removing {node.op} op: {node.name}") else: for input in node.inputs: @@ -100,20 +99,20 @@ def dead_node_elimination(graph, is_subgraph=False): constant_variable = node.inputs[1] value = constant_variable.values if value.ndim == 0 and value == 0: - delete_node(node) + node.replace_all_uses_with(node.feeds[0]) logger.debug(f"removing {node.op} op: {node.name}") elif np.all(value == 0) and (node.inputs[0].shape == node.outputs[0].shape): - delete_node(node) + node.replace_all_uses_with(node.feeds[0]) logger.debug(f"removing {node.op} op: {node.name}") elif node.op == "Div": if isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable): constant_variable = node.inputs[1] value = constant_variable.values if value.ndim == 0 and value == 1: - delete_node(node) + node.replace_all_uses_with(node.feeds[0]) logger.debug(f"removing {node.op} op: {node.name}") elif np.all(value == 1) and (node.inputs[0].shape == node.outputs[0].shape): - delete_node(node) + node.replace_all_uses_with(node.feeds[0]) logger.debug(f"removing {node.op} op: {node.name}") diff --git a/onnxslim/core/optimization/subexpression_elimination.py b/onnxslim/core/optimization/subexpression_elimination.py index 23215ae..ec370f8 100644 --- a/onnxslim/core/optimization/subexpression_elimination.py +++ b/onnxslim/core/optimization/subexpression_elimination.py @@ -1,6 +1,5 @@ import logging -from onnxslim.core.pattern import get_node_users from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Variable logger = logging.getLogger("onnxslim") @@ -17,7 +16,7 @@ def get_node_key(node): return "_".join(input_names) if input_names else None def replace_node_references(existing_node, to_be_removed_node): - users = get_node_users(to_be_removed_node) + users = to_be_removed_node.users for user in users: for inp in user.inputs: if inp in to_be_removed_node.outputs: diff --git a/onnxslim/core/pattern/__init__.py b/onnxslim/core/pattern/__init__.py index 32bf0d5..ba929c6 100644 --- a/onnxslim/core/pattern/__init__.py +++ b/onnxslim/core/pattern/__init__.py @@ -8,29 +8,6 @@ logger = logging.getLogger("onnxslim") -def get_node_users(node): - """Retrieve the list of nodes that use the outputs of the given node.""" - users = [] - for output in node.outputs: # output is a Variable - if output.is_output: - users.append(output) - users.extend(iter(output.outputs)) - return users - - -def get_node_feeds(node): - """Retrieve the list of nodes that provide inputs to the given node.""" - feeds = [] - for input in node.inputs: - if len(input.inputs) == 0 and not isinstance(input, Constant): - feeds.append(input) - elif isinstance(input, Constant): - feeds.append(input) - else: - feeds.extend(input if feed.op == "Split" else feed for feed in input.inputs) - return feeds - - def get_name(name): """Sanitizes the input string by replacing illegal characters with underscores and prefixing with an underscore if numeric. @@ -142,7 +119,7 @@ def match_(node, pattern_node): if node.op == pattern_node.op: setattr(self, pattern_node.name, node) - node_feeds = get_node_feeds(node) + node_feeds = node.feeds if pattern_node.coarse_input_num: if len(node_feeds) < len(pattern_node.input_names): return False @@ -207,8 +184,8 @@ def generate(self): for node in nodes: if node.op != "Constant": name = get_name(node.name) - feeds = get_node_feeds(node) - users = get_node_users(node) + feeds = node.feeds + users = node.users template.append( " ".join( [node.op, name, str(len(feeds)), str(len(users))] diff --git a/onnxslim/core/pattern/elimination/reshape.py b/onnxslim/core/pattern/elimination/reshape.py index 617f12f..26da8ab 100644 --- a/onnxslim/core/pattern/elimination/reshape.py +++ b/onnxslim/core/pattern/elimination/reshape.py @@ -1,7 +1,7 @@ import numpy as np import onnxslim.third_party.onnx_graphsurgeon as gs -from onnxslim.core.pattern import Pattern, PatternMatcher, get_node_users +from onnxslim.core.pattern import Pattern, PatternMatcher from onnxslim.core.pattern.registry import register_fusion_pattern @@ -33,7 +33,7 @@ def rewrite(self, opset=11): node = self.reshape_1 first_reshape_node = node.i(0) first_reshape_node_inputs = list(first_reshape_node.inputs) - first_reshape_node_users = get_node_users(first_reshape_node) + first_reshape_node_users = first_reshape_node.users if len(first_reshape_node_users) == 1: second_reshape_node = node diff --git a/onnxslim/core/pattern/elimination/slice.py b/onnxslim/core/pattern/elimination/slice.py index 14bec9c..867ce62 100644 --- a/onnxslim/core/pattern/elimination/slice.py +++ b/onnxslim/core/pattern/elimination/slice.py @@ -1,7 +1,7 @@ import numpy as np import onnxslim.third_party.onnx_graphsurgeon as gs -from onnxslim.core.pattern import Pattern, PatternMatcher, get_node_users +from onnxslim.core.pattern import Pattern, PatternMatcher from onnxslim.core.pattern.registry import register_fusion_pattern @@ -29,7 +29,7 @@ def rewrite(self, opset=11): first_slice_node = self.slice_0 first_slice_node_inputs = list(first_slice_node.inputs) if all(isinstance(input, gs.Constant) for input in first_slice_node_inputs[1:]): - first_slice_node_users = get_node_users(first_slice_node) + first_slice_node_users = first_slice_node.users if all( user.op == "Slice" and all(isinstance(input, gs.Constant) for input in list(user.inputs)[1:]) for user in first_slice_node_users diff --git a/onnxslim/core/pattern/elimination/unsqueeze.py b/onnxslim/core/pattern/elimination/unsqueeze.py index c4fcf8d..a7f35c4 100644 --- a/onnxslim/core/pattern/elimination/unsqueeze.py +++ b/onnxslim/core/pattern/elimination/unsqueeze.py @@ -1,7 +1,7 @@ import numpy as np import onnxslim.third_party.onnx_graphsurgeon as gs -from onnxslim.core.pattern import Pattern, PatternMatcher, get_node_users +from onnxslim.core.pattern import Pattern, PatternMatcher from onnxslim.core.pattern.registry import register_fusion_pattern @@ -27,7 +27,7 @@ def rewrite(self, opset=11): """Rewrites an elimination pattern for unsqueeze nodes by optimizing nested slice operations.""" match_case = {} node_unsqueeze_0 = self.unsqueeze_0 - users_node_unsqueeze_0 = get_node_users(node_unsqueeze_0) + users_node_unsqueeze_0 = node_unsqueeze_0.users node_unsqueeze_1 = self.unsqueeze_1 if len(users_node_unsqueeze_0) == 1 and node_unsqueeze_0.inputs[0].shape and node_unsqueeze_1.inputs[0].shape: if opset < 13 or ( diff --git a/onnxslim/core/pattern/fusion/convadd.py b/onnxslim/core/pattern/fusion/convadd.py index dbf243d..96bc028 100644 --- a/onnxslim/core/pattern/fusion/convadd.py +++ b/onnxslim/core/pattern/fusion/convadd.py @@ -1,5 +1,5 @@ import onnxslim.third_party.onnx_graphsurgeon as gs -from onnxslim.core.pattern import Pattern, PatternMatcher, get_node_users +from onnxslim.core.pattern import Pattern, PatternMatcher from onnxslim.core.pattern.registry import register_fusion_pattern @@ -25,7 +25,7 @@ def rewrite(self, opset=11): match_case = {} conv_node = self.conv_0 conv_weight = list(conv_node.inputs)[1] - conv_node_users = get_node_users(conv_node) + conv_node_users = conv_node.users node = self.add_0 if ( len(conv_node_users) == 1 diff --git a/onnxslim/core/pattern/fusion/convbn.py b/onnxslim/core/pattern/fusion/convbn.py index f4aa0e0..4801600 100644 --- a/onnxslim/core/pattern/fusion/convbn.py +++ b/onnxslim/core/pattern/fusion/convbn.py @@ -1,7 +1,7 @@ import numpy as np import onnxslim.third_party.onnx_graphsurgeon as gs -from onnxslim.core.pattern import Pattern, PatternMatcher, get_node_users +from onnxslim.core.pattern import Pattern, PatternMatcher from onnxslim.core.pattern.registry import register_fusion_pattern @@ -27,7 +27,7 @@ def rewrite(self, opset=11): """Rewrites the weights and biases of a BatchNormalization layer fused with a convolution layer.""" match_case = {} conv_transpose_node = self.conv_0 - conv_transpose_node_users = get_node_users(conv_transpose_node) + conv_transpose_node_users = conv_transpose_node.users node = self.bn_0 if len(conv_transpose_node_users) == 1: conv_transpose_weight = conv_transpose_node.inputs[1].values diff --git a/onnxslim/core/pattern/fusion/gemm.py b/onnxslim/core/pattern/fusion/gemm.py index 0ef1a0f..750f4f7 100644 --- a/onnxslim/core/pattern/fusion/gemm.py +++ b/onnxslim/core/pattern/fusion/gemm.py @@ -2,7 +2,7 @@ import onnxslim.third_party.onnx_graphsurgeon as gs from onnxslim.core.optimization.dead_node_elimination import get_constant_variable -from onnxslim.core.pattern import Pattern, PatternMatcher, get_node_users +from onnxslim.core.pattern import Pattern, PatternMatcher from onnxslim.core.pattern.registry import register_fusion_pattern @@ -35,7 +35,7 @@ def rewrite(self, opset=11): input_variable = ( matmul_node.inputs[0] if isinstance(matmul_node.inputs[1], gs.Constant) else matmul_node.inputs[1] ) - users = get_node_users(matmul_node) + users = matmul_node.users if len(users) == 1 and matmul_bias_variable and len(matmul_bias_variable.shape) == 2: if ( input_variable.shape diff --git a/onnxslim/core/pattern/fusion/padconv.py b/onnxslim/core/pattern/fusion/padconv.py index 42bf176..df4132f 100644 --- a/onnxslim/core/pattern/fusion/padconv.py +++ b/onnxslim/core/pattern/fusion/padconv.py @@ -1,5 +1,5 @@ import onnxslim.third_party.onnx_graphsurgeon as gs -from onnxslim.core.pattern import Pattern, PatternMatcher, get_node_users +from onnxslim.core.pattern import Pattern, PatternMatcher from onnxslim.core.pattern.registry import register_fusion_pattern @@ -33,7 +33,7 @@ def rewrite(self, opset=11): match_case = {} conv_node = self.conv_0 pad_node = self.pad_0 - pad_node_users = get_node_users(pad_node) + pad_node_users = pad_node.users pad_inputs = len(pad_node.inputs) if pad_inputs < 3 or ( diff --git a/onnxslim/core/pattern/fusion/reduce.py b/onnxslim/core/pattern/fusion/reduce.py index 29f31d0..3322080 100644 --- a/onnxslim/core/pattern/fusion/reduce.py +++ b/onnxslim/core/pattern/fusion/reduce.py @@ -1,4 +1,4 @@ -from onnxslim.core.pattern import Pattern, PatternMatcher, get_node_users +from onnxslim.core.pattern import Pattern, PatternMatcher from onnxslim.core.pattern.registry import register_fusion_pattern @@ -25,7 +25,7 @@ def rewrite(self, opset=11): match_case = {} node = self.unsqueeze_0 reduce_node = self.reduce_0 - reduce_node_node_users = get_node_users(reduce_node) + reduce_node_node_users = reduce_node.users if len(reduce_node_node_users) == 1: unsqueeze_node = node diff --git a/onnxslim/core/utils.py b/onnxslim/core/utils.py deleted file mode 100644 index 33700db..0000000 --- a/onnxslim/core/utils.py +++ /dev/null @@ -1,32 +0,0 @@ -from onnxslim.core.pattern import get_node_feeds, get_node_users -from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Variable - - -def delete_node(node, input_var_idx=0, output_var_idx=0): - """Delete a node from the computation graph while re-linking its input and output to maintain graph integrity.""" - input_variable = node.inputs[input_var_idx] - node_variable = node.outputs[output_var_idx] - next_nodes = get_node_users(node) - - output_var = None - for next_node in next_nodes: - if isinstance(next_node, Variable) and next_node.is_output: - output_var = next_node - break - - if output_var: - feeds = get_node_feeds(node) - feed = feeds[0] - if not isinstance(feed, (Variable, Constant)): - feed.outputs.remove(node.inputs[input_var_idx]) - feed.outputs.append(node.outputs[output_var_idx]) - for user in list(node.inputs[input_var_idx].outputs): - for i, input in enumerate(user.inputs): - if input == node.inputs[input_var_idx]: - user.inputs[i] = node.outputs[output_var_idx] - node.outputs.clear() - else: - for next_node in next_nodes: - index = next_node.inputs.index(node_variable) - next_node.inputs.pop(index) - next_node.inputs.insert(index, input_variable) diff --git a/onnxslim/third_party/onnx_graphsurgeon/ir/node.py b/onnxslim/third_party/onnx_graphsurgeon/ir/node.py index fccf357..b1cef6c 100644 --- a/onnxslim/third_party/onnx_graphsurgeon/ir/node.py +++ b/onnxslim/third_party/onnx_graphsurgeon/ir/node.py @@ -17,9 +17,9 @@ from collections import OrderedDict from dataclasses import dataclass -from typing import Dict, List +from typing import Dict, List, Union -from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Tensor +from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER from onnxslim.third_party.onnx_graphsurgeon.util import misc @@ -215,3 +215,57 @@ def __eq__(self, other): outputs_match = misc.sequences_equal(self.outputs, other.outputs) return self.domain == other.domain if outputs_match else False + + @property + def users(self): + users = [] + for output in self.outputs: # output is a Variable + if output.is_output: + users.append(output) + users.extend(iter(output.outputs)) + return users + + @property + def feeds(self): + """Retrieve the list of nodes that provide inputs to the given node.""" + feeds = [] + for input in self.inputs: + if len(input.inputs) == 0 and not isinstance(input, Constant): + feeds.append(input) + elif isinstance(input, Constant): + feeds.append(input) + else: + feeds.extend(input if feed.op == "Split" else feed for feed in input.inputs) + return feeds + + def replace_all_uses_with(self, node: Union["Node", "Tensor"], input_var_idx=0, output_var_idx=0): + """Replace all uses of this node with the given node.""" + if isinstance(node, Node): + input_var = node.outputs[output_var_idx] + else: + input_var = node + + output_var = None + for output in self.outputs: + if isinstance(output, Variable) and output.is_output: + output_var = output + break + + if output_var: + feed = self.feeds[0] + if not isinstance(feed, (Variable, Constant)): + index = feed.outputs.index(self.inputs[input_var_idx]) + feed.outputs.pop(index) + feed.outputs.insert(index, self.outputs[output_var_idx]) + for user in list(self.inputs[input_var_idx].outputs): + # do not use index here, because index will only return the first index of the input + for i, input in enumerate(user.inputs): + if input == self.inputs[input_var_idx]: + user.inputs[i] = self.outputs[output_var_idx] + self.outputs.clear() + else: + for output in self.outputs: + for node_ in output.outputs: + index = node_.inputs.index(output) + node_.inputs.pop(index) + node_.inputs.insert(index, input_var) diff --git a/tests/test_modelzoo.py b/tests/test_modelzoo.py index 6412ded..fe1c956 100644 --- a/tests/test_modelzoo.py +++ b/tests/test_modelzoo.py @@ -86,6 +86,8 @@ def test_layer_normalization_2d_axis0_expanded_ver18(self, request): with tempfile.TemporaryDirectory() as tempdir: slim(filename, os.path.join(tempdir, f"{name}_slim.onnx"), model_check=True) + summary = summarize_model(os.path.join(tempdir, f"{name}_slim.onnx"), tag=request.node.name) + assert summary.op_type_counts["Reshape"] == 1 def test_padconv(self, request): name = request.node.originalname[len("test_") :]