Skip to content

Commit

Permalink
Initial odml-torch sync
Browse files Browse the repository at this point in the history
Usage:
```
pip install -r odmltorch-requirements.txt
pip install . --no-deps
export USE_TORCH_XLA=0
```
PiperOrigin-RevId: 668645151
  • Loading branch information
chunnienc authored and copybara-github committed Aug 28, 2024
1 parent 846e2e9 commit 0a829b6
Show file tree
Hide file tree
Showing 53 changed files with 3,385 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ai_edge_torch._convert import fx_passes
import torch

from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest


def _export_to_stablehlo_with_composite(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ai_edge_torch._convert import fx_passes
import torch

from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest


def _export_to_stablehlo_with_composite(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ai_edge_torch._convert import fx_passes
import torch

from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest


def _export_to_stablehlo_with_composite(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch.utils._pytree as pytree
import torchvision

from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest


def export_with_pass(
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/_convert/test/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import torch
import torchvision

from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest


@dataclasses.dataclass
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/_convert/test/test_convert_composites.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import parameterized
import torch

from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest


def _func_to_torch_module(func: Callable[..., torch.Tensor]):
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/_convert/test/test_convert_multisig.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from torch import nn

from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest


class FullyConnectedModel(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/_convert/test/test_to_channel_last_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import ai_edge_torch
import torch

from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest


class Identity(torch.nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/debug/test/test_culprit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ai_edge_torch.debug import find_culprits
import torch

from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest

_test_culprit_lib = torch.library.Library("test_culprit", "DEF")

Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/debug/test/test_search_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ai_edge_torch.debug import _search_model
import torch

from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest


class TestSearchModel(googletest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/e2e_tests/test_multisig.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
import torchvision

from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest


class TestConvertMultiSignature(googletest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import ai_edge_torch.generative.layers.unet.model_config as unet_cfg
import torch

from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest


def _export_to_stablehlo(func: Union[torch.nn.Module, Callable], export_args):
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/test/test_experimental_ekv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import ai_edge_torch.generative.layers.model_config as cfg
import torch

from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest


class TestExternalKVLayers(googletest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/test/test_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import safetensors.torch
import torch

from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest


class TestLoader(googletest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/test/test_model_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import numpy as np
import torch

from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest


class TestModelConversion(googletest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/test/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from parameterized import parameterized
import torch

from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest


class TestVerifyRecipes(googletest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/hlfb/test/test_mark_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
import torch

from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest


def _export_stablehlo_mlir(model, args=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch
import torch.nn.functional as F

from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest


def _export_stablehlo_mlir(model, args):
Expand Down
6 changes: 5 additions & 1 deletion ai_edge_torch/lowertools/odml_torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,17 @@ def inner(*args):
t_outs = [torch_dtype_to_tf(sig.dtype) for sig in bundle.output_signature]
s_outs = [_get_shape_with_dynamic(sig) for sig in bundle.output_signature]
call_args = _extract_call_args(bundle, args, tf_state_dict)
# HACK: In OSS, we use MLIR pybinding and StableHLO dialect from JAX's
# build, which may not have the same StableHLO version as what used in
# TFLite converter. Therefore we always serialize MLIR module in VHLO.
# TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
call_module_return = tfxla.call_module(
tuple(call_args),
version=5,
Tout=t_outs, # dtype information
Sout=s_outs, # Shape information
function_list=[],
module=bundle.module_bytecode,
module=bundle.module_bytecode_vhlo,
)
spec = exported_program.call_spec.out_spec

Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/lowertools/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import re
from typing import Optional
from ai_edge_torch import config
from tensorflow.python.platform import googletest
from absl.testing import absltest as googletest


def _extract_backend_configs(mlir):
Expand Down
20 changes: 20 additions & 0 deletions ai_edge_torch/odml_torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from . import composite
from . import debuginfo
from . import export
from . import export_utils
from . import lowerings
from . import passes
61 changes: 61 additions & 0 deletions ai_edge_torch/odml_torch/_torch_future.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrappers for latest torch APIs/utilities to maintain backward compatibility with older torch releases."""

import torch
from torch.fx import _pytree as fx_pytree


def graph_module_flat_inputs(ep: torch.export.ExportedProgram, args, kwargs):
"""Transform args, kwargs of __call__ to args for graph_module.
self.graph_module takes stuff from state dict as inputs.
The invariant is for ep: ExportedProgram is
ep(args, kwargs) ==
ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs)))
"""
if hasattr(ep, "_graph_module_flat_inputs"):
return ep._graph_module_flat_inputs(args, kwargs)

if args is None:
args = tuple()
if kwargs is None:
kwargs = {}

flat_args = args
if (in_spec := ep.call_spec.in_spec) is not None:
if (
in_spec.type == tuple
and len(in_spec.children_specs) == 2
and in_spec.children_specs[0].type == tuple
and in_spec.children_specs[1].type == dict
):
# NOTE: this is the case where in_spec is for both args and kwargs
flat_args = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
else:
flat_args = fx_pytree.tree_flatten_spec(args, in_spec)

param_buffer_keys = ep.graph_signature.parameters + ep.graph_signature.buffers
param_buffer_values = tuple(ep.state_dict[key] for key in param_buffer_keys)

if hasattr(ep.graph_signature, "lifted_tensor_constants"):
ordered_tensor_constants = tuple(
ep.tensor_constants[name]
for name in ep.graph_signature.lifted_tensor_constants
)
else:
ordered_tensor_constants = tuple()

return (*param_buffer_values, *flat_args, *ordered_tensor_constants)
19 changes: 19 additions & 0 deletions ai_edge_torch/odml_torch/_torch_library.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Torch library for registering ODML Torch custom ops."""

import torch

ODML_TORCH_LIB = torch.library.Library("odml_torch", "DEF")
16 changes: 16 additions & 0 deletions ai_edge_torch/odml_torch/composite/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from .mark_tensor import mark_tensor_op
from .stablehlo_composite_builder import StableHLOCompositeBuilder
Loading

0 comments on commit 0a829b6

Please sign in to comment.