From 6abeb94ef7bfa92a4105644a1112b7b12b9212e2 Mon Sep 17 00:00:00 2001 From: Chun-nien Chan Date: Mon, 16 Dec 2024 21:40:30 -0800 Subject: [PATCH] fix pt2e quant ops lowering PiperOrigin-RevId: 706945163 --- ai_edge_torch/odml_torch/export.py | 4 ++++ .../odml_torch/lowerings/_quantized_decomposed.py | 7 +++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/ai_edge_torch/odml_torch/export.py b/ai_edge_torch/odml_torch/export.py index bbacdefa..dffaebea 100644 --- a/ai_edge_torch/odml_torch/export.py +++ b/ai_edge_torch/odml_torch/export.py @@ -304,9 +304,13 @@ def exported_program_to_mlir( ) _convert_i64_to_i32(exported_program) + exported_program = _torch_future.safe_run_decompositions( exported_program, lowerings.decompositions() ) + + # Passes below mutate the exported program to a state not executable by torch. + # Do not call run_decompositions after applying the passes. _convert_q_dq_per_channel_args_to_list(exported_program) with export_utils.create_ir_context() as context, ir.Location.unknown(): diff --git a/ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py b/ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py index de88fa58..95b5d3b6 100644 --- a/ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +++ b/ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py @@ -52,10 +52,13 @@ def _uniform_quantized_type( assert isinstance(scale, (list, tuple)) assert isinstance(zero_point, (list, tuple)) + scale = list(scale) + zero_point = list(zero_point) + if len(scale) == 1: - scale *= channel_axis_size + scale = scale * channel_axis_size if len(zero_point) == 1: - zero_point *= channel_axis_size + zero_point = zero_point * channel_axis_size assert len(scale) == len(zero_point) == channel_axis_size scale_zp_strs = []