Skip to content

Commit

Permalink
fix tied weight duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis committed Jun 8, 2024
1 parent 4688a64 commit f7774a4
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 10 deletions.
37 changes: 28 additions & 9 deletions onnxslim/onnx_graphsurgeon/exporters/onnx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from collections import OrderedDict
from typing import List, Sequence, Union

import numpy as np
Expand Down Expand Up @@ -164,14 +164,14 @@ def export_value_info_proto(tensor: Tensor, do_type_check: bool) -> onnx.ValueIn
return onnx_tensor

@staticmethod
def export_attributes(attrs: dict) -> List[onnx.AttributeProto]:
def export_attributes(attrs: dict, subgraph_tensor_map) -> List[onnx.AttributeProto]:
onnx_attrs: List[onnx.AttributeProto] = []
for key, val in attrs.items():
if isinstance(val, Tensor):
val = OnnxExporter.export_tensor_proto(val)
elif isinstance(val, Graph):
# Subgraphs don't need to have types specified for their tensors.
val = OnnxExporter.export_graph(val, do_type_check=False)
val = OnnxExporter.export_graph(val, subgraph_tensor_map=subgraph_tensor_map, do_type_check=False)
elif isinstance(val, Node.AttributeRef):
onnx_attr = onnx.AttributeProto()
onnx_attr.name = key
Expand All @@ -196,7 +196,7 @@ def export_attributes(attrs: dict) -> List[onnx.AttributeProto]:
return onnx_attrs

@staticmethod
def export_node(node: Node) -> onnx.NodeProto:
def export_node(node: Node, subgraph_tensor_map) -> onnx.NodeProto:
# Cannot pass in attrs directly as make_node will change the order
onnx_node = onnx.helper.make_node(
node.op,
Expand All @@ -205,7 +205,7 @@ def export_node(node: Node) -> onnx.NodeProto:
name=node.name,
domain=node.domain,
)
onnx_node.attribute.extend(OnnxExporter.export_attributes(node.attrs))
onnx_node.attribute.extend(OnnxExporter.export_attributes(node.attrs, subgraph_tensor_map))
return onnx_node

@staticmethod
Expand Down Expand Up @@ -259,7 +259,10 @@ def export_function(func: Function) -> onnx.FunctionProto:
)

@staticmethod
def export_graph(graph: Graph, do_type_check=True) -> onnx.GraphProto:
def export_graph(graph: Graph,
tensor_map:"OrderedDict[str, Tensor]" = None,
subgraph_tensor_map: "OrderedDict[str, Tensor]" = None,
do_type_check=True) -> onnx.GraphProto:
"""
Export an onnx-graphsurgeon Graph to an ONNX GraphProto.
Expand All @@ -270,10 +273,14 @@ def export_graph(graph: Graph, do_type_check=True) -> onnx.GraphProto:
Defaults to True.
"""
check_duplicate_node_names(graph.nodes, level=G_LOGGER.WARNING)
nodes = [OnnxExporter.export_node(node) for node in graph.nodes]
nodes = [OnnxExporter.export_node(node, subgraph_tensor_map) for node in graph.nodes]
inputs = [OnnxExporter.export_value_info_proto(inp, do_type_check) for inp in graph.inputs]
outputs = [OnnxExporter.export_value_info_proto(out, do_type_check) for out in graph.outputs]
tensor_map = graph.tensors()
if tensor_map is None:
tensor_map = graph.tensors()
tensor_map = misc.unique_dicts(tensor_map, subgraph_tensor_map)
else:
tensor_map = misc.combine_dicts(tensor_map, subgraph_tensor_map)
initializer = [
OnnxExporter.export_tensor_proto(tensor)
for tensor in tensor_map.values()
Expand Down Expand Up @@ -327,7 +334,19 @@ def export_onnx(graph: Graph, do_type_check=True, **kwargs) -> "onnx.ModelProto"
Returns:
onnx.ModelProto: A corresponding ONNX model.
"""
onnx_graph = OnnxExporter.export_graph(graph, do_type_check=do_type_check)
sub_graphs = graph.subgraphs(recursive=True)

graph_constants_list = []
for sub_graph in sub_graphs:
graph_constants = {name: tensor for name, tensor in sub_graph.tensors().items() if isinstance(tensor, Constant)}
graph_constants_list.append(graph_constants)

if len(graph_constants_list) == 0:
intersection = None
else:
intersection = {k: graph_constants_list[0][k] for k in graph_constants_list[0] if all(k in d for d in graph_constants_list[1:])}

onnx_graph = OnnxExporter.export_graph(graph, tensor_map=intersection, subgraph_tensor_map=intersection, do_type_check=do_type_check)
onnx_functions = [OnnxExporter.export_function(func) for func in graph.functions]
kwargs["functions"] = onnx_functions

Expand Down
2 changes: 1 addition & 1 deletion onnxslim/onnx_graphsurgeon/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,7 @@ def all_tensors_const(tensors):
graph_constants.update({out.name: out for out in node.outputs})
return graph_constants

graph_constants = {name: tensor for name, tensor in clone_tensors.items() if (isinstance(tensor, Constant) and len(tensor.outputs) == 1)}
graph_constants = {name: tensor for name, tensor in clone_tensors.items() if (isinstance(tensor, Constant) and not any([t.op == "Gather" for t in tensor.outputs]))}
graph_constants = update_foldable_outputs(graph_constants)

# Pass 4: Shape Folding
Expand Down
12 changes: 12 additions & 0 deletions onnxslim/onnx_graphsurgeon/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ def combine_dicts(dict0, dict1):
return combined


def unique_dicts(dict0, dict1):
"""
Substract two dictionaries. Values in the second will be substracted from the first.
"""
if not dict1:
return dict0

unique_dict = {k: v for k, v in dict0.items() if k not in dict1}

return unique_dict


def is_dynamic_dimension(dim):
"""Check if a dimension is dynamic (non-integer or negative)."""
return not isinstance(dim, int) or dim < 0
Expand Down
41 changes: 41 additions & 0 deletions onnxslim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,3 +439,44 @@ def check_result(raw_onnx_output, slimmed_onnx_output):
logger.warning("Model output mismatch after slimming.")
logger.warning("Please check the model carefully.")
return


data_type_sizes = {
onnx.TensorProto.FLOAT: 4,
onnx.TensorProto.DOUBLE: 8,
onnx.TensorProto.INT32: 4,
onnx.TensorProto.INT64: 8,
onnx.TensorProto.UINT8: 1,
onnx.TensorProto.INT8: 1,
onnx.TensorProto.UINT16: 2,
onnx.TensorProto.INT16: 2,
onnx.TensorProto.BOOL: 1,
}


def calculate_tensor_size(tensor):
shape = tensor.dims
num_elements = np.prod(shape) if shape else 0
element_size = data_type_sizes.get(tensor.data_type, 0)
return num_elements * element_size


def get_model_size_and_initializer_size(model):
initializer_size = 0
for tensor in model.graph.initializer:
tensor_size = calculate_tensor_size(tensor)
initializer_size += tensor_size

print('model size', model.ByteSize())
print('initializer size', initializer_size)


def get_model_subgraph_size(model):
graph = model.graph
for node in graph.node:
for attr in node.attribute:
ATTR_TYPE_MAPPING = {v: k for k, v in onnx.AttributeProto.AttributeType.items()}
if attr.type in ATTR_TYPE_MAPPING:
attr_str = ATTR_TYPE_MAPPING[attr.type]
if attr_str == "GRAPH":
print('subgraph', attr.g.ByteSize())

0 comments on commit f7774a4

Please sign in to comment.