From f7774a4362c5a1f545a3d81c1f0c5637f1b202c0 Mon Sep 17 00:00:00 2001 From: inisis Date: Sat, 8 Jun 2024 04:30:59 +0000 Subject: [PATCH] fix tied weight duplication --- .../exporters/onnx_exporter.py | 37 +++++++++++++---- onnxslim/onnx_graphsurgeon/ir/graph.py | 2 +- onnxslim/onnx_graphsurgeon/util/misc.py | 12 ++++++ onnxslim/utils.py | 41 +++++++++++++++++++ 4 files changed, 82 insertions(+), 10 deletions(-) diff --git a/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py b/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py index edfe3c6..3da0b85 100644 --- a/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py +++ b/onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from collections import OrderedDict from typing import List, Sequence, Union import numpy as np @@ -164,14 +164,14 @@ def export_value_info_proto(tensor: Tensor, do_type_check: bool) -> onnx.ValueIn return onnx_tensor @staticmethod - def export_attributes(attrs: dict) -> List[onnx.AttributeProto]: + def export_attributes(attrs: dict, subgraph_tensor_map) -> List[onnx.AttributeProto]: onnx_attrs: List[onnx.AttributeProto] = [] for key, val in attrs.items(): if isinstance(val, Tensor): val = OnnxExporter.export_tensor_proto(val) elif isinstance(val, Graph): # Subgraphs don't need to have types specified for their tensors. - val = OnnxExporter.export_graph(val, do_type_check=False) + val = OnnxExporter.export_graph(val, subgraph_tensor_map=subgraph_tensor_map, do_type_check=False) elif isinstance(val, Node.AttributeRef): onnx_attr = onnx.AttributeProto() onnx_attr.name = key @@ -196,7 +196,7 @@ def export_attributes(attrs: dict) -> List[onnx.AttributeProto]: return onnx_attrs @staticmethod - def export_node(node: Node) -> onnx.NodeProto: + def export_node(node: Node, subgraph_tensor_map) -> onnx.NodeProto: # Cannot pass in attrs directly as make_node will change the order onnx_node = onnx.helper.make_node( node.op, @@ -205,7 +205,7 @@ def export_node(node: Node) -> onnx.NodeProto: name=node.name, domain=node.domain, ) - onnx_node.attribute.extend(OnnxExporter.export_attributes(node.attrs)) + onnx_node.attribute.extend(OnnxExporter.export_attributes(node.attrs, subgraph_tensor_map)) return onnx_node @staticmethod @@ -259,7 +259,10 @@ def export_function(func: Function) -> onnx.FunctionProto: ) @staticmethod - def export_graph(graph: Graph, do_type_check=True) -> onnx.GraphProto: + def export_graph(graph: Graph, + tensor_map:"OrderedDict[str, Tensor]" = None, + subgraph_tensor_map: "OrderedDict[str, Tensor]" = None, + do_type_check=True) -> onnx.GraphProto: """ Export an onnx-graphsurgeon Graph to an ONNX GraphProto. @@ -270,10 +273,14 @@ def export_graph(graph: Graph, do_type_check=True) -> onnx.GraphProto: Defaults to True. """ check_duplicate_node_names(graph.nodes, level=G_LOGGER.WARNING) - nodes = [OnnxExporter.export_node(node) for node in graph.nodes] + nodes = [OnnxExporter.export_node(node, subgraph_tensor_map) for node in graph.nodes] inputs = [OnnxExporter.export_value_info_proto(inp, do_type_check) for inp in graph.inputs] outputs = [OnnxExporter.export_value_info_proto(out, do_type_check) for out in graph.outputs] - tensor_map = graph.tensors() + if tensor_map is None: + tensor_map = graph.tensors() + tensor_map = misc.unique_dicts(tensor_map, subgraph_tensor_map) + else: + tensor_map = misc.combine_dicts(tensor_map, subgraph_tensor_map) initializer = [ OnnxExporter.export_tensor_proto(tensor) for tensor in tensor_map.values() @@ -327,7 +334,19 @@ def export_onnx(graph: Graph, do_type_check=True, **kwargs) -> "onnx.ModelProto" Returns: onnx.ModelProto: A corresponding ONNX model. """ - onnx_graph = OnnxExporter.export_graph(graph, do_type_check=do_type_check) + sub_graphs = graph.subgraphs(recursive=True) + + graph_constants_list = [] + for sub_graph in sub_graphs: + graph_constants = {name: tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, Constant)} + graph_constants_list.append(graph_constants) + + if len(graph_constants_list) == 0: + intersection = None + else: + intersection = {k: graph_constants_list[0][k] for k in graph_constants_list[0] if all(k in d for d in graph_constants_list[1:])} + + onnx_graph = OnnxExporter.export_graph(graph, tensor_map=intersection, subgraph_tensor_map=intersection, do_type_check=do_type_check) onnx_functions = [OnnxExporter.export_function(func) for func in graph.functions] kwargs["functions"] = onnx_functions diff --git a/onnxslim/onnx_graphsurgeon/ir/graph.py b/onnxslim/onnx_graphsurgeon/ir/graph.py index 4b9f848..03219cd 100644 --- a/onnxslim/onnx_graphsurgeon/ir/graph.py +++ b/onnxslim/onnx_graphsurgeon/ir/graph.py @@ -957,7 +957,7 @@ def all_tensors_const(tensors): graph_constants.update({out.name: out for out in node.outputs}) return graph_constants - graph_constants = {name: tensor for name, tensor in clone_tensors.items() if (isinstance(tensor, Constant) and len(tensor.outputs) == 1)} + graph_constants = {name: tensor for name, tensor in clone_tensors.items() if (isinstance(tensor, Constant) and not any([t.op == "Gather" for t in tensor.outputs]))} graph_constants = update_foldable_outputs(graph_constants) # Pass 4: Shape Folding diff --git a/onnxslim/onnx_graphsurgeon/util/misc.py b/onnxslim/onnx_graphsurgeon/util/misc.py index 1a70f73..e43f00c 100644 --- a/onnxslim/onnx_graphsurgeon/util/misc.py +++ b/onnxslim/onnx_graphsurgeon/util/misc.py @@ -69,6 +69,18 @@ def combine_dicts(dict0, dict1): return combined +def unique_dicts(dict0, dict1): + """ + Substract two dictionaries. Values in the second will be substracted from the first. + """ + if not dict1: + return dict0 + + unique_dict = {k: v for k, v in dict0.items() if k not in dict1} + + return unique_dict + + def is_dynamic_dimension(dim): """Check if a dimension is dynamic (non-integer or negative).""" return not isinstance(dim, int) or dim < 0 diff --git a/onnxslim/utils.py b/onnxslim/utils.py index 1a0251a..72f31dc 100644 --- a/onnxslim/utils.py +++ b/onnxslim/utils.py @@ -439,3 +439,44 @@ def check_result(raw_onnx_output, slimmed_onnx_output): logger.warning("Model output mismatch after slimming.") logger.warning("Please check the model carefully.") return + + +data_type_sizes = { + onnx.TensorProto.FLOAT: 4, + onnx.TensorProto.DOUBLE: 8, + onnx.TensorProto.INT32: 4, + onnx.TensorProto.INT64: 8, + onnx.TensorProto.UINT8: 1, + onnx.TensorProto.INT8: 1, + onnx.TensorProto.UINT16: 2, + onnx.TensorProto.INT16: 2, + onnx.TensorProto.BOOL: 1, +} + + +def calculate_tensor_size(tensor): + shape = tensor.dims + num_elements = np.prod(shape) if shape else 0 + element_size = data_type_sizes.get(tensor.data_type, 0) + return num_elements * element_size + + +def get_model_size_and_initializer_size(model): + initializer_size = 0 + for tensor in model.graph.initializer: + tensor_size = calculate_tensor_size(tensor) + initializer_size += tensor_size + + print('model size', model.ByteSize()) + print('initializer size', initializer_size) + + +def get_model_subgraph_size(model): + graph = model.graph + for node in graph.node: + for attr in node.attribute: + ATTR_TYPE_MAPPING = {v: k for k, v in onnx.AttributeProto.AttributeType.items()} + if attr.type in ATTR_TYPE_MAPPING: + attr_str = ATTR_TYPE_MAPPING[attr.type] + if attr_str == "GRAPH": + print('subgraph', attr.g.ByteSize())