From 9d387eca5f07155f540ad6572e33772ef21399e6 Mon Sep 17 00:00:00 2001 From: Chun-nien Chan Date: Fri, 20 Dec 2024 10:28:08 -0800 Subject: [PATCH] fix rand and randn lowering PiperOrigin-RevId: 708361782 --- ai_edge_torch/odml_torch/export.py | 7 +- .../odml_torch/jax_bridge/__init__.py | 5 +- .../odml_torch/lowerings/__init__.py | 1 + .../odml_torch/lowerings/_jax_lowerings.py | 3 +- ai_edge_torch/odml_torch/lowerings/_rand.py | 142 ++++++++++++++++++ 5 files changed, 154 insertions(+), 4 deletions(-) create mode 100644 ai_edge_torch/odml_torch/lowerings/_rand.py diff --git a/ai_edge_torch/odml_torch/export.py b/ai_edge_torch/odml_torch/export.py index dffaebea..684a4cb1 100644 --- a/ai_edge_torch/odml_torch/export.py +++ b/ai_edge_torch/odml_torch/export.py @@ -198,7 +198,12 @@ def module_bytecode_vhlo(self) -> bytes: # build, which may not have the same StableHLO version as what used in # TFLite converter. Therefore we always serialize MLIR module in VHLO. # TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release. - target_version = stablehlo.get_minimum_version() + if stablehlo.get_api_version() < 9: + target_version = stablehlo.get_minimum_version() + else: + target_version = stablehlo.get_version_from_compatibility_requirement( + stablehlo.StablehloCompatibilityRequirement.WEEK_4 + ) module_bytecode = xla_extension.mlir.serialize_portable_artifact( self.module_bytecode, target_version ) diff --git a/ai_edge_torch/odml_torch/jax_bridge/__init__.py b/ai_edge_torch/odml_torch/jax_bridge/__init__.py index f04872bd..ef7ad189 100644 --- a/ai_edge_torch/odml_torch/jax_bridge/__init__.py +++ b/ai_edge_torch/odml_torch/jax_bridge/__init__.py @@ -12,4 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from ai_edge_torch.odml_torch.jax_bridge._wrap import wrap +from ai_edge_torch.odml_torch.jax_bridge import _wrap +from ai_edge_torch.odml_torch.jax_bridge import utils + +wrap = _wrap.wrap diff --git a/ai_edge_torch/odml_torch/lowerings/__init__.py b/ai_edge_torch/odml_torch/lowerings/__init__.py index 0d232d52..dcf48591 100644 --- a/ai_edge_torch/odml_torch/lowerings/__init__.py +++ b/ai_edge_torch/odml_torch/lowerings/__init__.py @@ -18,6 +18,7 @@ from . import _jax_lowerings from . import _layer_norm from . import _quantized_decomposed +from . import _rand from . import context from . import registry from . import utils diff --git a/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py b/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py index f2d1ca4e..cd4d27dd 100644 --- a/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +++ b/ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py @@ -26,6 +26,7 @@ LoweringContext = context.LoweringContext + @functools.cache def _log_usage(op): logging.warning("Use jax lowering: %s", str(op)) @@ -184,8 +185,6 @@ def lower_by_torch_xla2(op): lower_by_torch_xla2(torch.ops.aten.pixel_shuffle) lower_by_torch_xla2(torch.ops.aten.pow) lower_by_torch_xla2(torch.ops.aten.prod) -lower_by_torch_xla2(torch.ops.aten.rand) -lower_by_torch_xla2(torch.ops.aten.randn) lower_by_torch_xla2(torch.ops.aten.reciprocal) lower_by_torch_xla2(torch.ops.aten.reflection_pad1d) lower_by_torch_xla2(torch.ops.aten.relu) diff --git a/ai_edge_torch/odml_torch/lowerings/_rand.py b/ai_edge_torch/odml_torch/lowerings/_rand.py new file mode 100644 index 00000000..aa28742c --- /dev/null +++ b/ai_edge_torch/odml_torch/lowerings/_rand.py @@ -0,0 +1,142 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import uuid + +from ai_edge_torch.odml_torch import export_utils +from ai_edge_torch.odml_torch.lowerings import context +from ai_edge_torch.odml_torch.lowerings import registry +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import func +from jax._src.lib.mlir.dialects import hlo as stablehlo +import numpy as np +import torch +import torch.utils._pytree as pytree + +LoweringContext = context.LoweringContext +lower = registry.lower + + +def _random_lowering( + lctx: LoweringContext, + size: list[int], + generator, + dtype: torch.dtype, + rand_tensor, + composite_name: str, +): + if dtype is None: + dtype = torch.float32 + + rand_tensor = rand_tensor.type(dtype) + data = rand_tensor.detach().numpy() + + shape, _ = pytree.tree_flatten(size) + elty = export_utils.torch_dtype_to_ir_element_type(dtype) + + decomp_name = f"{composite_name}.impl_{uuid.uuid4().hex[:8]}" + + with ir.InsertionPoint(lctx.ir_module.body): + + @func.FuncOp.from_py_func( + ir.RankedTensorType.get( + [len(shape)], + ir.IntegerType.get_signless(32), + ), + name=decomp_name, + ) + def _rand_impl(_): + return [stablehlo.constant(ir.DenseElementsAttr.get(data))] + + seed, seed2 = ( + torch.randint( + torch.iinfo(torch.int64).min, + torch.iinfo(torch.int64).max, + (2,), + dtype=torch.int64, + generator=generator, + ) + .detach() + .numpy() + ) + + shape_ = stablehlo.constant( + ir.DenseElementsAttr.get(np.array(shape, dtype=np.int32)) + ) + return stablehlo.CompositeOp( + result=[ir.RankedTensorType.get(shape, elty)], + inputs=[shape_], + name=composite_name, + composite_attributes=ir.DictAttr.get({ + "seed": ir.IntegerAttr.get(ir.IntegerType.get_signless(64), seed), + "seed2": ir.IntegerAttr.get(ir.IntegerType.get_signless(64), seed2), + }), + decomposition=decomp_name, + ).results[0] + + +# Schema: +# - aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, +# Device? device=None, bool? pin_memory=None) -> Tensor +# - aten::rand.generator(SymInt[] size, *, Generator? generator, +# ScalarType? dtype=None, Layout? layout=None, Device? device=None, +# bool? pin_memory=None) -> Tensor +@registry.lower(torch.ops.aten.rand) +def _aten_rand( + lctx: LoweringContext, + size, + generator=None, + dtype=None, + layout=torch.strided, + device=None, + pin_memory=False, +): + return _random_lowering( + lctx, + size, + generator, + dtype, + rand_tensor=torch.ops.aten.rand.generator( + size, generator=generator, dtype=dtype + ), + composite_name="odml.random_uniform", + ) + + +# Schema: +# - aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, +# Device? device=None, bool? pin_memory=None) -> Tensor +# - aten::randn.generator(SymInt[] size, *, Generator? generator, +# ScalarType? dtype=None, Layout? layout=None, Device? device=None, +# bool? pin_memory=None) -> Tensor +@registry.lower(torch.ops.aten.randn) +def _aten_randn( + lctx: LoweringContext, + size, + generator=None, + dtype=None, + layout=torch.strided, + device=None, + pin_memory=False, +): + return _random_lowering( + lctx, + size, + generator, + dtype, + rand_tensor=torch.ops.aten.randn.generator( + size, generator=generator, dtype=dtype + ), + composite_name="odml.random_standard_normal", + )