Skip to content

Commit

Permalink
Add lowerings for PT2E QDQ ops
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702509772
  • Loading branch information
chunnienc authored and copybara-github committed Dec 4, 2024
1 parent a055e62 commit 2df5bd2
Show file tree
Hide file tree
Showing 7 changed files with 361 additions and 13 deletions.
48 changes: 48 additions & 0 deletions ai_edge_torch/_convert/test/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
import ai_edge_torch
from ai_edge_torch import config
from ai_edge_torch._convert import conversion_utils
from ai_edge_torch.quantize import pt2e_quantizer
from ai_edge_torch.testing import model_coverage
import numpy as np
import torch
from torch import nn
from torch.ao.quantization import quantize_pt2e
import torchvision

from absl.testing import absltest as googletest
Expand Down Expand Up @@ -506,6 +508,52 @@ def forward(self, x):
model_coverage.compare_tflite_torch(edge_model, torch_module, args)
)

def test_convert_resnet18_pt2e_per_layer(self):
# Step 1: export resnet18
args = (torch.randn(1, 3, 224, 224),)
m = torchvision.models.resnet18().eval()
m = torch._export.capture_pre_autograd_graph(m, args)

# Step 2: Insert observers or fake quantize modules
quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
pt2e_quantizer.get_symmetric_quantization_config(is_per_channel=False)
)
m = quantize_pt2e.prepare_pt2e(m, quantizer)

# Step 3: Quantize the model
m = quantize_pt2e.convert_pt2e(m, fold_quantize=False)

# pylint: disable=broad-except
try:
ai_edge_torch.convert(m, args)
except Exception as err:
self.fail(f"PT2E conversion failed: {err}")
# pylint: enable=broad-except

def test_convert_resnet18_pt2e_per_channel(self):
# Step 1: export resnet18
args = (torch.randn(1, 3, 224, 224),)
m = torchvision.models.resnet18().eval()
m = torch._export.capture_pre_autograd_graph(m, args)

# Step 2: Insert observers or fake quantize modules
quantizer = pt2e_quantizer.PT2EQuantizer().set_global(
pt2e_quantizer.get_symmetric_quantization_config(is_per_channel=True)
)
m = quantize_pt2e.prepare_pt2e(m, quantizer)
# Step 3: Run through example inputs, otherwise per-channel
# quant may have scalar scale/zero_point
m(*args)
# Step 4: Quantize the model
m = quantize_pt2e.convert_pt2e(m, fold_quantize=False)

# pylint: disable=broad-except
try:
ai_edge_torch.convert(m, args)
except Exception as err:
self.fail(f"PT2E conversion failed: {err}")
# pylint: enable=broad-except


if __name__ == "__main__":
googletest.main()
38 changes: 38 additions & 0 deletions ai_edge_torch/odml_torch/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,43 @@ def rewrite_arange(node: torch.fx.Node):
rewrite_arange(node)


# TODO(b/331481564) Make this a ai_edge_torch FX pass.
def _convert_q_dq_per_channel_args_to_list(
exported_program: torch.export.ExportedProgram,
):
"""Resolve tensor inputs to Q/DQ ops as static number list for lowering.
This pass makes the ExportedProgram in a non-executable state. This pass must
be run after all run_decompositions calls.
"""
placeholder_nodes = [
n for n in exported_program.graph.nodes if n.op == "placeholder"
]
export_flat_args = _torch_future.graph_module_flat_inputs(
exported_program, *exported_program.example_inputs
)

placeholder_tensor = {
n: tensor for n, tensor in zip(placeholder_nodes, export_flat_args)
}

graph_module = exported_program.graph_module
for node in graph_module.graph.nodes:
if node.target in (
torch.ops.quantized_decomposed.quantize_per_channel.default,
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
torch.ops.quantized_decomposed.dequantize_per_channel.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
):
input, scale_node, zero_point_node = node.args[:3]
scale = placeholder_tensor[scale_node]
zero_point = placeholder_tensor[zero_point_node]

scale = scale.detach().numpy().tolist()
zero_point = zero_point.detach().numpy().tolist()
node.args = (input, scale, zero_point, *node.args[3:])


