Skip to content

Commit

Permalink
enable odml-torch as default
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705640829
  • Loading branch information
chunnienc authored and copybara-github committed Dec 13, 2024
1 parent f142f4a commit bb02016
Show file tree
Hide file tree
Showing 19 changed files with 143 additions and 113 deletions.
2 changes: 1 addition & 1 deletion ai_edge_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
# limitations under the License.
# ==============================================================================

from ai_edge_torch._config import config
from ai_edge_torch._convert.converter import convert
from ai_edge_torch._convert.converter import signature
from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
from ai_edge_torch.model import Model
from ai_edge_torch.version import __version__


def load(path: str) -> Model:
"""Imports an ai_edge_torch model from disk.
Expand Down
52 changes: 52 additions & 0 deletions ai_edge_torch/_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Provides a configuration for the ai-edge-torch."""

import functools
import logging
import os

__all__ = ["config"]


class _Config:
"""ai-edge-torch global configs."""

@property
@functools.cache # pylint: disable=method-cache-max-size-none
def use_torch_xla(self) -> bool:
"""True if using torch_xla to lower torch ops to StableHLO.
To use torch_xla as the lowering backend, set environment variable
`USE_TORCH_XLA` to "true".
"""
var = os.environ.get("USE_TORCH_XLA", "false")
var = var.lower().strip()
if var in ("y", "yes", "t", "true", "on", "1"):
return True
elif var in ("n", "no", "f", "false", "off", "0"):
return False
else:
logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var)
return False

@property
def in_oss(self) -> bool:
"""True if the code is not running in google internal environment."""
return True


config = _Config()
3 changes: 1 addition & 2 deletions ai_edge_torch/_convert/test/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import Tuple

import ai_edge_torch
from ai_edge_torch import config
from ai_edge_torch._convert import conversion_utils
from ai_edge_torch.quantize import pt2e_quantizer
from ai_edge_torch.testing import model_coverage
Expand Down Expand Up @@ -292,7 +291,7 @@ def test_convert_conv_transpose_batch_norm(self):
self.assertTrue(result)

@googletest.skipIf(
not config.Config.use_torch_xla,
not ai_edge_torch.config.use_torch_xla,
reason="Shape polymorphism is not yet support with odml_torch.",
)
def test_convert_model_with_dynamic_batch(self):
Expand Down
27 changes: 0 additions & 27 deletions ai_edge_torch/config.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import re
from typing import Callable, Union

from ai_edge_torch import config
import ai_edge_torch
from ai_edge_torch import fx_pass_base
from ai_edge_torch import lowertools
from ai_edge_torch.generative.fx_passes import CanonicalizePass
Expand Down Expand Up @@ -112,7 +112,7 @@ def get_model_config() -> unet_cfg.AttentionBlock2DConfig:
(torch.rand(1, 512, 64, 64),),
)

if config.Config.use_torch_xla:
if ai_edge_torch.config.use_torch_xla:
self.assertTrue(
re.search(
'stablehlo\.composite "odml\.scaled_dot_product_attention" %\d+,'
Expand Down
17 changes: 8 additions & 9 deletions ai_edge_torch/generative/test/test_model_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"""Testing model conversion for a few gen-ai models."""

import ai_edge_torch
from ai_edge_torch import config as ai_edge_config
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
from ai_edge_torch.generative.layers import kv_cache
Expand Down Expand Up @@ -83,22 +82,22 @@ def _test_model_with_kv_cache(self, enable_hlfb: bool):
)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def test_toy_model_with_kv_cache(self):
self._test_model_with_kv_cache(enable_hlfb=False)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def test_toy_model_with_kv_cache_with_hlfb(self):
self._test_model_with_kv_cache(enable_hlfb=True)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def test_toy_model_has_dus_op(self):
"""Tests that the model has the dynamic update slice op."""
Expand Down Expand Up @@ -179,8 +178,8 @@ def _test_multisig_model(self, config, pytorch_model, atol, rtol):
)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def test_tiny_llama_multisig(self):
config = tiny_llama.get_fake_model_config()
Expand Down
53 changes: 26 additions & 27 deletions ai_edge_torch/generative/test/test_model_conversion_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"""Testing model conversion for a few gen-ai models."""

