Skip to content

Commit

Permalink
Significantly faster processing
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed May 8, 2022
1 parent 9499494 commit a677e1f
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 32 deletions.
27 changes: 16 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ usage:
--input_op_names INPUT_OP_NAMES
--output_op_names OUTPUT_OP_NAMES
[--output_onnx_file_path OUTPUT_ONNX_FILE_PATH]
[--non_verbose]

optional arguments:
-h, --help
Expand All @@ -43,16 +44,17 @@ optional arguments:

--input_op_names INPUT_OP_NAMES
List of OP names to specify for the input layer of the model.
Specify the name of the OP, separated by commas.
e.g. --input_op_names aaa,bbb,ccc
e.g. --input_op_names aaa bbb ccc

--output_op_names OUTPUT_OP_NAMES
List of OP names to specify for the output layer of the model.
Specify the name of the OP, separated by commas.
e.g. --output_op_names ddd,eee,fff
e.g. --output_op_names ddd eee fff

--output_onnx_file_path OUTPUT_ONNX_FILE_PATH
Output onnx file path. If not specified, extracted.onnx is output.

--non_verbose
Do not show all information logs. Only error logs are displayed.
```

## 3. In-script Usage
Expand All @@ -68,19 +70,18 @@ extraction(
output_op_names: List[str],
input_onnx_file_path: Union[str, NoneType] = '',
onnx_graph: Union[onnx.onnx_ml_pb2.ModelProto, NoneType] = None,
output_onnx_file_path: Union[str, NoneType] = ''
output_onnx_file_path: Union[str, NoneType] = '',
non_verbose: Optional[bool] = False
) -> onnx.onnx_ml_pb2.ModelProto

Parameters
----------
input_op_names: List[str]
List of OP names to specify for the input layer of the model.
Specify the name of the OP, separated by commas.
e.g. ['aaa','bbb','ccc']

output_op_names: List[str]
List of OP names to specify for the output layer of the model.
Specify the name of the OP, separated by commas.
e.g. ['ddd','eee','fff']

input_onnx_file_path: Optional[str]
Expand All @@ -98,6 +99,10 @@ extraction(
If not specified, .onnx is not output.
Default: ''

non_verbose: Optional[bool]
Do not show all information logs. Only error logs are displayed.
Default: False

Returns
-------
extracted_graph: onnx.ModelProto
Expand All @@ -108,8 +113,8 @@ extraction(
```bash
$ sne4onnx \
--input_onnx_file_path input.onnx \
--input_op_names aaa,bbb,ccc \
--output_op_names ddd,eee,fff \
--input_op_names aaa bbb ccc \
--output_op_names ddd eee fff \
--output_onnx_file_path output.onnx
```

Expand Down Expand Up @@ -147,8 +152,8 @@ extracted_graph = extraction(
```bash
$ sne4onnx \
--input_onnx_file_path hitnet_sf_finalpass_720x1280.onnx \
--input_op_names 0,1 \
--output_op_names 497,785 \
--input_op_names 0 1 \
--output_op_names 497 785 \
--output_onnx_file_path hitnet_sf_finalpass_720x960_head.onnx
```

Expand Down
2 changes: 1 addition & 1 deletion sne4onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from sne4onnx.onnx_network_extraction import extraction, main

__version__ = '1.0.5'
__version__ = '1.0.6'
93 changes: 73 additions & 20 deletions sne4onnx/onnx_network_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
from argparse import ArgumentParser
import onnx
import onnx_graphsurgeon as gs
from typing import Optional, List

class Color:
Expand Down Expand Up @@ -37,19 +38,18 @@ def extraction(
input_onnx_file_path: Optional[str] = '',
onnx_graph: Optional[onnx.ModelProto] = None,
output_onnx_file_path: Optional[str] = '',
non_verbose: Optional[bool] = False,
) -> onnx.ModelProto:

"""
Parameters
----------
input_op_names: List[str]
List of OP names to specify for the input layer of the model.\n\
Specify the name of the OP, separated by commas.\n\
e.g. ['aaa','bbb','ccc']
output_op_names: List[str]
List of OP names to specify for the output layer of the model.\n\
Specify the name of the OP, separated by commas.\n\
e.g. ['ddd','eee','fff']
input_onnx_file_path: Optional[str]
Expand All @@ -67,6 +67,10 @@ def extraction(
If not specified, .onnx is not output.\n\
Default: ''
non_verbose: Optional[bool]
Do not show all information logs. Only error logs are displayed.\n\
Default: False
Returns
-------
extracted_graph: onnx.ModelProto
Expand All @@ -80,19 +84,55 @@ def extraction(
)
sys.exit(1)

if not input_op_names:
print(
f'{Color.RED}ERROR:{Color.RESET} '+
f'One or more input_op_names must be specified.'
)
sys.exit(1)

if not output_op_names:
print(
f'{Color.RED}ERROR:{Color.RESET} '+
f'One or more output_op_names must be specified.'
)
sys.exit(1)

# Load
graph = None
if not onnx_graph:
graph = onnx.load(input_onnx_file_path)
else:
graph = onnx_graph

# Extract
extractor = onnx.utils.Extractor(graph)
extracted_graph = extractor.extract_model(
input_op_names,
output_op_names,
)
onnx_graph = onnx.load(input_onnx_file_path)
onnx_graph = onnx.shape_inference.infer_shapes(onnx_graph)
graph = gs.import_onnx(onnx_graph)
graph.cleanup().toposort()

# Extraction of input OP and output OP
graph_node_inputs = [graph_nodes for graph_nodes in graph.nodes for graph_nodes_input in graph_nodes.inputs if graph_nodes_input.name in input_op_names]
graph_node_outputs = [graph_nodes for graph_nodes in graph.nodes for graph_nodes_output in graph_nodes.outputs if graph_nodes_output.name in output_op_names]

# Init graph INPUT/OUTPUT
graph.inputs.clear()
graph.outputs.clear()

# Update graph INPUT/OUTPUT
graph.inputs = [graph_node_input for graph_node in graph_node_inputs for graph_node_input in graph_node.inputs if graph_node_input.shape]
graph.outputs = [graph_node_output for graph_node in graph_node_outputs for graph_node_output in graph_node.outputs]

# Cleanup
graph.cleanup().toposort()

# Shape Estimation
extracted_graph = None
try:
extracted_graph = onnx.shape_inference.infer_shapes(gs.export_onnx(graph))
except Exception as e:
extracted_graph = gs.export_onnx(graph)
if not non_verbose:
print(
f'{Color.YELLOW}WARNING:{Color.RESET} '+
'The input shape of the next OP does not match the output shape. '+
'Be sure to open the .onnx file to verify the certainty of the geometry.'
)

# Save
if output_onnx_file_path:
Expand All @@ -112,38 +152,51 @@ def main():
parser.add_argument(
'--input_op_names',
type=str,
nargs='+',
required=True,
help="\
List of OP names to specify for the input layer of the model. \
Specify the name of the OP, separated by commas. \
e.g. --input_op_names aaa,bbb,ccc"
e.g. --input_op_names aaa bbb ccc"
)
parser.add_argument(
'--output_op_names',
type=str,
nargs='+',
required=True,
help="\
List of OP names to specify for the output layer of the model. \
Specify the name of the OP, separated by commas. \
e.g. --output_op_names ddd,eee,fff"
e.g. --output_op_names ddd eee fff"
)
parser.add_argument(
'--output_onnx_file_path',
type=str,
default='extracted.onnx',
help='Output onnx file path. If not specified, extracted.onnx is output.'
)
parser.add_argument(
'--non_verbose',
action='store_true',
help='Do not show all information logs. Only error logs are displayed.'
)
args = parser.parse_args()

input_op_names = args.input_op_names.strip(' ,').replace(' ','').split(',')
output_op_names = args.output_op_names.strip(' ,').replace(' ','').split(',')
input_onnx_file_path = args.input_onnx_file_path
input_op_names = args.input_op_names
output_op_names = args.output_op_names
output_onnx_file_path = args.output_onnx_file_path
non_verbose = args.non_verbose

# Load
onnx_graph = onnx.load(input_onnx_file_path)

# Model extraction
extracted_graph = extraction(
input_onnx_file_path=args.input_onnx_file_path,
input_onnx_file_path=None,
input_op_names=input_op_names,
output_op_names=output_op_names,
output_onnx_file_path=args.output_onnx_file_path,
onnx_graph=onnx_graph,
output_onnx_file_path=output_onnx_file_path,
non_verbose=non_verbose,
)


Expand Down

0 comments on commit a677e1f

Please sign in to comment.