Skip to content

Commit

Permalink
Fix regression due to torch 2.6.0
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702456988
  • Loading branch information
chunnienc authored and copybara-github committed Dec 3, 2024
1 parent b517bd0 commit f2575bf
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 30 deletions.
5 changes: 5 additions & 0 deletions ai_edge_torch/generative/test/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def test_verify_valid_recipes(
class TestQuantizeConvert(parameterized.TestCase):
"""Test conversion with quantization."""

def setUp(self):
super().setUp()
torch.manual_seed(0)
torch._dynamo.reset()

def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
return quant_config.QuantConfig(
generative_recipe=quant_recipe.GenerativeQuantRecipe(
Expand Down
14 changes: 7 additions & 7 deletions ai_edge_torch/odml_torch/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@
LoweringContext = lowerings.context.LoweringContext


def _build_flat_inputs(
ctx: ir.Context, exported_program: torch.export.ExportedProgram
):
def _build_flat_inputs(exported_program: torch.export.ExportedProgram):
"""Build flattened inputs and metadata from exported program's signature."""
placeholder_nodes = [
n for n in exported_program.graph.nodes if n.op == "placeholder"
Expand All @@ -49,9 +47,11 @@ def _build_flat_inputs(
ir_inputs = []
tensor_metas = []
for node, arg in zip(placeholder_nodes, export_flat_args):
tensor_meta = node.meta.get("tensor_meta")
tensor_meta = node.meta.get("tensor_meta") or node.meta.get("val")
if tensor_meta is None:
raise RuntimeError(f"{type(arg)} (for {node.name}) is not a tensor")
raise RuntimeError(
f"{type(arg)} (for {node.name}) does not have tensor meta"
)

tensor_metas.append(tensor_meta)
# Assume all dynamic dimensions are unbounded.
Expand All @@ -63,7 +63,7 @@ def _build_flat_inputs(
ir_inputs.append(
ir.RankedTensorType.get(
shape,
export_utils.torch_dtype_to_ir_element_type(ctx, tensor_meta.dtype),
export_utils.torch_dtype_to_ir_element_type(tensor_meta.dtype),
)
)
return tuple(ir_inputs), tuple(export_flat_args), tuple(tensor_metas)
Expand Down Expand Up @@ -277,7 +277,7 @@ def exported_program_to_mlir(
lctx = LoweringContext(context, module)
interpreter = LoweringInterpreter(exported_program.graph_module, lctx)
ir_flat_inputs, export_flat_args, tensor_metas = _build_flat_inputs(
context, exported_program
exported_program
)

# HACK: OSS MLIR pybinding could mysteriously transform func.func under
Expand Down
6 changes: 3 additions & 3 deletions ai_edge_torch/odml_torch/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def build_ir_attr(val):
return ir.StringAttr.get(str(val))


def torch_dtype_to_ir_element_type(ctx, dtype):
def torch_dtype_to_ir_element_type(dtype):
ty_get = {
torch.double: ir.F64Type.get,
torch.float32: ir.F32Type.get,
Expand All @@ -144,8 +144,8 @@ def torch_dtype_to_ir_element_type(ctx, dtype):
torch.int32: functools.partial(ir.IntegerType.get_signless, 32),
torch.int16: functools.partial(ir.IntegerType.get_signless, 16),
torch.bool: functools.partial(ir.IntegerType.get_signless, 1),
}.get(dtype)
return ty_get(ctx)
}[dtype]
return ty_get()


def ir_element_type_to_torch_dtype(ty):
Expand Down
4 changes: 1 addition & 3 deletions ai_edge_torch/odml_torch/jax_bridge/_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,7 @@ def sanitize_result_elty(result, aval):
if aval is None:
return result

target_elty = export_utils.torch_dtype_to_ir_element_type(
lctx.ir_context, aval.dtype
)
target_elty = export_utils.torch_dtype_to_ir_element_type(aval.dtype)
if result.type.element_type == target_elty:
return result
return stablehlo.convert(
Expand Down
4 changes: 1 addition & 3 deletions ai_edge_torch/odml_torch/lowerings/_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,7 @@ def _aten_cat(lctx: LoweringContext, tensors, dim=0):
if not non_empty_tensors:
return utils.splat(
0,
export_utils.torch_dtype_to_ir_element_type(
lctx.ir_context, out_aval.dtype
),
export_utils.torch_dtype_to_ir_element_type(out_aval.dtype),
out_aval.shape,
)

Expand Down
25 changes: 11 additions & 14 deletions ai_edge_torch/odml_torch/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,28 @@


def export_without_scalar_inputs(model, args, kwargs):
export_args = []
keys = []

for key, arg in [*enumerate(args), *kwargs.items()]:
flatten_args, treespec = pytree.tree_flatten([args, kwargs])

export_args = []
indices = []
for i, arg in enumerate(flatten_args):
if isinstance(arg, torch.Tensor):
export_args.append(arg)
keys.append(key)
indices.append(i)

class ModuleWrapper(torch.nn.Module):

def __init__(self, func, original_args, original_kwargs):
super().__init__()
self.original_args = [*original_args]
self.original_kwargs = original_kwargs.copy()
self.original_args = list(flatten_args)
self.func = func

def forward(self, *export_args):
args = [*self.original_args]
kwargs = self.original_kwargs.copy()

for key, arg in zip(keys, export_args):
if isinstance(key, int):
args[key] = arg
else:
kwargs[key] = arg
flatten_args = self.original_args.copy()
for i, arg in zip(indices, export_args):
flatten_args[i] = arg
args, kwargs = pytree.tree_unflatten(flatten_args, treespec)
return self.func(*args, **kwargs)

export_args = tuple(export_args)
Expand Down

0 comments on commit f2575bf

Please sign in to comment.