From f2575bf0678c6f6ece3a771ccceee155bdbc9550 Mon Sep 17 00:00:00 2001 From: Chun-nien Chan Date: Tue, 3 Dec 2024 13:47:59 -0800 Subject: [PATCH] Fix regression due to torch 2.6.0 PiperOrigin-RevId: 702456988 --- .../generative/test/test_quantize.py | 5 ++++ ai_edge_torch/odml_torch/export.py | 14 +++++------ ai_edge_torch/odml_torch/export_utils.py | 6 ++--- ai_edge_torch/odml_torch/jax_bridge/_wrap.py | 4 +-- ai_edge_torch/odml_torch/lowerings/_basic.py | 4 +-- .../odml_torch/test/test_core_aten_ops.py | 25 ++++++++----------- 6 files changed, 28 insertions(+), 30 deletions(-) diff --git a/ai_edge_torch/generative/test/test_quantize.py b/ai_edge_torch/generative/test/test_quantize.py index ca7f9ebf..f95daa91 100644 --- a/ai_edge_torch/generative/test/test_quantize.py +++ b/ai_edge_torch/generative/test/test_quantize.py @@ -91,6 +91,11 @@ def test_verify_valid_recipes( class TestQuantizeConvert(parameterized.TestCase): """Test conversion with quantization.""" + def setUp(self): + super().setUp() + torch.manual_seed(0) + torch._dynamo.reset() + def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig: return quant_config.QuantConfig( generative_recipe=quant_recipe.GenerativeQuantRecipe( diff --git a/ai_edge_torch/odml_torch/export.py b/ai_edge_torch/odml_torch/export.py index 0f410fa9..1026e6a1 100644 --- a/ai_edge_torch/odml_torch/export.py +++ b/ai_edge_torch/odml_torch/export.py @@ -35,9 +35,7 @@ LoweringContext = lowerings.context.LoweringContext -def _build_flat_inputs( - ctx: ir.Context, exported_program: torch.export.ExportedProgram -): +def _build_flat_inputs(exported_program: torch.export.ExportedProgram): """Build flattened inputs and metadata from exported program's signature.""" placeholder_nodes = [ n for n in exported_program.graph.nodes if n.op == "placeholder" @@ -49,9 +47,11 @@ def _build_flat_inputs( ir_inputs = [] tensor_metas = [] for node, arg in zip(placeholder_nodes, export_flat_args): - tensor_meta = node.meta.get("tensor_meta") + tensor_meta = node.meta.get("tensor_meta") or node.meta.get("val") if tensor_meta is None: - raise RuntimeError(f"{type(arg)} (for {node.name}) is not a tensor") + raise RuntimeError( + f"{type(arg)} (for {node.name}) does not have tensor meta" + ) tensor_metas.append(tensor_meta) # Assume all dynamic dimensions are unbounded. @@ -63,7 +63,7 @@ def _build_flat_inputs( ir_inputs.append( ir.RankedTensorType.get( shape, - export_utils.torch_dtype_to_ir_element_type(ctx, tensor_meta.dtype), + export_utils.torch_dtype_to_ir_element_type(tensor_meta.dtype), ) ) return tuple(ir_inputs), tuple(export_flat_args), tuple(tensor_metas) @@ -277,7 +277,7 @@ def exported_program_to_mlir( lctx = LoweringContext(context, module) interpreter = LoweringInterpreter(exported_program.graph_module, lctx) ir_flat_inputs, export_flat_args, tensor_metas = _build_flat_inputs( - context, exported_program + exported_program ) # HACK: OSS MLIR pybinding could mysteriously transform func.func under diff --git a/ai_edge_torch/odml_torch/export_utils.py b/ai_edge_torch/odml_torch/export_utils.py index df30350b..8b9d8ae8 100644 --- a/ai_edge_torch/odml_torch/export_utils.py +++ b/ai_edge_torch/odml_torch/export_utils.py @@ -135,7 +135,7 @@ def build_ir_attr(val): return ir.StringAttr.get(str(val)) -def torch_dtype_to_ir_element_type(ctx, dtype): +def torch_dtype_to_ir_element_type(dtype): ty_get = { torch.double: ir.F64Type.get, torch.float32: ir.F32Type.get, @@ -144,8 +144,8 @@ def torch_dtype_to_ir_element_type(ctx, dtype): torch.int32: functools.partial(ir.IntegerType.get_signless, 32), torch.int16: functools.partial(ir.IntegerType.get_signless, 16), torch.bool: functools.partial(ir.IntegerType.get_signless, 1), - }.get(dtype) - return ty_get(ctx) + }[dtype] + return ty_get() def ir_element_type_to_torch_dtype(ty): diff --git a/ai_edge_torch/odml_torch/jax_bridge/_wrap.py b/ai_edge_torch/odml_torch/jax_bridge/_wrap.py index 1cea1f47..bdf5c63d 100644 --- a/ai_edge_torch/odml_torch/jax_bridge/_wrap.py +++ b/ai_edge_torch/odml_torch/jax_bridge/_wrap.py @@ -163,9 +163,7 @@ def sanitize_result_elty(result, aval): if aval is None: return result - target_elty = export_utils.torch_dtype_to_ir_element_type( - lctx.ir_context, aval.dtype - ) + target_elty = export_utils.torch_dtype_to_ir_element_type(aval.dtype) if result.type.element_type == target_elty: return result return stablehlo.convert( diff --git a/ai_edge_torch/odml_torch/lowerings/_basic.py b/ai_edge_torch/odml_torch/lowerings/_basic.py index aa4454fb..808de8b4 100644 --- a/ai_edge_torch/odml_torch/lowerings/_basic.py +++ b/ai_edge_torch/odml_torch/lowerings/_basic.py @@ -227,9 +227,7 @@ def _aten_cat(lctx: LoweringContext, tensors, dim=0): if not non_empty_tensors: return utils.splat( 0, - export_utils.torch_dtype_to_ir_element_type( - lctx.ir_context, out_aval.dtype - ), + export_utils.torch_dtype_to_ir_element_type(out_aval.dtype), out_aval.shape, ) diff --git a/ai_edge_torch/odml_torch/test/test_core_aten_ops.py b/ai_edge_torch/odml_torch/test/test_core_aten_ops.py index 82cfccc8..d0692371 100644 --- a/ai_edge_torch/odml_torch/test/test_core_aten_ops.py +++ b/ai_edge_torch/odml_torch/test/test_core_aten_ops.py @@ -22,31 +22,28 @@ def export_without_scalar_inputs(model, args, kwargs): - export_args = [] - keys = [] - for key, arg in [*enumerate(args), *kwargs.items()]: + flatten_args, treespec = pytree.tree_flatten([args, kwargs]) + + export_args = [] + indices = [] + for i, arg in enumerate(flatten_args): if isinstance(arg, torch.Tensor): export_args.append(arg) - keys.append(key) + indices.append(i) class ModuleWrapper(torch.nn.Module): def __init__(self, func, original_args, original_kwargs): super().__init__() - self.original_args = [*original_args] - self.original_kwargs = original_kwargs.copy() + self.original_args = list(flatten_args) self.func = func def forward(self, *export_args): - args = [*self.original_args] - kwargs = self.original_kwargs.copy() - - for key, arg in zip(keys, export_args): - if isinstance(key, int): - args[key] = arg - else: - kwargs[key] = arg + flatten_args = self.original_args.copy() + for i, arg in zip(indices, export_args): + flatten_args[i] = arg + args, kwargs = pytree.tree_unflatten(flatten_args, treespec) return self.func(*args, **kwargs) export_args = tuple(export_args)