def exported_program_to_mlir(
exported_program: torch.export.ExportedProgram,
) -> MlirLowered:
Expand All @@ -270,6 +307,7 @@ def exported_program_to_mlir(
exported_program = _torch_future.safe_run_decompositions(
exported_program, lowerings.decompositions()
)
_convert_q_dq_per_channel_args_to_list(exported_program)

with export_utils.create_ir_context() as context, ir.Location.unknown():

Expand Down
15 changes: 2 additions & 13 deletions ai_edge_torch/odml_torch/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
# ==============================================================================
"""Utilities for ODML Torch export."""

import functools
import re
from typing import Sequence, cast
from ai_edge_torch.odml_torch.lowerings import utils as lowering_utils
import jax._src.interpreters.mlir
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func
Expand Down Expand Up @@ -47,7 +47,6 @@ def create_ir_context():
# TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
context = jax._src.interpreters.mlir.make_ir_context()
context.allow_unregistered_dialects = True

return context


Expand Down Expand Up @@ -135,17 +134,7 @@ def build_ir_attr(val):
return ir.StringAttr.get(str(val))


def torch_dtype_to_ir_element_type(dtype):
ty_get = {
torch.double: ir.F64Type.get,
torch.float32: ir.F32Type.get,
torch.half: ir.F16Type.get,
torch.long: functools.partial(ir.IntegerType.get_signless, 64),
torch.int32: functools.partial(ir.IntegerType.get_signless, 32),
torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
}[dtype]
return ty_get()
torch_dtype_to_ir_element_type = lowering_utils.torch_dtype_to_ir_element_type


def ir_element_type_to_torch_dtype(ty):
Expand Down
1 change: 1 addition & 0 deletions ai_edge_torch/odml_torch/lowerings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from . import _convolution
from . import _jax_lowerings
from . import _layer_norm
from . import _quantized_decomposed
from . import context
from . import registry
from . import utils
Expand Down
174 changes: 174 additions & 0 deletions ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Lowerings for PT2E torch.ops.quantized_decomposed ops."""
from typing import Union, cast

from ai_edge_torch.odml_torch.lowerings import context
from ai_edge_torch.odml_torch.lowerings import utils
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo as stablehlo
import torch
import torch.ao.quantization.fx._decomposed
import torch.utils._pytree as pytree

from . import registry

lower = registry.lower
LoweringContext = context.LoweringContext


def _uniform_quantized_type(
stored_type: str | ir.Type,
expressed_type: str | ir.Type,
*,
scale=float | list[float] | tuple[float],
zero_point=float | list[float] | tuple[float],
storage_type_min: int | None = None,
storage_type_max: int | None = None,
channel_axis: int | None = None,
channel_axis_size: int | None = None,
):
"""Polyfill for quant.UniformQuantizedType."""
if storage_type_min and storage_type_max:
storage_min_max = f"<{storage_type_min}:{storage_type_max}>"
else:
storage_min_max = ""

if channel_axis is not None:
# Per-channel quantization
# https://mlir.llvm.org/docs/Dialects/QuantDialect/#per-channel-quantization
assert isinstance(scale, (list, tuple))
assert isinstance(zero_point, (list, tuple))

if len(scale) == 1:
scale *= channel_axis_size
if len(zero_point) == 1:
zero_point *= channel_axis_size

assert len(scale) == len(zero_point) == channel_axis_size
scale_zp_strs = []
for s, zp in zip(scale, zero_point):
scale_zp_strs.append(f"{s}:{zp}")
scale_zp = "{" + ",".join(scale_zp_strs) + "}"
return ir.Type.parse(
f"!quant.uniform<{stored_type}{storage_min_max}:{expressed_type}:{channel_axis},{scale_zp}>"
)
else:
# Per-layer quantization
# https://mlir.llvm.org/docs/Dialects/QuantDialect/#per-layer-quantization
scale = pytree.tree_flatten([scale])[0][-1]
zero_point = pytree.tree_flatten([zero_point])[0][-1]
scale_zp = f"{scale}:{zero_point}"
return ir.Type.parse(
f"!quant.uniform<{stored_type}{storage_min_max}:{expressed_type},{scale_zp}>"
)


# Quant dialect is not registered in the Python MLIR pybinding used by
# odml-torch. Therefore, stablehlo.uniform_quantize/uniform_dequantize ops and
# quant types are represented in stablehlo.custom_call to pass MLIR verification
# and VHLO serialization before converter.
# TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.


# Schema:
# - quantized_decomposed::quantize_per_tensor(Tensor input, float scale,
# int zero_point, int quant_min, int quant_max,
# ScalarType dtype) -> Tensor
# - quantized_decomposed::quantize_per_tensor.tensor(Tensor input,
# Tensor scale, Tensor zero_point, int quant_min, int quant_max,
# ScalarType dtype) -> Tensor
#
# Scale and zero_point in tensors are automatically converted to list before
# lowering.
@lower(torch.ops.quantized_decomposed.quantize_per_tensor)
def _quantize_per_tensor(
lctx: LoweringContext,
input: ir.Value,
scale: Union[float, list[float]],
zero_point: Union[float, list[float]],
quant_min: int,
quant_max: int,
dtype: torch.dtype,
):
input_ty = cast(ir.RankedTensorType, input.type)
qty = _uniform_quantized_type(
utils.torch_dtype_to_ir_element_type(dtype),
input_ty.element_type,
scale=scale,
zero_point=zero_point,
storage_type_min=quant_min,
storage_type_max=quant_max,
)
return stablehlo.custom_call(
call_target_name="odml_torch.uniform_quantize",
inputs=[input],
result=[input_ty],
backend_config=ir.StringAttr.get(
str(ir.RankedTensorType.get(input_ty.shape, qty))
),
)


# Schema:
# - quantized_decomposed::quantize_per_channel(Tensor input, Tensor scales,
# Tensor zero_points, int axis, int quant_min, int quant_max,
# ScalarType dtype) -> Tensor
#
# Scale and zero_point in tensors are automatically converted to list before
# lowering.
@lower(torch.ops.quantized_decomposed.quantize_per_channel)
def _quantize_per_channel(
lctx: LoweringContext,
input: ir.Value,
scale: list[float],
zero_point: list[float],
axis: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
):
input_ty = cast(ir.RankedTensorType, input.type)
qty = _uniform_quantized_type(
utils.torch_dtype_to_ir_element_type(dtype),
input_ty.element_type,
scale=scale,
zero_point=zero_point,
channel_axis=axis,
channel_axis_size=input_ty.shape[axis],
storage_type_min=quant_min,
storage_type_max=quant_max,
)
return stablehlo.custom_call(
call_target_name="odml_torch.uniform_quantize",
inputs=[input],
result=[input_ty],
backend_config=ir.StringAttr.get(
str(ir.RankedTensorType.get(input_ty.shape, qty))
),
)


@lower(torch.ops.quantized_decomposed.dequantize_per_tensor)
@lower(torch.ops.quantized_decomposed.dequantize_per_channel)
def _dequantize(lctx: LoweringContext, input: ir.Value, *args, **kwargs):
result_meta = lctx.node.meta.get("tensor_meta")
result_elty = utils.torch_dtype_to_ir_element_type(result_meta.dtype)

return stablehlo.custom_call(
call_target_name="odml_torch.uniform_dequantize",
inputs=[input],
result=[ir.RankedTensorType.get(result_meta.shape, result_elty)],
)
16 changes: 16 additions & 0 deletions ai_edge_torch/odml_torch/lowerings/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,29 @@
# ==============================================================================
"""Utilities for building MLIR lowerings."""

import functools
import numbers
from typing import Any
from typing import Optional

from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo as stablehlo
import numpy as np
import torch


def torch_dtype_to_ir_element_type(dtype):
ty_get = {
torch.double: ir.F64Type.get,
torch.float32: ir.F32Type.get,
torch.half: ir.F16Type.get,
torch.long: functools.partial(ir.IntegerType.get_signless, 64),
torch.int32: functools.partial(ir.IntegerType.get_signless, 32),
torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
torch.int8: functools.partial(ir.IntegerType.get_signless, 8),
torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
}[dtype]
return ty_get()


def splat(val, ty, shape=tuple(), *, loc: Optional[Any] = None):
Expand Down
Loading

0 comments on commit 2df5bd2

Please sign in to comment.