Skip to content

Commit

Permalink
Fix q8 quantization for models > 2GB
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Jul 4, 2024
1 parent b411e9f commit 04a334a
Showing 1 changed file with 27 additions and 12 deletions.
39 changes: 27 additions & 12 deletions scripts/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from transformers import HfArgumentParser
from optimum.onnx.graph_transformations import check_and_save_model

from onnxruntime.quantization import quantize_dynamic, QuantType
from onnxruntime.quantization import QuantType, QuantizationMode
from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer
from onnxruntime.quantization.registry import IntegerOpsRegistry
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
from onnxruntime.quantization.matmul_bnb4_quantizer import MatMulBnb4Quantizer
from onnxconverter_common import float16
Expand Down Expand Up @@ -125,20 +127,33 @@ def quantize_q8(
):
"""
Quantize the weights of the model from float32 to int8/uint8
Uses unsigned ints for activation values, signed ints for weights, per
https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection
it is faster on most CPU architectures
"""

# Uses unsigned ints for activation values, signed ints for weights, per
# https://onnxruntime.ai/docs/performance/quantization.html#data-type-selection
# it is faster on most CPU architectures
quantize_dynamic(
model_input=model,
model_output=save_path,
weight_type=weight_type,
per_channel=per_channel,
reduce_range=reduce_range,
extra_options=dict(EnableSubgraph=True),
quantizer = ONNXQuantizer(
model,
per_channel,
reduce_range,
mode=QuantizationMode.IntegerOps,
static=False,
weight_qType=weight_type,
activation_qType=QuantType.QUInt8, # dynamic activation only supports uint8
tensors_range=None,
nodes_to_quantize=[],
nodes_to_exclude=[],
op_types_to_quantize=list(IntegerOpsRegistry.keys()),
extra_options=dict(
EnableSubgraph=True,
MatMulConstBOnly=True,
),
)

quantizer.quantize_model()
check_and_save_model(quantizer.model, save_path)


def quantize_fp16(
model: onnx.ModelProto,
Expand Down Expand Up @@ -295,7 +310,7 @@ def main():
weight_type = QuantType.QUInt8

quantize_q8(
model_path,
model,
save_path,
per_channel=quantization_args.per_channel,
reduce_range=quantization_args.reduce_range,
Expand Down

0 comments on commit 04a334a

Please sign in to comment.