import ai_edge_torch
from ai_edge_torch import config as ai_edge_config
from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
from ai_edge_torch.generative.examples.gemma import gemma1
from ai_edge_torch.generative.examples.gemma import gemma2
Expand Down Expand Up @@ -91,35 +90,35 @@ def _test_model(self, config, model, signature_name, atol, rtol):
)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def test_gemma1(self):
config = gemma1.get_fake_model_config()
pytorch_model = gemma1.Gemma1(config).eval()
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def test_gemma2(self):
config = gemma2.get_fake_model_config()
pytorch_model = gemma2.Gemma2(config).eval()
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def test_llama(self):
config = llama.get_fake_model_config()
pytorch_model = llama.Llama(config).eval()
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def test_phi2(self):
config = phi2.get_fake_model_config()
Expand All @@ -128,53 +127,53 @@ def test_phi2(self):
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def test_phi3(self):
config = phi3.get_fake_model_config()
pytorch_model = phi3.Phi3_5Mini(config).eval()
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def test_smollm(self):
config = smollm.get_fake_model_config()
pytorch_model = smollm.SmolLM(config).eval()
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def test_openelm(self):
config = openelm.get_fake_model_config()
pytorch_model = openelm.OpenELM(config).eval()
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def test_qwen(self):
config = qwen.get_fake_model_config()
pytorch_model = qwen.Qwen(config).eval()
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def test_amd_llama_135m(self):
config = amd_llama_135m.get_fake_model_config()
pytorch_model = amd_llama_135m.AmdLlama(config).eval()
self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def disabled_test_paligemma(self):
config = paligemma.get_fake_model_config()
Expand Down Expand Up @@ -222,8 +221,8 @@ def disabled_test_paligemma(self):
)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def test_stable_diffusion_clip(self):
config = sd_clip.get_fake_model_config()
Expand Down Expand Up @@ -254,8 +253,8 @@ def test_stable_diffusion_clip(self):
)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def test_stable_diffusion_diffusion(self):
config = sd_diffusion.get_fake_model_config(2)
Expand Down Expand Up @@ -296,8 +295,8 @@ def test_stable_diffusion_diffusion(self):
)

@googletest.skipIf(
ai_edge_config.Config.use_torch_xla,
reason="tests with custom ops are not supported on oss",
ai_edge_torch.config.in_oss,
reason="tests with custom ops are not supported in oss",
)
def test_stable_diffusion_decoder(self):
config = sd_decoder.get_fake_model_config()
Expand Down
7 changes: 5 additions & 2 deletions ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

import math

from ai_edge_torch import config
import ai_edge_torch
from ai_edge_torch import hlfb
from ai_edge_torch import lowertools
from ai_edge_torch.hlfb import StableHLOCompositeBuilder
import torch
Expand All @@ -29,9 +30,11 @@ def _export_stablehlo_mlir(model, args):
ep = torch.export.export(model, args)
return lowertools.exported_program_to_mlir_text(ep)

StableHLOCompositeBuilder = hlfb.StableHLOCompositeBuilder


@googletest.skipIf(
not config.Config.use_torch_xla,
not ai_edge_torch.config.use_torch_xla,
reason="The odml_torch counter part is in odml_torch.",
)
class TestStableHLOCompositeBuilder(googletest.TestCase):
Expand Down
6 changes: 4 additions & 2 deletions ai_edge_torch/lowertools/_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@

from typing import Any, Optional

from ai_edge_torch import config
from ai_edge_torch import _config
from ai_edge_torch._convert import signature
from ai_edge_torch.quantize import quant_config as qcfg
import torch

config = _config.config

# isort: off
if config.Config.use_torch_xla:
if config.use_torch_xla:
from ai_edge_torch.lowertools import torch_xla_utils as utils
from ai_edge_torch.lowertools.torch_xla_utils import exported_program_to_mlir_text
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
Expand Down
6 changes: 4 additions & 2 deletions ai_edge_torch/lowertools/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@

import re
from typing import Optional
from ai_edge_torch import config
from ai_edge_torch import _config
from absl.testing import absltest as googletest

config = _config.config


def _extract_backend_configs(mlir):
mlir = mlir.replace("\\22", '"')
Expand All @@ -38,7 +40,7 @@ def assert_string_count(
if odml_torch_attr_counter is None:
odml_torch_attr_counter = {}

if config.Config.use_torch_xla:
if config.use_torch_xla:
for key in torch_xla_pattern_counter:
test_case.assertEqual(
mlir.count(key),
Expand Down
Loading

0 comments on commit bb02016

Please sign in to comment.