Skip to content

Commit

Permalink
Support for onnx.ModelProto input
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Apr 11, 2022
1 parent e04cdd0 commit 7d7be5b
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 19 deletions.
30 changes: 25 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ A very simple tool for situations where optimization with onnx-simplifier would

# Key concept
- [x] If INPUT OP name and OUTPUT OP name are specified, the onnx graph within the range of the specified OP name is extracted and .onnx is generated.
- [ ] Change backend to onnx-graphsurgeon so that onnx.ModelProto can be specified as input.
- [x] Change backend to `onnx.utils.Extractor.extract_model` so that onnx.ModelProto can be specified as input.

## 1. Setup
### 1-1. HostPC
Expand Down Expand Up @@ -75,7 +75,8 @@ extraction(
input_onnx_file_path: str,
input_op_names: List[str],
output_op_names: List[str],
output_onnx_file_path: Union[str, NoneType] = ''
output_onnx_file_path: Union[str, NoneType] = '',
onnx_graph: Union[onnx.onnx_ml_pb2.ModelProto, NoneType] = None
) -> onnx.onnx_ml_pb2.ModelProto

Parameters
Expand All @@ -98,6 +99,11 @@ extraction(
If not specified, .onnx is not output.
Default: ''

onnx_graph: Optional[onnx.ModelProto]
onnx.ModelProto.
Either input_onnx_file_path or onnx_graph must be specified.
onnx_graph If specified, ignore input_onnx_file_path and process onnx_graph.

Returns
-------
extracted_graph: onnx.ModelProto
Expand All @@ -114,6 +120,7 @@ $ sne4onnx \
```

## 5. In-script Execution
### 5-1. Use ONNX files
```python
from sne4onnx import extraction

Expand All @@ -124,6 +131,17 @@ extracted_graph = extraction(
output_onnx_file_path='output.onnx',
)
```
### 5-2. Use onnx.ModelProto
```python
from sne4onnx import extraction

extracted_graph = extraction(
input_op_names=['aaa', 'bbb', 'ccc'],
output_op_names=['ddd', 'eee', 'fff'],
output_onnx_file_path='output.onnx',
onnx_graph=graph,
)
```

## 6. Samples
### 6-1. Pre-extraction
Expand All @@ -147,6 +165,8 @@ $ sne4onnx \

## 7. Reference
1. https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md
2. https://github.com/PINTO0309/snd4onnx
3. https://github.com/PINTO0309/scs4onnx
4. https://github.com/PINTO0309/snc4onnx
2. https://docs.nvidia.com/deeplearning/tensorrt/onnx-graphsurgeon/docs/index.html
3. https://github.com/NVIDIA/TensorRT/tree/main/tools/onnx-graphsurgeon
4. https://github.com/PINTO0309/snd4onnx
5. https://github.com/PINTO0309/scs4onnx
6. https://github.com/PINTO0309/snc4onnx
2 changes: 1 addition & 1 deletion sne4onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from sne4onnx.onnx_network_extraction import extraction, main

__version__ = '1.0.3'
__version__ = '1.0.4'
40 changes: 27 additions & 13 deletions sne4onnx/onnx_network_extraction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#! /usr/bin/env python

import os
import sys
from argparse import ArgumentParser
import onnx
from typing import Optional, List
Expand Down Expand Up @@ -36,6 +36,7 @@ def extraction(
input_op_names: List[str],
output_op_names: List[str],
output_onnx_file_path: Optional[str] = '',
onnx_graph: Optional[onnx.ModelProto] = None,
) -> onnx.ModelProto:

"""
Expand All @@ -59,28 +60,41 @@ def extraction(
If not specified, .onnx is not output.\n\
Default: ''
onnx_graph: Optional[onnx.ModelProto]
onnx.ModelProto.\n\
Either input_onnx_file_path or onnx_graph must be specified.\n\
onnx_graph If specified, ignore input_onnx_file_path and process onnx_graph.
Returns
-------
extracted_graph: onnx.ModelProto
Extracted onnx ModelProto
"""

tmp_onnx_file = ''
if not output_onnx_file_path:
tmp_onnx_file = 'extracted.onnx'
if not input_onnx_file_path and not onnx_graph:
print(
f'{Color.RED}ERROR:{Color.RESET} '+
f'One of input_onnx_file_path or onnx_graph must be specified.'
)
sys.exit(1)

# Load
graph = None
if not onnx_graph:
graph = onnx.load(input_onnx_file_path)
else:
tmp_onnx_file = output_onnx_file_path
graph = onnx_graph

onnx.utils.extract_model(
input_onnx_file_path,
tmp_onnx_file,
# Extract
extractor = onnx.utils.Extractor(graph)
extracted_graph = extractor.extract_model(
input_op_names,
output_op_names
output_op_names,
)

extracted_graph = onnx.load(tmp_onnx_file)
if not output_onnx_file_path:
os.remove(tmp_onnx_file)
# Save
if output_onnx_file_path:
onnx.save(extracted_graph, output_onnx_file_path)

return extracted_graph

Expand Down Expand Up @@ -132,4 +146,4 @@ def main():


if __name__ == '__main__':
main()
main()

0 comments on commit 7d7be5b

Please sign in to comment.