Skip to content

Commit

Permalink
use jax bridged einsum lowering
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705278809
  • Loading branch information
chunnienc authored and copybara-github committed Dec 12, 2024
1 parent 924801e commit bb140e8
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
30 changes: 28 additions & 2 deletions ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
import logging

from ai_edge_torch.odml_torch import jax_bridge
from ai_edge_torch.odml_torch.lowerings import context
from ai_edge_torch.odml_torch.lowerings import registry
import jax.numpy as jnp
from jax._src.lib.mlir import ir
import torch
import torch_xla2.ops.jaten # Import to load torch_xla2 ops
import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops

from . import registry

LoweringContext = context.LoweringContext

@functools.cache
def _log_usage(op):
Expand Down Expand Up @@ -258,3 +261,26 @@ def _aten_copy(self, *args, **kwargs):
@lower_by_jax(torch.ops.aten.copy, ir_input_names=["src"])
def _aten_copy(self, src, **kwargs):
return _TORCH_XLA2_IMPLS[torch.ops.aten.copy](self, src)


# Schema:
# - aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None)
# -> Tensor
# Torch Reference:
# - https://pytorch.org/docs/stable/generated/torch.einsum.html
# - https://github.com/pytorch/pytorch/blob/1b3f8b75896720e88362cbec7db32abc52afa83e/aten/src/ATen/native/Linear.cpp#L255
@registry.lower(torch.ops.aten.einsum.default)
def _aten_einsum_default(
lctx: LoweringContext,
equation: str,
tensors: list[ir.Value],
path=None,
):
_log_usage(torch.ops.aten.einsum.default)

@jax_bridge.wrap
def jax_lowering(operands):
# Ignore the input path and let JAX determine the path.
return jnp.einsum(equation, *operands, optimize="optimal")

return jax_lowering(lctx, tuple(tensors))
8 changes: 7 additions & 1 deletion ai_edge_torch/odml_torch/lowerings/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ def decompositions():

torch._decomp.remove_decompositions(
decompositions,
[torch.ops.aten.roll],
[
torch.ops.aten.roll,
# Torch's default einsum impl/decompositions is less efficient and
# optimized through converter than JAX's impl. Disable einsum
# decomposition to use JAX bridge for a more efficient lowering.
torch.ops.aten.einsum.default,
],
)

# Override _safe_softmax decompositions with regular softmax.
Expand Down
1 change: 1 addition & 0 deletions ai_edge_torch/odml_torch/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def _run_export_and_compare(
("aten_div_Tensor_0", torch.ops.aten.div.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
("aten_div_Tensor_mode_trunc_0", torch.ops.aten.div.Tensor_mode, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), {"rounding_mode": "trunc"}),
("aten_div_Tensor_mode_trunc_1", torch.ops.aten.div.Tensor_mode, (rnd(torch.int32, (10, 10)), rnd(torch.int32, (10, 10)),), {"rounding_mode": "trunc"}),
("aten_einsum", torch.ops.aten.einsum, ("abcd,abed->abce", [rnd(torch.float32, (1, 4, 2, 16)), rnd(torch.float32, (1, 4, 32, 16))]), dict()),
("aten_embedding_0", torch.ops.aten.embedding, (rnd(torch.float32, (10, 10)), rnd(torch.int64, (10,)),), dict()),
("aten_eq_Scalar_2", torch.ops.aten.eq.Scalar, (rnd(torch.float32, (10, 10)), 1,), dict()),
("aten_eq_Tensor_0", torch.ops.aten.eq.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
Expand Down

0 comments on commit bb140e8

Please sign in to comment.