-
Notifications
You must be signed in to change notification settings - Fork 23
/
quantize_model.py
63 lines (50 loc) · 1.99 KB
/
quantize_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import argparse
import json
import onnx
import quantize
def main():
parser = argparse.ArgumentParser(
description='Quantize model with specified parameters')
parser.add_argument('--no_per_channel', '-t',
action='store_true', default=False)
parser.add_argument('--nbits', type=int, default=8)
parser.add_argument('--quantization_mode', default='Integer',
choices=('Integer', 'QLinear'))
parser.add_argument('--static', '-s', action='store_true', default=False)
parser.add_argument('--asymmetric_input_types',
action='store_true', default=False)
parser.add_argument('--input_quantization_params', default='')
parser.add_argument('--output_quantization_params', default='')
parser.add_argument('model')
parser.add_argument('output')
args = parser.parse_args()
args.per_channel = not args.no_per_channel
del args.no_per_channel
if args.quantization_mode == 'QLinear':
args.quantization_mode = quantize.QuantizationMode.QLinearOps
else:
args.quantization_mode = quantize.QuantizationMode.IntegerOps
if len(args.input_quantization_params) != 0:
args.input_quantization_params = json.loads(
args.input_quantization_params)
else:
args.input_quantization_params = None
if len(args.output_quantization_params) != 0:
args.output_quantization_params = json.loads(
args.output_quantization_params)
else:
args.output_quantization_params = None
# Load the onnx model
model_file = args.model
model = onnx.load(model_file)
del args.model
output_file = args.output
del args.output
# Quantize
print('Quantize config: {}'.format(vars(args)))
quantized_model = quantize.quantize(model, **vars(args))
print('Saving "{}" to "{}"'.format(model_file, output_file))
# Save the quantized model
onnx.save(quantized_model, output_file)
if __name__ == '__main__':
main()