Skip to content

Commit

Permalink
update python code
Browse files Browse the repository at this point in the history
  • Loading branch information
keosu committed Mar 7, 2024
1 parent b01884a commit 0698bbd
Show file tree
Hide file tree
Showing 4 changed files with 1,079 additions and 43 deletions.
249 changes: 214 additions & 35 deletions demos/onnx/basic.ipynb
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ONNX 模型基本信息"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "python"
}
},
"metadata": {},
"outputs": [],
"source": [
"from functools import reduce\n",
Expand Down Expand Up @@ -71,11 +74,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "python"
}
},
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
Expand All @@ -84,27 +83,13 @@
"def infer_onnx(model_path):\n",
" # Load the ONNX model\n",
" sess = ort.InferenceSession(model_path)\n",
" onnx_to_numpy_dtype = {\n",
" \"tensor(float16)\": np.float16,\n",
" \"tensor(float)\": np.float32,\n",
" \"tensor(double)\": np.float64,\n",
" \"tensor(int8)\": np.int8,\n",
" \"tensor(int16)\": np.int16,\n",
" \"tensor(int32)\": np.int32,\n",
" \"tensor(int64)\": np.int64,\n",
" \"tensor(uint8)\": np.uint8,\n",
" \"tensor(uint16)\": np.uint16,\n",
" \"tensor(uint32)\": np.uint32,\n",
" \"tensor(uint64)\": np.uint64,\n",
" # 'tensor(bool)': np.bool,\n",
" }\n",
"\n",
" \n",
" feeds = {}\n",
" for input in sess.get_inputs():\n",
" # Get input name and shape\n",
" input_name = input.name\n",
" input_shape = input.shape\n",
" numpy_type = onnx_to_numpy_dtype.get(input.type, None)\n",
" numpy_type = TENSOR_TYPE_TO_NP_TYPE[input.type]\n",
" # Generate random input\n",
" input_data = np.random.random_sample(input_shape).astype(numpy_type)\n",
"\n",
Expand All @@ -115,21 +100,215 @@
"\n",
" print([a.shape for a in result])\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ONNX 模型构造"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import onnx \n",
"from onnx import helper \n",
"from onnx import TensorProto \n",
" \n",
"# input and output \n",
"a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [10, 10]) \n",
"x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 10]) \n",
"b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [10, 10]) \n",
"output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [10, 10]) \n",
" \n",
"# Mul \n",
"mul = helper.make_node('Mul', ['a', 'x'], ['c']) \n",
" \n",
"# Add \n",
"add = helper.make_node('Add', ['c', 'b'], ['output']) \n",
" \n",
"# graph and model \n",
"graph = helper.make_graph([mul, add], 'linear_func', [a, x, b], [output]) \n",
"model = helper.make_model(graph) \n",
" \n",
"# save model \n",
"onnx.checker.check_model(model) \n",
"print(model) \n",
"onnx.save(model, 'linear_func.onnx')\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import onnxruntime \n",
"import numpy as np \n",
" \n",
"sess = onnxruntime.InferenceSession('linear_func.onnx') \n",
"a = np.random.rand(10, 10).astype(np.float32) \n",
"b = np.random.rand(10, 10).astype(np.float32) \n",
"x = np.random.rand(10, 10).astype(np.float32) \n",
" \n",
"output = sess.run(['output'], {'a': a, 'b': b, 'x': x})[0] \n",
" \n",
"assert np.allclose(output, a * x + b)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ONNX 模型修改\n",
"\n",
"**NOTE**: 下面的例子将Add改成Sub,一般如果不是类似的OP,直接修改op_type可能会有问题"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import onnx \n",
"model = onnx.load('linear_func.onnx') \n",
" \n",
"node = model.graph.node \n",
"node[1].op_type = 'Sub' \n",
" \n",
"onnx.checker.check_model(model) \n",
"print(model)\n",
"onnx.save(model, 'linear_func_2.onnx')\n",
"\n",
"import onnxruntime \n",
"import numpy as np \n",
" \n",
"sess = onnxruntime.InferenceSession('linear_func_2.onnx') \n",
"a = np.random.rand(10, 10).astype(np.float32) \n",
"b = np.random.rand(10, 10).astype(np.float32) \n",
"x = np.random.rand(10, 10).astype(np.float32) \n",
" \n",
"output = sess.run(['output'], {'a': a, 'b': b, 'x': x})[0] \n",
" \n",
"assert np.allclose(output, a * x - b)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ONNX 子图提取"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import onnx \n",
" \n",
"onnx.utils.extract_model('linear_func_2.onnx', 'mul.onnx', ['a','x'], ['c'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Test\n",
"from pprint import pprint\n",
"from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE,NP_TYPE_TO_TENSOR_TYPE\n",
"\n",
"pprint(TENSOR_TYPE_TO_NP_TYPE)\n",
"pprint(NP_TYPE_TO_TENSOR_TYPE)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\"> Date </span>┃<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\"> Title </span>┃<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\"> Production Budget </span>┃<span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\"> Box Office </span>┃\n",
"┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━┩\n",
"│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> Dev 20, 2019 </span>│ Star Wars: The Rise of Skywalker │ $275,000,000 │ $375,126,118 │\n",
"│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> May 25, 2018 </span>│ <span style=\"color: #800000; text-decoration-color: #800000\">Solo</span>: A Star Wars Story │ $275,000,000 │ $393,151,347 │\n",
"│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> Dec 15, 2017 </span>│ Star Wars Ep. VIII: The Last Jedi │ $262,000,000 │ <span style=\"font-weight: bold\">$1,332,539,889</span> │\n",
"└──────────────┴───────────────────────────────────┴───────────────────┴────────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1;32m \u001b[0m\u001b[1;32mDate \u001b[0m\u001b[1;32m \u001b[0m┃\u001b[1;32m \u001b[0m\u001b[1;32mTitle \u001b[0m\u001b[1;32m \u001b[0m┃\u001b[1;32m \u001b[0m\u001b[1;32mProduction Budget\u001b[0m\u001b[1;32m \u001b[0m┃\u001b[1;32m \u001b[0m\u001b[1;32m Box Office\u001b[0m\u001b[1;32m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━┩\n",
"│\u001b[2m \u001b[0m\u001b[2mDev 20, 2019\u001b[0m\u001b[2m \u001b[0m│ Star Wars: The Rise of Skywalker │ $275,000,000 │ $375,126,118 │\n",
"│\u001b[2m \u001b[0m\u001b[2mMay 25, 2018\u001b[0m\u001b[2m \u001b[0m│ \u001b[31mSolo\u001b[0m: A Star Wars Story │ $275,000,000 │ $393,151,347 │\n",
"│\u001b[2m \u001b[0m\u001b[2mDec 15, 2017\u001b[0m\u001b[2m \u001b[0m│ Star Wars Ep. VIII: The Last Jedi │ $262,000,000 │ \u001b[1m$1,332,539,889\u001b[0m │\n",
"└──────────────┴───────────────────────────────────┴───────────────────┴────────────────┘\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from rich.console import Console\n",
"from rich.table import Column, Table\n",
"\n",
"console = Console()\n",
"\n",
"table = Table(show_header=True, header_style=\"bold green\")\n",
"table.add_column(\"Date\", style=\"dim\", width=12)\n",
"table.add_column(\"Title\")\n",
"table.add_column(\"Production Budget\", justify=\"right\")\n",
"table.add_column(\"Box Office\", justify=\"right\")\n",
"table.add_row(\n",
" \"Dev 20, 2019\", \"Star Wars: The Rise of Skywalker\", \"$275,000,000\", \"$375,126,118\"\n",
")\n",
"table.add_row(\n",
" \"May 25, 2018\",\n",
" \"[red]Solo[/red]: A Star Wars Story\",\n",
" \"$275,000,000\",\n",
" \"$393,151,347\",\n",
")\n",
"table.add_row(\n",
" \"Dec 15, 2017\",\n",
" \"Star Wars Ep. VIII: The Last Jedi\",\n",
" \"$262,000,000\",\n",
" \"[bold]$1,332,539,889[/bold]\",\n",
")\n",
"\n",
"console.print(table)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Rust",
"language": "rust",
"name": "rust"
"display_name": "main",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": "rust",
"file_extension": ".rs",
"mimetype": "text/rust",
"name": "Rust",
"pygment_lexer": "rust",
"version": ""
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 0698bbd

Please sign in to comment.