diff --git a/onnxslim/core/graph_rewriter.py b/onnxslim/core/graph_rewriter.py new file mode 100644 index 0000000..7c6dbdf --- /dev/null +++ b/onnxslim/core/graph_rewriter.py @@ -0,0 +1,144 @@ +import re +from abc import ABCMeta, abstractmethod + +from onnxslim.utils import logger +from onnxslim.onnx_graphsurgeon import Constant + + +def get_node_users(node): + """Retrieve the list of nodes that use the outputs of the given node.""" + users = [] + for output in node.outputs: # output is a Variable + users.extend(iter(output.outputs)) + return users + + +def get_node_feeds(node): + """Retrieve the list of nodes that provide inputs to the given node.""" + feeds = [] + for input in node.inputs: + if len(input.inputs) == 0 and not isinstance(input, Constant): + feeds.append(input) + elif isinstance(input, Constant): + feeds.append(input) + else: + for feed in input.inputs: + if feed.op == "Split": + feeds.append(input) + else: + feeds.append(feed) + return feeds + + +class NodeDescriptor: + def __init__(self, node_spec): + if not isinstance(node_spec, list): + raise ValueError('node_spec must be a list') + if len(node_spec) < 4: + raise ValueError(f'node_spec must have at least 4 elements {node_spec}') + + def get_input_info(io_spec): + if not io_spec.isdigit(): + pattern_with_plus = re.search(r'(\d+)(\+)', io_spec) + if pattern_with_plus: + return int(pattern_with_plus.group(1)), True + else: + raise ValueError(f'input_num and output_num must be integers {io_spec}') + + return int(io_spec), False + + self.op = node_spec[0] + self.name = node_spec[1] + self.input_num, self.coarse_input_num = get_input_info(node_spec[2]) + self.output_num, self.coarse_output_num = get_input_info(node_spec[3]) + self.input_names = node_spec[4:4 + self.input_num] + self.output_names = node_spec[4 + self.input_num:] + assert len(self.input_names) == self.input_num + assert len(self.output_names) == self.output_num, f'{self.name} {len(self.output_names)} != {self.output_num}' + + def __repr__(self): + return f'name: {self.name}, type: {self.op}, input_num: {self.input_num}, output_num: {self.output_num}, input_names: {self.input_names}, output_names: {self.output_names}' + + def __dict__(self): + return { + 'name': self, + } + + +class Pattern: + def __init__(self, pattern): + self.pattern = pattern + self.nodes = self.parse_nodes() + + def parse_nodes(self): + nodes = self.pattern.split('\n') + nodes = [line.strip().split() for line in nodes if line] + nodes = [NodeDescriptor(node) for node in nodes if node] + return nodes + + def match(self, node): + return self.pattern.match(node) + + def __repr__(self): + return self.pattern + + +class PatternMatcher: + def __init__(self, pattern, priority): + self.pattern = pattern + self.priority = priority + self.pattern_dict = {node.name: node for node in pattern.nodes} + + def get_match_point(self): + return self.pattern_dict[self.pattern_dict['output'].input_names[0]] + + def match(self, node): + match_point = self.get_match_point() + + def match_(node, pattern_node): + if pattern_node.op == 'input': + return True + + # node is an input variable + if not hasattr(node, 'op'): + return False + + if node.op == pattern_node.op: + setattr(self, pattern_node.name, node) + + node_feeds = get_node_feeds(node) + if pattern_node.coarse_input_num: + if len(node_feeds) <= len(pattern_node.input_names): + return False + else: + if len(node_feeds) != len(pattern_node.input_names): + logger.debug('len(node_feeds) != len(pattern_node.input_names)', + len(node_feeds), len(pattern_node.input_names)) + return False + + pattern_nodes = [self.pattern_dict[name] if name != "?" else None for name in pattern_node.input_names] + all_match = True + for node_feed, pattern_node in zip(node_feeds, pattern_nodes): + if pattern_node is not None: + node_match = match_(node_feed, pattern_node) + if not node_match: + return False + setattr(self, pattern_node.name, node_feed) + + return all_match + + return False + + if match_(node, match_point): + setattr(self, 'output', node.outputs) + if self.parameter_check(): + return True + + return False + + @abstractmethod + def rewrite(self): + raise NotImplementedError('rewrite method must be implemented') + + def parameter_check(self): + return True \ No newline at end of file diff --git a/onnxslim/core/optimizer.py b/onnxslim/core/optimizer.py index 7439d24..8c75145 100644 --- a/onnxslim/core/optimizer.py +++ b/onnxslim/core/optimizer.py @@ -1,4 +1,3 @@ -import contextlib from collections import Counter, OrderedDict from typing import List, Union @@ -10,20 +9,18 @@ from onnxslim.onnx_graphsurgeon.ir.graph import Graph from onnxslim.onnx_graphsurgeon.ir.tensor import Constant, Variable from onnxslim.utils import logger +from onnxslim.core.graph_rewriter import PatternMatcher, Pattern, get_node_feeds, get_node_users DEFAULT_FUSION_PATTERNS = OrderedDict() -def register_fusion_pattern(layer_type): +def register_fusion_pattern(fusion_pattern): """Registers a fusion pattern function for a specified layer type in the DEFAULT_FUSION_PATTERNS dictionary.""" + layer_type = fusion_pattern.name - def insert(fn): - if layer_type in DEFAULT_FUSION_PATTERNS.keys(): - raise - DEFAULT_FUSION_PATTERNS[layer_type] = fn - return fn - - return insert + if layer_type in DEFAULT_FUSION_PATTERNS.keys(): + raise + DEFAULT_FUSION_PATTERNS[layer_type] = fusion_pattern def get_fusion_patterns(skip_fusion_patterns: str = None): @@ -174,72 +171,92 @@ def graph_constant_fold_inplace(graph): node.inputs.remove(input) -@register_fusion_pattern("FusionPadConv") -def find_conv_nodes(node, opset): - """Identify and match convolution nodes following a padding operation to update padding attributes for fusion - purposes. - """ - """ - x - | - Pad - | - Conv - """ - # fmt: on - match = {} - if node.op == "Conv" and node.i(0).op == "Pad": - pad_node = node.i(0) - if isinstance(pad_node.inputs[1], Constant): - pad_value = pad_node.inputs[1].values.tolist() - input_variable = node.i(0).inputs[0] - input_variable.outputs.remove(pad_node) - - pad_variable = node.i(0).outputs[0] # pad output variable - index = node.inputs.index(pad_variable) - node.inputs.pop(index) - node.inputs.insert(index, input_variable) - - inputs = list(node.inputs) - outputs = list(node.outputs) - attrs = node.attrs - - node.inputs.clear() - node.outputs.clear() - pad_node.inputs.clear() - pad_node.outputs.clear() - conv_pads = attrs["pads"] - len_conv_pads = len(conv_pads) // 2 - - len_pads = len(pad_value) // 2 - pads = pad_value[len_pads - len_conv_pads : len_pads] + pad_value[len_pads + len_conv_pads :] - - pads = [pad + conv_pad for pad, conv_pad in zip(pads, conv_pads)] - attrs["pads"] = pads - match[node.name] = { - "op": "Conv", - "inputs": inputs, - "outputs": outputs, - "name": node.name, - "attrs": node.attrs, - "domain": None, - } +class PadConvMatcher(PatternMatcher): + def __init__(self, priority): + pattern = Pattern( + ''' + input input 0 1 pad_0 + Pad pad_0 1+ 1 input conv_0 + Conv conv_0 1+ 1 pad_0 output + output output 1 0 conv_0 + ''') + super().__init__(pattern, priority) + + @property + def name(self): + return "FusionPadConv" + + def parameter_check(self): + pad_node = self.pad_0 + if not isinstance(pad_node.inputs[1], Constant): + return False + + return True - return match + def rewrite(self): + match_case = {} + node = self.conv_0 + pad_node = self.pad_0 + input_variable = self.pad_0.inputs[0] + pad_value = pad_node.inputs[1].values.tolist() + input_variable.outputs.remove(pad_node) -@register_fusion_pattern("FusionConvBN") -def find_conv_transpose_nodes(node, opset): - # fmt: off - """X | Conv/ConvTranspose | BatchNormalization.""" - # fmt: on - match = {} - if node.op == "BatchNormalization" and node.i(0).op in { - "ConvTranspose", - "Conv", - }: - conv_transpose_node = node.i(0) + pad_variable = pad_node.outputs[0] # pad output variable + index = node.inputs.index(pad_variable) + node.inputs.pop(index) + node.inputs.insert(index, input_variable) + + inputs = list(node.inputs) + outputs = list(node.outputs) + attrs = node.attrs + + node.inputs.clear() + node.outputs.clear() + pad_node.inputs.clear() + pad_node.outputs.clear() + conv_pads = attrs["pads"] + len_conv_pads = len(conv_pads) // 2 + + len_pads = len(pad_value) // 2 + pads = pad_value[len_pads - len_conv_pads : len_pads] + pad_value[len_pads + len_conv_pads :] + + pads = [pad + conv_pad for pad, conv_pad in zip(pads, conv_pads)] + attrs["pads"] = pads + + match_case[node.name] = { + "op": "Conv", + "inputs": inputs, + "outputs": outputs, + "name": node.name, + "attrs": node.attrs, + "domain": None, + } + + return match_case + +register_fusion_pattern(PadConvMatcher(1)) + +class ConvBatchNormMatcher(PatternMatcher): + def __init__(self, priority): + pattern = Pattern( + ''' + input input 0 1 conv_0 + Conv conv_0 3 1 input ? ? bn_0 + BatchNormalization bn_0 5 1 conv_0 ? ? ? ? output + output output 1 0 bn_0 + ''') + super().__init__(pattern, priority) + + @property + def name(self): + return "FusionConvBN" + + def rewrite(self): + match_case = {} + conv_transpose_node = self.conv_0 conv_transpose_node_users = get_node_users(conv_transpose_node) + node = self.bn_0 if len(conv_transpose_node_users) == 1: conv_transpose_weight = conv_transpose_node.inputs[1].values bn_node = node @@ -282,7 +299,7 @@ def find_conv_transpose_nodes(node, opset): bn_node.inputs.clear() bn_node.outputs.clear() - match[conv_transpose_node.name] = { + match_case[conv_transpose_node.name] = { "op": conv_transpose_node.op, "inputs": inputs, "outputs": outputs, @@ -291,23 +308,28 @@ def find_conv_transpose_nodes(node, opset): "domain": None, } - return match + return match_case +register_fusion_pattern(ConvBatchNormMatcher(1)) -@register_fusion_pattern("EliminationSlice") -def find_slice_nodes(node, opset): - """Identify and combine consecutive 'Slice' nodes in a computational graph for optimization purposes.""" - """ - x - | - Slice - | - Slice - """ - # fmt: on - match = {} - if node.op == "Slice" and node.i(0).op == "Slice": - first_slice_node = node.i(0) +class SlicePatternMatcher(PatternMatcher): + def __init__(self, priority): + pattern = Pattern( + ''' + input input 0 1 slice_0 + Slice slice_0 5 1 input ? ? ? ? slice_1 + Slice slice_1 5 1 slice_0 ? ? ? ? output + output output 1 0 slice_1 + ''') # to check here slice_0 + super().__init__(pattern, priority) + + @property + def name(self): + return "EliminationSlice" + + def rewrite(self): + match_case = {} + first_slice_node = self.slice_0 first_slice_node_inputs = list(first_slice_node.inputs) if all(isinstance(input, Constant) for input in first_slice_node_inputs[1:]): first_slice_node_users = get_node_users(first_slice_node) @@ -365,7 +387,7 @@ def find_slice_nodes(node, opset): second_slice_node.outputs.clear() if len(first_slice_node_users) == 1: - match[first_slice_node.name] = { + match_case[first_slice_node.name] = { "op": "Slice", "inputs": inputs, "outputs": outputs, @@ -374,7 +396,7 @@ def find_slice_nodes(node, opset): "domain": None, } else: - match[second_slice_node.name] = { + match_case[second_slice_node.name] = { "op": "Slice", "inputs": inputs, "outputs": outputs, @@ -383,30 +405,33 @@ def find_slice_nodes(node, opset): "domain": None, } - return match + return match_case +register_fusion_pattern(SlicePatternMatcher(1)) -@register_fusion_pattern("EliminationReshape") -def find_reshape_nodes(node, opset): - """Identify consecutive 'Reshape' nodes in the computational graph for potential fusion, returning a matching - dictionary when criteria are met. - """ - """ - x - | - Reshape - | - Reshape - """ - # fmt: on - match = {} - if node.op == "Reshape" and node.i(0).op == "Reshape": +class ReshapePatternMatcher(PatternMatcher): + def __init__(self, priority): + pattern = Pattern( + ''' + input input 0 1 reshape_0 + Reshape reshape_0 2 1 input ? reshape_1 + Reshape reshape_1 2 1 reshape_0 ? output + output output 1 0 reshape_1 + ''') + super().__init__(pattern, priority) + + @property + def name(self): + return "EliminationReshape" + + def rewrite(self): + match_case = {} + node = self.reshape_1 first_reshape_node = node.i(0) first_reshape_node_inputs = list(first_reshape_node.inputs) first_reshape_node_users = get_node_users(first_reshape_node) if len(first_reshape_node_users) == 1: second_reshape_node = node - def check_constant_mergeable(reshape_node): if isinstance(reshape_node.inputs[1], Constant): input_shape = reshape_node.inputs[0].shape @@ -428,7 +453,7 @@ def check_constant_mergeable(reshape_node): second_reshape_node.inputs.clear() second_reshape_node.outputs.clear() - match[first_reshape_node.name] = { + match_case[first_reshape_node.name] = { "op": "Reshape", "inputs": inputs, "outputs": outputs, @@ -437,66 +462,29 @@ def check_constant_mergeable(reshape_node): "domain": None, } - return match - - -# @register_fusion_pattern("EliminationTranspose") -def find_slice_nodes(node, opset): - """Identifies and processes patterns of consecutive Transpose nodes in a computational graph.""" - """ - x - | - Transpose - | - Transpose - """ - if node.op == "Transpose": - previous_nodes = get_previous_node_by_type(node, "Transpose") - if previous_nodes: - if len(previous_nodes) == 1: - delete_node(node) - delete_node(previous_nodes[-1]) - else: - delete_node(node) - previous_transpose_node = previous_nodes[-1] - last_node = previous_nodes[-2] - slice_axis = gs.Constant( - f"{last_node.name}_slice_axis", - values=np.array([2]).astype(np.int64), - ) - last_node.inputs.pop(3) - last_node.inputs.insert(3, slice_axis) - previous_transpose_node_variable = previous_transpose_node.outputs[0] # pad output variable - previous_transpose_node_variable.outputs.remove(last_node) - last_node.inputs.insert(0, previous_transpose_node.inputs[0]) - for node in previous_nodes: - for output in node.outputs: - if isinstance(output, Constant): - continue - output.shape = None - - return {} - - -@register_fusion_pattern("FusionGemm") -def find_matmul_add_nodes(node, opset): - """Identifies and returns a pattern match for MatMul followed by Add operations for optimization in a computational - graph. - """ - """ - x - | - MatMul - | - Add - """ - # fmt: on - match = {} - if node.op == "Add" and ( - (isinstance(node.inputs[1], Constant) and node.i(0).op == "MatMul") - or (isinstance(node.inputs[0], Constant) and node.i(1).op == "MatMul") - ): - matmul_node = node.i(0) if isinstance(node.inputs[1], Constant) else node.i(1) + return match_case + +register_fusion_pattern(ReshapePatternMatcher(1)) + +class MatMulAddPatternMatcher(PatternMatcher): + def __init__(self, priority): + pattern = Pattern( + ''' + input input 0 1 matmul_0 + MatMul matmul_0 2 1 input ? add_0 + Add add_0 2 1 matmul_0 ? output + output output 1 0 add_0 + ''') + super().__init__(pattern, priority) + + @property + def name(self): + return "FusionGemm" + + def rewrite(self): + match_case = {} + node = self.add_0 + matmul_node = self.matmul_0 matmul_bias_variable = get_constant_variable(matmul_node) input_variable = matmul_node.inputs[0] if isinstance(matmul_node.inputs[1], Constant) else matmul_node.inputs[1] users = get_node_users(matmul_node) @@ -520,7 +508,7 @@ def find_matmul_add_nodes(node, opset): ) outputs = [reshape_out_variable] - match.update( + match_case.update( { matmul_node.name + "_pre_reshape": { "op": "Reshape", @@ -550,7 +538,7 @@ def find_matmul_add_nodes(node, opset): gemm_out_variable = gs.Variable(matmul_node.name + "_gemm_out", dtype=output_variable.dtype) outputs = [gemm_out_variable] - match.update( + match_case.update( { matmul_node.name: { "op": "Gemm", @@ -583,7 +571,7 @@ def find_matmul_add_nodes(node, opset): add_node.inputs.clear() add_node.outputs.clear() - match.update( + match_case.update( { matmul_node.name + "_post_reshape": { "op": "Reshape", @@ -617,7 +605,7 @@ def find_matmul_add_nodes(node, opset): outputs = list(add_node.outputs) add_node.inputs.clear() add_node.outputs.clear() - match.update( + match_case.update( { matmul_node.name: { "op": "Gemm", @@ -634,70 +622,70 @@ def find_matmul_add_nodes(node, opset): } } ) - return match - - -# @register_fusion_pattern("FusionGelu") -def find_gelu_nodes(node, opset): - """Identifies GELU (Gaussian Error Linear Unit) activation pattern nodes in a computational graph based on given - conditions. - """ - """ - x - / \ - | Div - | | - | Erf - | | - | Add - \ / - Mul - | - Mul - """ - # fmt: on - match = {} - if node.op == "Mul" and ( - node.i(0).op == "Mul" - and node.i(0).i(1).op == "Add" - and node.i(0).i(1).i(0).op == "Erf" - and node.i(0).i(1).i(0).i(0).op == "Div" - ): - input_variable = node.i(0).i(1).i(0).i(0).inputs[0] - mul_node = node.i(0) - div_node = node.i(0).i(1).i(0).i(0) + return match_case + +register_fusion_pattern(MatMulAddPatternMatcher(1)) + +class GeluPatternMatcher(PatternMatcher): + def __init__(self, priority): + pattern = Pattern( + ''' + input input 0 2 mul_0 div_0 + Div div_0 2 1 input ? erf_0 + Erf erf_0 1 1 div_0 add_0 + Add add_0 2 1 erf_0 ? mul_0 + Mul mul_0 2 1 input add_0 mul_1 + Mul mul_1 2 1 mul_0 ? output + output output 1 0 mul_1 + ''') + super().__init__(pattern, priority) + + @property + def name(self): + return "FusionGelu" + + def rewrite(self): + match_case = {} + input_variable = self.div_0.inputs[0] + mul_node = self.mul_0 + div_node = self.div_0 input_variable.outputs.remove(mul_node) input_variable.outputs.remove(div_node) - output_variable = node.outputs[0] + output_variable = self.mul_1.outputs[0] output_variable.inputs.clear() - match[node.name] = { + + match_case[self.mul_1.name] = { "op": "Gelu", "inputs": [input_variable], "outputs": [output_variable], "domain": None, } - return match - - -@register_fusion_pattern("FusionReduce") -def find_slice_nodes(node, opset): - """Find and return a dictionary of matching 'ReduceSum' followed by 'Unsqueeze' nodes that match specific conditions - in the graph. - """ - """ - x - | - ReduceSum - | - Unsqueeze - """ - # fmt: on - match = {} - if node.op == "Unsqueeze" and node.i(0).op == "ReduceSum": - reduce_node = node.i(0) + return match_case + +# register_fusion_pattern(GeluPatternMatcher(1)) + +class ReducePatternMatcher(PatternMatcher): + def __init__(self, priority): + pattern = Pattern( + ''' + input input 0 1 reduce_0 + ReduceSum reduce_0 1 1 input unsqueeze_0 + Unsqueeze unsqueeze_0 1 1 reduce_0 output + output output 1 0 unsqueeze_0 + ''') + super().__init__(pattern, priority) + + @property + def name(self): + return "FusionReduce" + + def rewrite(self, opset=11): + match_case = {} + node = self.unsqueeze_0 + reduce_node = self.reduce_0 reduce_node_node_users = get_node_users(reduce_node) if len(reduce_node_node_users) == 1: unsqueeze_node = node @@ -714,7 +702,7 @@ def find_slice_nodes(node, opset): unsqueeze_node.inputs.clear() unsqueeze_node.outputs.clear() attrs["keepdims"] = 1 - match[reduce_node.name] = { + match_case[reduce_node.name] = { "op": reduce_node.op, "inputs": inputs, "outputs": outputs, @@ -723,8 +711,9 @@ def find_slice_nodes(node, opset): "domain": None, } - return match + return match_case +register_fusion_pattern(ReducePatternMatcher(1)) @gs.Graph.register() def replace_custom_layer( @@ -753,18 +742,18 @@ def find_matches(graph: Graph, fusion_patterns: dict): counter = Counter() for node in reversed(graph.nodes): if node.name not in match_map: - for layer_type, func in fusion_patterns.items(): - with contextlib.suppress(IndexError): - matches = func(node, opset) - if matches: - logger.debug(f"matched pattern {layer_type}") - for _, match in matches.items(): - if "op" not in match: - match.update({"op": layer_type}) - if "name" not in match: - match.update({"name": f"{layer_type.lower()}_{counter[layer_type]}"}) - counter.update([layer_type]) - match_map.update(matches) + for layer_type, pattern_matcher in fusion_patterns.items(): + match = pattern_matcher.match(node) + if match: + match_case = pattern_matcher.rewrite() + logger.debug(f"matched pattern {layer_type}") + for _, match in match_case.items(): + if "op" not in match: + match.update({"op": layer_type}) + if "name" not in match: + match.update({"name": f"{layer_type.lower()}_{counter[layer_type]}"}) + counter.update([layer_type]) + match_map.update(match_case) return match_map