Skip to content

Commit

Permalink
fix rand and randn lowering
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708361782
  • Loading branch information
chunnienc authored and copybara-github committed Dec 20, 2024
1 parent dc45276 commit 9d387ec
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 4 deletions.
7 changes: 6 additions & 1 deletion ai_edge_torch/odml_torch/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,12 @@ def module_bytecode_vhlo(self) -> bytes:
# build, which may not have the same StableHLO version as what used in
# TFLite converter. Therefore we always serialize MLIR module in VHLO.
# TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
target_version = stablehlo.get_minimum_version()
if stablehlo.get_api_version() < 9:
target_version = stablehlo.get_minimum_version()
else:
target_version = stablehlo.get_version_from_compatibility_requirement(
stablehlo.StablehloCompatibilityRequirement.WEEK_4
)
module_bytecode = xla_extension.mlir.serialize_portable_artifact(
self.module_bytecode, target_version
)
Expand Down
5 changes: 4 additions & 1 deletion ai_edge_torch/odml_torch/jax_bridge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from ai_edge_torch.odml_torch.jax_bridge._wrap import wrap
from ai_edge_torch.odml_torch.jax_bridge import _wrap
from ai_edge_torch.odml_torch.jax_bridge import utils

wrap = _wrap.wrap
1 change: 1 addition & 0 deletions ai_edge_torch/odml_torch/lowerings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from . import _jax_lowerings
from . import _layer_norm
from . import _quantized_decomposed
from . import _rand
from . import context
from . import registry
from . import utils
Expand Down
3 changes: 1 addition & 2 deletions ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

LoweringContext = context.LoweringContext


@functools.cache
def _log_usage(op):
logging.warning("Use jax lowering: %s", str(op))
Expand Down Expand Up @@ -184,8 +185,6 @@ def lower_by_torch_xla2(op):
lower_by_torch_xla2(torch.ops.aten.pixel_shuffle)
lower_by_torch_xla2(torch.ops.aten.pow)
lower_by_torch_xla2(torch.ops.aten.prod)
lower_by_torch_xla2(torch.ops.aten.rand)
lower_by_torch_xla2(torch.ops.aten.randn)
lower_by_torch_xla2(torch.ops.aten.reciprocal)
lower_by_torch_xla2(torch.ops.aten.reflection_pad1d)
lower_by_torch_xla2(torch.ops.aten.relu)
Expand Down
142 changes: 142 additions & 0 deletions ai_edge_torch/odml_torch/lowerings/_rand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import uuid

from ai_edge_torch.odml_torch import export_utils
from ai_edge_torch.odml_torch.lowerings import context
from ai_edge_torch.odml_torch.lowerings import registry
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import hlo as stablehlo
import numpy as np
import torch
import torch.utils._pytree as pytree

LoweringContext = context.LoweringContext
lower = registry.lower


def _random_lowering(
lctx: LoweringContext,
size: list[int],
generator,
dtype: torch.dtype,
rand_tensor,
composite_name: str,
):
if dtype is None:
dtype = torch.float32

rand_tensor = rand_tensor.type(dtype)
data = rand_tensor.detach().numpy()

shape, _ = pytree.tree_flatten(size)
elty = export_utils.torch_dtype_to_ir_element_type(dtype)

decomp_name = f"{composite_name}.impl_{uuid.uuid4().hex[:8]}"

with ir.InsertionPoint(lctx.ir_module.body):

@func.FuncOp.from_py_func(
ir.RankedTensorType.get(
[len(shape)],
ir.IntegerType.get_signless(32),
),
name=decomp_name,
)
def _rand_impl(_):
return [stablehlo.constant(ir.DenseElementsAttr.get(data))]

seed, seed2 = (
torch.randint(
torch.iinfo(torch.int64).min,
torch.iinfo(torch.int64).max,
(2,),
dtype=torch.int64,
generator=generator,
)
.detach()
.numpy()
)

shape_ = stablehlo.constant(
ir.DenseElementsAttr.get(np.array(shape, dtype=np.int32))
)
return stablehlo.CompositeOp(
result=[ir.RankedTensorType.get(shape, elty)],
inputs=[shape_],
name=composite_name,
composite_attributes=ir.DictAttr.get({
"seed": ir.IntegerAttr.get(ir.IntegerType.get_signless(64), seed),
"seed2": ir.IntegerAttr.get(ir.IntegerType.get_signless(64), seed2),
}),
decomposition=decomp_name,
).results[0]


# Schema:
# - aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None,
# Device? device=None, bool? pin_memory=None) -> Tensor
# - aten::rand.generator(SymInt[] size, *, Generator? generator,
# ScalarType? dtype=None, Layout? layout=None, Device? device=None,
# bool? pin_memory=None) -> Tensor
@registry.lower(torch.ops.aten.rand)
def _aten_rand(
lctx: LoweringContext,
size,
generator=None,
dtype=None,
layout=torch.strided,
device=None,
pin_memory=False,
):
return _random_lowering(
lctx,
size,
generator,
dtype,
rand_tensor=torch.ops.aten.rand.generator(
size, generator=generator, dtype=dtype
),
composite_name="odml.random_uniform",
)


# Schema:
# - aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None,
# Device? device=None, bool? pin_memory=None) -> Tensor
# - aten::randn.generator(SymInt[] size, *, Generator? generator,
# ScalarType? dtype=None, Layout? layout=None, Device? device=None,
# bool? pin_memory=None) -> Tensor
@registry.lower(torch.ops.aten.randn)
def _aten_randn(
lctx: LoweringContext,
size,
generator=None,
dtype=None,
layout=torch.strided,
device=None,
pin_memory=False,
):
return _random_lowering(
lctx,
size,
generator,
dtype,
rand_tensor=torch.ops.aten.randn.generator(
size, generator=generator, dtype=dtype
),
composite_name="odml.random_standard_normal",
)

0 comments on commit 9d387ec

Please sign in to comment.