Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use jax bridged einsum lowering #416

Merged
merged 1 commit into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
import logging

from ai_edge_torch.odml_torch import jax_bridge
from ai_edge_torch.odml_torch.lowerings import context
from ai_edge_torch.odml_torch.lowerings import registry
import jax.numpy as jnp
from jax._src.lib.mlir import ir
import torch
import torch_xla2.ops.jaten # Import to load torch_xla2 ops
import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops

from . import registry

LoweringContext = context.LoweringContext

@functools.cache
def _log_usage(op):
Expand Down Expand Up @@ -258,3 +261,26 @@ def _aten_copy(self, *args, **kwargs):
@lower_by_jax(torch.ops.aten.copy, ir_input_names=["src"])
def _aten_copy(self, src, **kwargs):
return _TORCH_XLA2_IMPLS[torch.ops.aten.copy](self, src)


# Schema:
# - aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None)
# -> Tensor
# Torch Reference:
# - https://pytorch.org/docs/stable/generated/torch.einsum.html
# - https://github.com/pytorch/pytorch/blob/1b3f8b75896720e88362cbec7db32abc52afa83e/aten/src/ATen/native/Linear.cpp#L255
@registry.lower(torch.ops.aten.einsum.default)
def _aten_einsum_default(
lctx: LoweringContext,
equation: str,
tensors: list[ir.Value],
path=None,
):
_log_usage(torch.ops.aten.einsum.default)

@jax_bridge.wrap
def jax_lowering(operands):
# Ignore the input path and let JAX determine the path.
return jnp.einsum(equation, *operands, optimize="optimal")

return jax_lowering(lctx, tuple(tensors))
8 changes: 7 additions & 1 deletion ai_edge_torch/odml_torch/lowerings/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ def decompositions():

torch._decomp.remove_decompositions(
decompositions,
[torch.ops.aten.roll],
[
torch.ops.aten.roll,
# Torch's default einsum impl/decompositions is less efficient and
# optimized through converter than JAX's impl. Disable einsum
# decomposition to use JAX bridge for a more efficient lowering.
torch.ops.aten.einsum.default,
],
)

# Override _safe_softmax decompositions with regular softmax.
Expand Down
1 change: 1 addition & 0 deletions ai_edge_torch/odml_torch/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def _run_export_and_compare(
("aten_div_Tensor_0", torch.ops.aten.div.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
("aten_div_Tensor_mode_trunc_0", torch.ops.aten.div.Tensor_mode, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), {"rounding_mode": "trunc"}),
("aten_div_Tensor_mode_trunc_1", torch.ops.aten.div.Tensor_mode, (rnd(torch.int32, (10, 10)), rnd(torch.int32, (10, 10)),), {"rounding_mode": "trunc"}),
("aten_einsum", torch.ops.aten.einsum, ("abcd,abed->abce", [rnd(torch.float32, (1, 4, 2, 16)), rnd(torch.float32, (1, 4, 32, 16))]), dict()),
("aten_embedding_0", torch.ops.aten.embedding, (rnd(torch.float32, (10, 10)), rnd(torch.int64, (10,)),), dict()),
("aten_eq_Scalar_2", torch.ops.aten.eq.Scalar, (rnd(torch.float32, (10, 10)), 1,), dict()),
("aten_eq_Tensor_0", torch.ops.aten.eq.Tensor, (rnd(torch.float32, (10, 10)), rnd(torch.float32, (10, 10)),), dict()),
Expand Down
Loading