diff --git a/README.md b/README.md index e71a4c6..c41a771 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ usage: --input_op_names INPUT_OP_NAMES --output_op_names OUTPUT_OP_NAMES [--output_onnx_file_path OUTPUT_ONNX_FILE_PATH] + [--non_verbose] optional arguments: -h, --help @@ -43,16 +44,17 @@ optional arguments: --input_op_names INPUT_OP_NAMES List of OP names to specify for the input layer of the model. - Specify the name of the OP, separated by commas. - e.g. --input_op_names aaa,bbb,ccc + e.g. --input_op_names aaa bbb ccc --output_op_names OUTPUT_OP_NAMES List of OP names to specify for the output layer of the model. - Specify the name of the OP, separated by commas. - e.g. --output_op_names ddd,eee,fff + e.g. --output_op_names ddd eee fff --output_onnx_file_path OUTPUT_ONNX_FILE_PATH Output onnx file path. If not specified, extracted.onnx is output. + + --non_verbose + Do not show all information logs. Only error logs are displayed. ``` ## 3. In-script Usage @@ -68,19 +70,18 @@ extraction( output_op_names: List[str], input_onnx_file_path: Union[str, NoneType] = '', onnx_graph: Union[onnx.onnx_ml_pb2.ModelProto, NoneType] = None, - output_onnx_file_path: Union[str, NoneType] = '' + output_onnx_file_path: Union[str, NoneType] = '', + non_verbose: Optional[bool] = False ) -> onnx.onnx_ml_pb2.ModelProto Parameters ---------- input_op_names: List[str] List of OP names to specify for the input layer of the model. - Specify the name of the OP, separated by commas. e.g. ['aaa','bbb','ccc'] output_op_names: List[str] List of OP names to specify for the output layer of the model. - Specify the name of the OP, separated by commas. e.g. ['ddd','eee','fff'] input_onnx_file_path: Optional[str] @@ -98,6 +99,10 @@ extraction( If not specified, .onnx is not output. Default: '' + non_verbose: Optional[bool] + Do not show all information logs. Only error logs are displayed. + Default: False + Returns ------- extracted_graph: onnx.ModelProto @@ -108,8 +113,8 @@ extraction( ```bash $ sne4onnx \ --input_onnx_file_path input.onnx \ ---input_op_names aaa,bbb,ccc \ ---output_op_names ddd,eee,fff \ +--input_op_names aaa bbb ccc \ +--output_op_names ddd eee fff \ --output_onnx_file_path output.onnx ``` @@ -147,8 +152,8 @@ extracted_graph = extraction( ```bash $ sne4onnx \ --input_onnx_file_path hitnet_sf_finalpass_720x1280.onnx \ ---input_op_names 0,1 \ ---output_op_names 497,785 \ +--input_op_names 0 1 \ +--output_op_names 497 785 \ --output_onnx_file_path hitnet_sf_finalpass_720x960_head.onnx ``` diff --git a/sne4onnx/__init__.py b/sne4onnx/__init__.py index 7283581..be33c5b 100644 --- a/sne4onnx/__init__.py +++ b/sne4onnx/__init__.py @@ -1,3 +1,3 @@ from sne4onnx.onnx_network_extraction import extraction, main -__version__ = '1.0.5' +__version__ = '1.0.6' diff --git a/sne4onnx/onnx_network_extraction.py b/sne4onnx/onnx_network_extraction.py index 546a4d9..93c100a 100644 --- a/sne4onnx/onnx_network_extraction.py +++ b/sne4onnx/onnx_network_extraction.py @@ -3,6 +3,7 @@ import sys from argparse import ArgumentParser import onnx +import onnx_graphsurgeon as gs from typing import Optional, List class Color: @@ -37,6 +38,7 @@ def extraction( input_onnx_file_path: Optional[str] = '', onnx_graph: Optional[onnx.ModelProto] = None, output_onnx_file_path: Optional[str] = '', + non_verbose: Optional[bool] = False, ) -> onnx.ModelProto: """ @@ -44,12 +46,10 @@ def extraction( ---------- input_op_names: List[str] List of OP names to specify for the input layer of the model.\n\ - Specify the name of the OP, separated by commas.\n\ e.g. ['aaa','bbb','ccc'] output_op_names: List[str] List of OP names to specify for the output layer of the model.\n\ - Specify the name of the OP, separated by commas.\n\ e.g. ['ddd','eee','fff'] input_onnx_file_path: Optional[str] @@ -67,6 +67,10 @@ def extraction( If not specified, .onnx is not output.\n\ Default: '' + non_verbose: Optional[bool] + Do not show all information logs. Only error logs are displayed.\n\ + Default: False + Returns ------- extracted_graph: onnx.ModelProto @@ -80,19 +84,55 @@ def extraction( ) sys.exit(1) + if not input_op_names: + print( + f'{Color.RED}ERROR:{Color.RESET} '+ + f'One or more input_op_names must be specified.' + ) + sys.exit(1) + + if not output_op_names: + print( + f'{Color.RED}ERROR:{Color.RESET} '+ + f'One or more output_op_names must be specified.' + ) + sys.exit(1) + # Load graph = None if not onnx_graph: - graph = onnx.load(input_onnx_file_path) - else: - graph = onnx_graph - - # Extract - extractor = onnx.utils.Extractor(graph) - extracted_graph = extractor.extract_model( - input_op_names, - output_op_names, - ) + onnx_graph = onnx.load(input_onnx_file_path) + onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph) + graph = gs.import_onnx(onnx_graph) + graph.cleanup().toposort() + + # Extraction of input OP and output OP + graph_node_inputs = [graph_nodes for graph_nodes in graph.nodes for graph_nodes_input in graph_nodes.inputs if graph_nodes_input.name in input_op_names] + graph_node_outputs = [graph_nodes for graph_nodes in graph.nodes for graph_nodes_output in graph_nodes.outputs if graph_nodes_output.name in output_op_names] + + # Init graph INPUT/OUTPUT + graph.inputs.clear() + graph.outputs.clear() + + # Update graph INPUT/OUTPUT + graph.inputs = [graph_node_input for graph_node in graph_node_inputs for graph_node_input in graph_node.inputs if graph_node_input.shape] + graph.outputs = [graph_node_output for graph_node in graph_node_outputs for graph_node_output in graph_node.outputs] + + # Cleanup + graph.cleanup().toposort() + + # Shape Estimation + extracted_graph = None + try: + extracted_graph = onnx.shape_inference.infer_shapes(gs.export_onnx(graph)) + except Exception as e: + extracted_graph = gs.export_onnx(graph) + if not non_verbose: + print( + f'{Color.YELLOW}WARNING:{Color.RESET} '+ + 'The input shape of the next OP does not match the output shape. '+ + 'Be sure to open the .onnx file to verify the certainty of the geometry.' + ) # Save if output_onnx_file_path: @@ -112,20 +152,20 @@ def main(): parser.add_argument( '--input_op_names', type=str, + nargs='+', required=True, help="\ List of OP names to specify for the input layer of the model. \ - Specify the name of the OP, separated by commas. \ - e.g. --input_op_names aaa,bbb,ccc" + e.g. --input_op_names aaa bbb ccc" ) parser.add_argument( '--output_op_names', type=str, + nargs='+', required=True, help="\ List of OP names to specify for the output layer of the model. \ - Specify the name of the OP, separated by commas. \ - e.g. --output_op_names ddd,eee,fff" + e.g. --output_op_names ddd eee fff" ) parser.add_argument( '--output_onnx_file_path', @@ -133,17 +173,30 @@ def main(): default='extracted.onnx', help='Output onnx file path. If not specified, extracted.onnx is output.' ) + parser.add_argument( + '--non_verbose', + action='store_true', + help='Do not show all information logs. Only error logs are displayed.' + ) args = parser.parse_args() - input_op_names = args.input_op_names.strip(' ,').replace(' ','').split(',') - output_op_names = args.output_op_names.strip(' ,').replace(' ','').split(',') + input_onnx_file_path = args.input_onnx_file_path + input_op_names = args.input_op_names + output_op_names = args.output_op_names + output_onnx_file_path = args.output_onnx_file_path + non_verbose = args.non_verbose + + # Load + onnx_graph = onnx.load(input_onnx_file_path) # Model extraction extracted_graph = extraction( - input_onnx_file_path=args.input_onnx_file_path, + input_onnx_file_path=None, input_op_names=input_op_names, output_op_names=output_op_names, - output_onnx_file_path=args.output_onnx_file_path, + onnx_graph=onnx_graph, + output_onnx_file_path=output_onnx_file_path, + non_verbose=non_verbose, )