-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference_on_device.py
138 lines (119 loc) · 4.64 KB
/
inference_on_device.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
import os
import pickle
from pathlib import Path
from time import perf_counter
import numpy as np
import onnx
import onnxruntime as ort
def has_quant_dequant_node(onnx_model: onnx.ModelProto) -> bool:
"""Checks if ONNX model has quant-dequant nodes."""
for node in onnx_model.graph.node:
if node.op_type in ("QuantizeLinear", "DequantizeLinear"):
return True
return False
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--model_path",
type=str,
required=True,
help="path to model. This is path to 'onnx file'/'artifacts dir'.",
)
parser.add_argument(
"-i",
"--input_data_dir",
type=str,
default="./pickle_data",
help="path to input data directory",
)
parser.add_argument(
"-o",
"--output_data_dir",
type=str,
default="./out_pickle",
help="path to output data directory",
)
parser.add_argument("-d", "--debug_level", type=int, default=0, help="debug level for TI device")
parser.add_argument(
"--device", required=True, choices=["jetson", "ti"], help="device name, only 'jetson' and 'ti' are supported"
)
args = parser.parse_args()
model_path = Path(args.model_path)
if args.device == "ti":
if model_path.is_dir():
providers = ["TIDLExecutionProvider", "CPUExecutionProvider"]
tidl_provider_options = {
"platform": "J7",
"version": "7.2",
"debug_level": args.debug_level,
"max_num_subgraphs": 16,
"ti_internal_nc_flag": 1601,
"tidl_tools_path": os.environ["TIDL_TOOLS_PATH"],
"artifacts_folder": str(model_path),
}
provider_options = [tidl_provider_options, {}]
onnx_path = list(model_path.glob("*.onnx"))
if len(onnx_path) == 0:
raise ValueError("cannot find model onnx in artifacts directory")
if len(onnx_path) > 1:
raise ValueError("artifacts directory must contain only one onnx")
onnx_path = onnx_path[0]
else:
providers = ["CPUExecutionProvider"]
provider_options = [{}]
onnx_path = model_path
ort_session_options = ort.SessionOptions()
session = ort.InferenceSession(
str(onnx_path),
providers=providers,
provider_options=provider_options,
sess_options=ort_session_options,
)
elif args.device == "jetson":
onnx_model = onnx.load(args.model_path)
trt_ep_options = {"trt_fp16_enable": True}
if has_quant_dequant_node(onnx_model):
trt_ep_options["trt_int8_enable"] = True
session = ort.InferenceSession(
args.model_path,
providers=[
("TensorrtExecutionProvider", trt_ep_options),
],
)
else:
ValueError("Only 'jetson' and 'ti' are supported for --device argument.")
inputs = session.get_inputs()
outputs = session.get_outputs()
if len(inputs) != 1:
raise NotImplementedError("Case with multiple inputs is not implemented")
input_name = inputs[0].name
output_names = [output.name for output in outputs]
input_data_dir = Path(args.input_data_dir)
output_data_dir = Path(args.output_data_dir)
if input_data_dir == output_data_dir:
raise ValueError("--input_data_dir and --output_data_dir cannot be the same")
if not output_data_dir.exists():
output_data_dir.mkdir(parents=True, exist_ok=True)
stats = []
t_0 = perf_counter()
for i, input_data_path in enumerate(input_data_dir.glob("*.pickle")):
with input_data_path.open("rb") as data_file:
data = pickle.load(data_file)
data = np.expand_dims(data, axis=0)
t_1 = perf_counter()
result = session.run(output_names=output_names, input_feed={input_name: data})
t_2 = perf_counter()
output_data_path = output_data_dir / f"result_{input_data_path.name}"
with output_data_path.open("wb") as result_file:
pickle.dump(result, result_file)
stats.append(t_2 - t_1)
if (i + 1) % 100 == 0:
total_time = perf_counter() - t_0
total_inference_time = sum(stats)
print("TIME STATS:")
print(f" AVG TIME PER FILE = {total_time * 1000.0 / len(stats)} ms")
print(f" AVG TIME PER RUN = {total_inference_time * 1000.0 / len(stats)} ms")
print(f" TOTAL TIME = {total_time} s")
print()