Skip to content

Commit

Permalink
Merge pull request #3 from PINTO0309/support_custom_domain
Browse files Browse the repository at this point in the history
Support for models with custom domains and elimination of critical bugs
  • Loading branch information
PINTO0309 authored Jan 2, 2023
2 parents 0e63f63 + f1d3cbd commit fdb1cae
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 9 deletions.
2 changes: 1 addition & 1 deletion sne4onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from sne4onnx.onnx_network_extraction import extraction, main

__version__ = '1.0.10'
__version__ = '1.0.11'
54 changes: 46 additions & 8 deletions sne4onnx/onnx_network_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ class Color:
BG_DEFAULT = '\033[49m'
RESET = '\033[0m'

ONNX_STANDARD_DOMAINS = [
'ai.onnx',
'ai.onnx.ml',
'',
]


def extraction(
input_op_names: List[str],
Expand Down Expand Up @@ -103,20 +109,39 @@ def extraction(
if not onnx_graph:
onnx_graph = onnx.load(input_onnx_file_path)
onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph)

# Acquisition of Node with custom domain
custom_domain_check_onnx_nodes = []
custom_domain_check_onnx_nodes = \
custom_domain_check_onnx_nodes + \
[
node for node in onnx_graph.graph.node \
if node.domain not in ONNX_STANDARD_DOMAINS
]

graph = gs.import_onnx(onnx_graph)
graph.cleanup().toposort()

# Check if Graph contains a custom domain (custom module)
contains_custom_domain = len(
[
domain \
for domain in graph.import_domains \
if domain.domain not in ONNX_STANDARD_DOMAINS
]
) > 0

# Extraction of input OP and output OP
graph_node_inputs = [
graph_nodes \
for graph_nodes in graph.nodes \
for graph_nodes_input in graph_nodes.inputs \
graph_node \
for graph_node in graph.nodes \
for graph_nodes_input in graph_node.inputs \
if graph_nodes_input.name in input_op_names
]
graph_node_outputs = [
graph_nodes \
for graph_nodes in graph.nodes \
for graph_nodes_output in graph_nodes.outputs \
graph_node \
for graph_node in graph.nodes \
for graph_nodes_output in graph_node.outputs \
if graph_nodes_output.name in output_op_names
]

Expand All @@ -128,8 +153,10 @@ def extraction(
input_tmp = []
for graph_node in graph_node_inputs:
for graph_node_input in graph_node.inputs:
# if graph_node_input.shape and graph_node_input.name not in [i.name for i in input_tmp]:
if graph_node_input.shape and graph_node_input not in [i for i in input_tmp]:
if graph_node_input.shape \
and graph_node_input not in [i for i in input_tmp] \
and hasattr(graph_node_input, 'name') \
and graph_node_input.name in [i for i in input_op_names]:
input_tmp.append(graph_node_input)
graph.inputs = input_tmp

Expand All @@ -155,10 +182,21 @@ def extraction(
'Be sure to open the .onnx file to verify the certainty of the geometry.'
)

## 5. Restore a node's custom domain
if contains_custom_domain:
extracted_graph_nodes = extracted_graph.graph.node
for extracted_graph_node in extracted_graph_nodes:
for custom_domain_check_onnx_node in custom_domain_check_onnx_nodes:
if extracted_graph_node.name == custom_domain_check_onnx_node.name:
extracted_graph_node.domain = custom_domain_check_onnx_node.domain

# Save
if output_onnx_file_path:
onnx.save(extracted_graph, output_onnx_file_path)

if not non_verbose:
print(f'{Color.GREEN}INFO:{Color.RESET} Finish!')

return extracted_graph


Expand Down

0 comments on commit fdb1cae

Please sign in to comment.