Skip to content

Commit

Permalink
fix pt2e quant ops lowering
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 706945163
  • Loading branch information
chunnienc authored and copybara-github committed Dec 17, 2024
1 parent fc9c986 commit 6abeb94
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
4 changes: 4 additions & 0 deletions ai_edge_torch/odml_torch/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,13 @@ def exported_program_to_mlir(
)

_convert_i64_to_i32(exported_program)

exported_program = _torch_future.safe_run_decompositions(
exported_program, lowerings.decompositions()
)

# Passes below mutate the exported program to a state not executable by torch.
# Do not call run_decompositions after applying the passes.
_convert_q_dq_per_channel_args_to_list(exported_program)

with export_utils.create_ir_context() as context, ir.Location.unknown():
Expand Down
7 changes: 5 additions & 2 deletions ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@ def _uniform_quantized_type(
assert isinstance(scale, (list, tuple))
assert isinstance(zero_point, (list, tuple))

scale = list(scale)
zero_point = list(zero_point)

if len(scale) == 1:
scale *= channel_axis_size
scale = scale * channel_axis_size
if len(zero_point) == 1:
zero_point *= channel_axis_size
zero_point = zero_point * channel_axis_size

assert len(scale) == len(zero_point) == channel_axis_size
scale_zp_strs = []
Expand Down

0 comments on commit 6abeb94

Please sign in to comment.