diff --git a/sne4onnx/__init__.py b/sne4onnx/__init__.py index 8e0584a..573cb3f 100644 --- a/sne4onnx/__init__.py +++ b/sne4onnx/__init__.py @@ -1,3 +1,3 @@ from sne4onnx.onnx_network_extraction import extraction, main -__version__ = '1.0.10' +__version__ = '1.0.11' diff --git a/sne4onnx/onnx_network_extraction.py b/sne4onnx/onnx_network_extraction.py index 03c37ef..e63d739 100644 --- a/sne4onnx/onnx_network_extraction.py +++ b/sne4onnx/onnx_network_extraction.py @@ -31,6 +31,12 @@ class Color: BG_DEFAULT = '\033[49m' RESET = '\033[0m' +ONNX_STANDARD_DOMAINS = [ + 'ai.onnx', + 'ai.onnx.ml', + '', +] + def extraction( input_op_names: List[str], @@ -103,20 +109,39 @@ def extraction( if not onnx_graph: onnx_graph = onnx.load(input_onnx_file_path) onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph) + + # Acquisition of Node with custom domain + custom_domain_check_onnx_nodes = [] + custom_domain_check_onnx_nodes = \ + custom_domain_check_onnx_nodes + \ + [ + node for node in onnx_graph.graph.node \ + if node.domain not in ONNX_STANDARD_DOMAINS + ] + graph = gs.import_onnx(onnx_graph) graph.cleanup().toposort() + # Check if Graph contains a custom domain (custom module) + contains_custom_domain = len( + [ + domain \ + for domain in graph.import_domains \ + if domain.domain not in ONNX_STANDARD_DOMAINS + ] + ) > 0 + # Extraction of input OP and output OP graph_node_inputs = [ - graph_nodes \ - for graph_nodes in graph.nodes \ - for graph_nodes_input in graph_nodes.inputs \ + graph_node \ + for graph_node in graph.nodes \ + for graph_nodes_input in graph_node.inputs \ if graph_nodes_input.name in input_op_names ] graph_node_outputs = [ - graph_nodes \ - for graph_nodes in graph.nodes \ - for graph_nodes_output in graph_nodes.outputs \ + graph_node \ + for graph_node in graph.nodes \ + for graph_nodes_output in graph_node.outputs \ if graph_nodes_output.name in output_op_names ] @@ -128,8 +153,10 @@ def extraction( input_tmp = [] for graph_node in graph_node_inputs: for graph_node_input in graph_node.inputs: - # if graph_node_input.shape and graph_node_input.name not in [i.name for i in input_tmp]: - if graph_node_input.shape and graph_node_input not in [i for i in input_tmp]: + if graph_node_input.shape \ + and graph_node_input not in [i for i in input_tmp] \ + and hasattr(graph_node_input, 'name') \ + and graph_node_input.name in [i for i in input_op_names]: input_tmp.append(graph_node_input) graph.inputs = input_tmp @@ -155,10 +182,21 @@ def extraction( 'Be sure to open the .onnx file to verify the certainty of the geometry.' ) + ## 5. Restore a node's custom domain + if contains_custom_domain: + extracted_graph_nodes = extracted_graph.graph.node + for extracted_graph_node in extracted_graph_nodes: + for custom_domain_check_onnx_node in custom_domain_check_onnx_nodes: + if extracted_graph_node.name == custom_domain_check_onnx_node.name: + extracted_graph_node.domain = custom_domain_check_onnx_node.domain + # Save if output_onnx_file_path: onnx.save(extracted_graph, output_onnx_file_path) + if not non_verbose: + print(f'{Color.GREEN}INFO:{Color.RESET} Finish!') + return extracted_graph