From bb140e8739b2d3801ca3a481c2c129f8f2e8e83d Mon Sep 17 00:00:00 2001 From: Chun-nien Chan Date: Wed, 11 Dec 2024 16:24:59 -0800 Subject: [PATCH] use jax bridged einsum lowering PiperOrigin-RevId: 705278809 --- .../odml_torch/lowerings/_jax_lowerings.py | 30 +++++++++++++++++-- ai_edge_torch/odml_torch/lowerings/decomp.py | 8 ++++- .../odml_torch/test/test_core_aten_ops.py | 1 + 3 files changed, 36 insertions(+), 3 deletions(-) diff --git a/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py b/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py index 57dd4bd3..f2d1ca4e 100644 --- a/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +++ b/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py @@ -16,12 +16,15 @@ import logging from ai_edge_torch.odml_torch import jax_bridge +from ai_edge_torch.odml_torch.lowerings import context +from ai_edge_torch.odml_torch.lowerings import registry +import jax.numpy as jnp +from jax._src.lib.mlir import ir import torch import torch_xla2.ops.jaten # Import to load torch_xla2 ops import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops -from . import registry - +LoweringContext = context.LoweringContext @functools.cache def _log_usage(op): @@ -258,3 +261,26 @@ def _aten_copy(self, *args, **kwargs): @lower_by_jax(torch.ops.aten.copy, ir_input_names=["src"]) def _aten_copy(self, src, **kwargs): return _TORCH_XLA2_IMPLS[torch.ops.aten.copy](self, src) + + +# Schema: +# - aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None) +# -> Tensor +# Torch Reference: +# - https://pytorch.org/docs/stable/generated/torch.einsum.html +# - https://github.com/pytorch/pytorch/blob/1b3f8b75896720e88362cbec7db32abc52afa83e/aten/src/ATen/native/Linear.cpp#L255 +@registry.lower(torch.ops.aten.einsum.default) +def _aten_einsum_default( + lctx: LoweringContext, + equation: str, + tensors: list[ir.Value], + path=None, +): + _log_usage(torch.ops.aten.einsum.default) + + @jax_bridge.wrap + def jax_lowering(operands): + # Ignore the input path and let JAX determine the path. + return jnp.einsum(equation, *operands, optimize="optimal") + + return jax_lowering(lctx, tuple(tensors)) diff --git a/ai_edge_torch/odml_torch/lowerings/decomp.py b/ai_edge_torch/odml_torch/lowerings/decomp.py index 5dabb293..51831d34 100644 --- a/ai_edge_torch/odml_torch/lowerings/decomp.py +++ b/ai_edge_torch/odml_torch/lowerings/decomp.py @@ -46,7 +46,13 @@ def decompositions(): torch._decomp.remove_decompositions( decompositions, - [torch.ops.aten.roll], + [ + torch.ops.aten.roll, + # Torch's default einsum impl/decompositions is less efficient and + # optimized through converter than JAX's impl. Disable einsum + # decomposition to use JAX bridge for a more efficient lowering. + torch.ops.aten.einsum.default, + ], ) # Override _safe_softmax decompositions with regular softmax. diff --git a/ai_edge_torch/odml_torch/test/test_core_aten_ops.py b/ai_edge_torch/odml_torch/test/test_core_aten_ops.py index d0692371..9e35e125 100644 --- a/ai_edge_torch/odml_torch/test/test_core_aten_ops.py +++ b/ai_edge_torch/odml_torch/test/test_core_aten_ops.py @@ -229,6 +229,7 @@ def _run_export_and_compare( ("aten_div_Tensor_0", torch.ops.aten.div.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()), ("aten_div_Tensor_mode_trunc_0", torch.ops.aten.div.Tensor_mode, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), {"rounding_mode": "trunc"}), ("aten_div_Tensor_mode_trunc_1", torch.ops.aten.div.Tensor_mode, (rnd(torch.int32, (10, 10)), rnd(torch.int32, (10, 10)),), {"rounding_mode": "trunc"}), + ("aten_einsum", torch.ops.aten.einsum, ("abcd,abed->abce", [rnd(torch.float32, (1, 4, 2, 16)), rnd(torch.float32, (1, 4, 32, 16))]), dict()), ("aten_embedding_0", torch.ops.aten.embedding, (rnd(torch.float32, (10, 10)), rnd(torch.int64, (10,)),), dict()), ("aten_eq_Scalar_2", torch.ops.aten.eq.Scalar, (rnd(torch.float32, (10, 10)), 1,), dict()), ("aten_eq_Tensor_0", torch.ops.aten.eq.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),