diff --git a/ai_edge_torch/odml_torch/export.py b/ai_edge_torch/odml_torch/export.py index 23cf9379..09381dae 100644 --- a/ai_edge_torch/odml_torch/export.py +++ b/ai_edge_torch/odml_torch/export.py @@ -238,6 +238,9 @@ def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram): def in_i32(x: int): return -2147483648 <= x <= 2147483647 + def to_int32(x: torch.Tensor): + return torch.ops.aten._to_copy.default(x, dtype=torch.int32) + def rewrite_arange(node: torch.fx.Node): tensor_meta = node.meta.get("tensor_meta", None) if not tensor_meta: @@ -249,7 +252,7 @@ def rewrite_arange(node: torch.fx.Node): if not (in_i32(start) and in_i32(end)): return op = node.target - node.target = lambda *args, **kwargs: op(*args, **kwargs).type(torch.int32) + node.target = lambda *args, **kwargs: to_int32(op(*args, **kwargs)) graph_module = exported_program.graph_module for node in graph_module.graph.nodes: @@ -305,9 +308,8 @@ def exported_program_to_mlir( _convert_i64_to_i32(exported_program) - exported_program = _torch_future.safe_run_decompositions( - exported_program, lowerings.decompositions() - ) + # No decompositions but just retracing/cananicalization. + exported_program = _torch_future.safe_run_decompositions(exported_program, {}) # Passes below mutate the exported program to a state not executable by torch. # Do not call run_decompositions after applying the passes. diff --git a/ai_edge_torch/odml_torch/lowerings/decomp.py b/ai_edge_torch/odml_torch/lowerings/decomp.py index 51831d34..4c8ea69e 100644 --- a/ai_edge_torch/odml_torch/lowerings/decomp.py +++ b/ai_edge_torch/odml_torch/lowerings/decomp.py @@ -55,6 +55,10 @@ def decompositions(): ], ) + # Override noop aten op decompositions for faster run_decompositions. + decompositions[torch.ops.aten.alias.default] = lambda x: x + decompositions[torch.ops.aten.detach.default] = lambda x: x + # Override _safe_softmax decompositions with regular softmax. # _safe_softmax introduces additional check-select ops to guard extreme # input values to softmax, which could make the converted model inefficient