From bb02016c2b2bda542c0ed4ce7de10640c17659fd Mon Sep 17 00:00:00 2001 From: Chun-nien Chan Date: Thu, 12 Dec 2024 14:58:22 -0800 Subject: [PATCH] enable odml-torch as default PiperOrigin-RevId: 705640829 --- ai_edge_torch/__init__.py | 2 +- ai_edge_torch/_config.py | 52 ++++++++++++++++++ ai_edge_torch/_convert/test/test_convert.py | 3 +- ai_edge_torch/config.py | 27 ---------- .../test/test_remove_sdpa_zero_mask_pass.py | 4 +- .../generative/test/test_model_conversion.py | 17 +++--- .../test/test_model_conversion_large.py | 53 +++++++++---------- .../test/test_stablehlo_composite_builder.py | 7 ++- ai_edge_torch/lowertools/_shim.py | 6 ++- ai_edge_torch/lowertools/test_utils.py | 6 ++- ai_edge_torch/odml_torch/lowerings/_basic.py | 8 +-- .../odml_torch/lowerings/_convolution.py | 4 +- .../odml_torch/lowerings/_layer_norm.py | 13 ++++- .../lowerings/_quantized_decomposed.py | 18 +++---- odmltorch-requirements.txt | 16 ------ requirements.txt | 3 +- run_tests.sh | 3 +- setup.py | 11 ++-- test/test_quantize.py | 3 +- 19 files changed, 143 insertions(+), 113 deletions(-) create mode 100644 ai_edge_torch/_config.py delete mode 100644 ai_edge_torch/config.py delete mode 100644 odmltorch-requirements.txt diff --git a/ai_edge_torch/__init__.py b/ai_edge_torch/__init__.py index 8c6493ba..2531b27e 100644 --- a/ai_edge_torch/__init__.py +++ b/ai_edge_torch/__init__.py @@ -13,13 +13,13 @@ # limitations under the License. # ============================================================================== +from ai_edge_torch._config import config from ai_edge_torch._convert.converter import convert from ai_edge_torch._convert.converter import signature from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io from ai_edge_torch.model import Model from ai_edge_torch.version import __version__ - def load(path: str) -> Model: """Imports an ai_edge_torch model from disk. diff --git a/ai_edge_torch/_config.py b/ai_edge_torch/_config.py new file mode 100644 index 00000000..f8ff99b1 --- /dev/null +++ b/ai_edge_torch/_config.py @@ -0,0 +1,52 @@ +# Copyright 2024 The AI Edge Torch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Provides a configuration for the ai-edge-torch.""" + +import functools +import logging +import os + +__all__ = ["config"] + + +class _Config: + """ai-edge-torch global configs.""" + + @property + @functools.cache # pylint: disable=method-cache-max-size-none + def use_torch_xla(self) -> bool: + """True if using torch_xla to lower torch ops to StableHLO. + + To use torch_xla as the lowering backend, set environment variable + `USE_TORCH_XLA` to "true". + """ + var = os.environ.get("USE_TORCH_XLA", "false") + var = var.lower().strip() + if var in ("y", "yes", "t", "true", "on", "1"): + return True + elif var in ("n", "no", "f", "false", "off", "0"): + return False + else: + logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var) + return False + + @property + def in_oss(self) -> bool: + """True if the code is not running in google internal environment.""" + return True + + +config = _Config() diff --git a/ai_edge_torch/_convert/test/test_convert.py b/ai_edge_torch/_convert/test/test_convert.py index 6573645d..171ddb99 100644 --- a/ai_edge_torch/_convert/test/test_convert.py +++ b/ai_edge_torch/_convert/test/test_convert.py @@ -19,7 +19,6 @@ from typing import Tuple import ai_edge_torch -from ai_edge_torch import config from ai_edge_torch._convert import conversion_utils from ai_edge_torch.quantize import pt2e_quantizer from ai_edge_torch.testing import model_coverage @@ -292,7 +291,7 @@ def test_convert_conv_transpose_batch_norm(self): self.assertTrue(result) @googletest.skipIf( - not config.Config.use_torch_xla, + not ai_edge_torch.config.use_torch_xla, reason="Shape polymorphism is not yet support with odml_torch.", ) def test_convert_model_with_dynamic_batch(self): diff --git a/ai_edge_torch/config.py b/ai_edge_torch/config.py deleted file mode 100644 index 404773ef..00000000 --- a/ai_edge_torch/config.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2024 The AI Edge Torch Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Provides a configuration for the AI Edge Torch library.""" - -import dataclasses -import os - - -@dataclasses.dataclass -class Config: - use_torch_xla: bool = os.environ.get("USE_TORCH_XLA", "true").lower() in ( - "1", - "true", - ) diff --git a/ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py b/ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py index de4539f1..969760e7 100644 --- a/ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py +++ b/ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py @@ -15,7 +15,7 @@ import re from typing import Callable, Union -from ai_edge_torch import config +import ai_edge_torch from ai_edge_torch import fx_pass_base from ai_edge_torch import lowertools from ai_edge_torch.generative.fx_passes import CanonicalizePass @@ -112,7 +112,7 @@ def get_model_config() -> unet_cfg.AttentionBlock2DConfig: (torch.rand(1, 512, 64, 64),), ) - if config.Config.use_torch_xla: + if ai_edge_torch.config.use_torch_xla: self.assertTrue( re.search( 'stablehlo\.composite "odml\.scaled_dot_product_attention" %\d+,' diff --git a/ai_edge_torch/generative/test/test_model_conversion.py b/ai_edge_torch/generative/test/test_model_conversion.py index a21be58b..43b47326 100644 --- a/ai_edge_torch/generative/test/test_model_conversion.py +++ b/ai_edge_torch/generative/test/test_model_conversion.py @@ -16,7 +16,6 @@ """Testing model conversion for a few gen-ai models.""" import ai_edge_torch -from ai_edge_torch import config as ai_edge_config from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache from ai_edge_torch.generative.examples.tiny_llama import tiny_llama from ai_edge_torch.generative.layers import kv_cache @@ -83,22 +82,22 @@ def _test_model_with_kv_cache(self, enable_hlfb: bool): ) @googletest.skipIf( - ai_edge_config.Config.use_torch_xla, - reason="tests with custom ops are not supported on oss", + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", ) def test_toy_model_with_kv_cache(self): self._test_model_with_kv_cache(enable_hlfb=False) @googletest.skipIf( - ai_edge_config.Config.use_torch_xla, - reason="tests with custom ops are not supported on oss", + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", ) def test_toy_model_with_kv_cache_with_hlfb(self): self._test_model_with_kv_cache(enable_hlfb=True) @googletest.skipIf( - ai_edge_config.Config.use_torch_xla, - reason="tests with custom ops are not supported on oss", + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", ) def test_toy_model_has_dus_op(self): """Tests that the model has the dynamic update slice op.""" @@ -179,8 +178,8 @@ def _test_multisig_model(self, config, pytorch_model, atol, rtol): ) @googletest.skipIf( - ai_edge_config.Config.use_torch_xla, - reason="tests with custom ops are not supported on oss", + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", ) def test_tiny_llama_multisig(self): config = tiny_llama.get_fake_model_config() diff --git a/ai_edge_torch/generative/test/test_model_conversion_large.py b/ai_edge_torch/generative/test/test_model_conversion_large.py index 8ff2c265..94ca776e 100644 --- a/ai_edge_torch/generative/test/test_model_conversion_large.py +++ b/ai_edge_torch/generative/test/test_model_conversion_large.py @@ -16,7 +16,6 @@ """Testing model conversion for a few gen-ai models.""" import ai_edge_torch -from ai_edge_torch import config as ai_edge_config from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m from ai_edge_torch.generative.examples.gemma import gemma1 from ai_edge_torch.generative.examples.gemma import gemma2 @@ -91,8 +90,8 @@ def _test_model(self, config, model, signature_name, atol, rtol): ) @googletest.skipIf( - ai_edge_config.Config.use_torch_xla, - reason="tests with custom ops are not supported on oss", + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", ) def test_gemma1(self): config = gemma1.get_fake_model_config() @@ -100,8 +99,8 @@ def test_gemma1(self): self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5) @googletest.skipIf( - ai_edge_config.Config.use_torch_xla, - reason="tests with custom ops are not supported on oss", + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", ) def test_gemma2(self): config = gemma2.get_fake_model_config() @@ -109,8 +108,8 @@ def test_gemma2(self): self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5) @googletest.skipIf( - ai_edge_config.Config.use_torch_xla, - reason="tests with custom ops are not supported on oss", + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", ) def test_llama(self): config = llama.get_fake_model_config() @@ -118,8 +117,8 @@ def test_llama(self): self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5) @googletest.skipIf( - ai_edge_config.Config.use_torch_xla, - reason="tests with custom ops are not supported on oss", + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", ) def test_phi2(self): config = phi2.get_fake_model_config() @@ -128,8 +127,8 @@ def test_phi2(self): self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5) @googletest.skipIf( - ai_edge_config.Config.use_torch_xla, - reason="tests with custom ops are not supported on oss", + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", ) def test_phi3(self): config = phi3.get_fake_model_config() @@ -137,8 +136,8 @@ def test_phi3(self): self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5) @googletest.skipIf( - ai_edge_config.Config.use_torch_xla, - reason="tests with custom ops are not supported on oss", + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", ) def test_smollm(self): config = smollm.get_fake_model_config() @@ -146,8 +145,8 @@ def test_smollm(self): self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5) @googletest.skipIf( - ai_edge_config.Config.use_torch_xla, - reason="tests with custom ops are not supported on oss", + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", ) def test_openelm(self): config = openelm.get_fake_model_config() @@ -155,8 +154,8 @@ def test_openelm(self): self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5) @googletest.skipIf( - ai_edge_config.Config.use_torch_xla, - reason="tests with custom ops are not supported on oss", + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", ) def test_qwen(self): config = qwen.get_fake_model_config() @@ -164,8 +163,8 @@ def test_qwen(self): self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5) @googletest.skipIf( - ai_edge_config.Config.use_torch_xla, - reason="tests with custom ops are not supported on oss", + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", ) def test_amd_llama_135m(self): config = amd_llama_135m.get_fake_model_config() @@ -173,8 +172,8 @@ def test_amd_llama_135m(self): self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5) @googletest.skipIf( - ai_edge_config.Config.use_torch_xla, - reason="tests with custom ops are not supported on oss", + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", ) def disabled_test_paligemma(self): config = paligemma.get_fake_model_config() @@ -222,8 +221,8 @@ def disabled_test_paligemma(self): ) @googletest.skipIf( - ai_edge_config.Config.use_torch_xla, - reason="tests with custom ops are not supported on oss", + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", ) def test_stable_diffusion_clip(self): config = sd_clip.get_fake_model_config() @@ -254,8 +253,8 @@ def test_stable_diffusion_clip(self): ) @googletest.skipIf( - ai_edge_config.Config.use_torch_xla, - reason="tests with custom ops are not supported on oss", + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", ) def test_stable_diffusion_diffusion(self): config = sd_diffusion.get_fake_model_config(2) @@ -296,8 +295,8 @@ def test_stable_diffusion_diffusion(self): ) @googletest.skipIf( - ai_edge_config.Config.use_torch_xla, - reason="tests with custom ops are not supported on oss", + ai_edge_torch.config.in_oss, + reason="tests with custom ops are not supported in oss", ) def test_stable_diffusion_decoder(self): config = sd_decoder.get_fake_model_config() diff --git a/ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py b/ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py index 176feb16..b808b828 100644 --- a/ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +++ b/ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py @@ -16,7 +16,8 @@ import math -from ai_edge_torch import config +import ai_edge_torch +from ai_edge_torch import hlfb from ai_edge_torch import lowertools from ai_edge_torch.hlfb import StableHLOCompositeBuilder import torch @@ -29,9 +30,11 @@ def _export_stablehlo_mlir(model, args): ep = torch.export.export(model, args) return lowertools.exported_program_to_mlir_text(ep) +StableHLOCompositeBuilder = hlfb.StableHLOCompositeBuilder + @googletest.skipIf( - not config.Config.use_torch_xla, + not ai_edge_torch.config.use_torch_xla, reason="The odml_torch counter part is in odml_torch.", ) class TestStableHLOCompositeBuilder(googletest.TestCase): diff --git a/ai_edge_torch/lowertools/_shim.py b/ai_edge_torch/lowertools/_shim.py index e9480c05..0eb07a82 100644 --- a/ai_edge_torch/lowertools/_shim.py +++ b/ai_edge_torch/lowertools/_shim.py @@ -15,13 +15,15 @@ from typing import Any, Optional -from ai_edge_torch import config +from ai_edge_torch import _config from ai_edge_torch._convert import signature from ai_edge_torch.quantize import quant_config as qcfg import torch +config = _config.config + # isort: off -if config.Config.use_torch_xla: +if config.use_torch_xla: from ai_edge_torch.lowertools import torch_xla_utils as utils from ai_edge_torch.lowertools.torch_xla_utils import exported_program_to_mlir_text from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder diff --git a/ai_edge_torch/lowertools/test_utils.py b/ai_edge_torch/lowertools/test_utils.py index 4c5c94ed..137d1ee1 100644 --- a/ai_edge_torch/lowertools/test_utils.py +++ b/ai_edge_torch/lowertools/test_utils.py @@ -15,9 +15,11 @@ import re from typing import Optional -from ai_edge_torch import config +from ai_edge_torch import _config from absl.testing import absltest as googletest +config = _config.config + def _extract_backend_configs(mlir): mlir = mlir.replace("\\22", '"') @@ -38,7 +40,7 @@ def assert_string_count( if odml_torch_attr_counter is None: odml_torch_attr_counter = {} - if config.Config.use_torch_xla: + if config.use_torch_xla: for key in torch_xla_pattern_counter: test_case.assertEqual( mlir.count(key), diff --git a/ai_edge_torch/odml_torch/lowerings/_basic.py b/ai_edge_torch/odml_torch/lowerings/_basic.py index 808de8b4..264f367b 100644 --- a/ai_edge_torch/odml_torch/lowerings/_basic.py +++ b/ai_edge_torch/odml_torch/lowerings/_basic.py @@ -276,11 +276,13 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1): interior_padding if i == dim else 0 for i in range(rank) ], ) - pred = np.ones(self.type.shape, dtype=np.bool_) - pred[*[ + + slices = [ slice(start, end, step) if i == dim else slice(None, None, None) for i in range(rank) - ]] = False + ] + pred = np.ones(self.type.shape, dtype=np.bool_) + pred[np.index_exp[tuple(slices)]] = False pred = stablehlo.constant( ir.DenseElementsAttr.get( np.packbits(pred, bitorder="little"), diff --git a/ai_edge_torch/odml_torch/lowerings/_convolution.py b/ai_edge_torch/odml_torch/lowerings/_convolution.py index 68fa09db..bcc4a53b 100644 --- a/ai_edge_torch/odml_torch/lowerings/_convolution.py +++ b/ai_edge_torch/odml_torch/lowerings/_convolution.py @@ -232,7 +232,9 @@ def _aten_convolution( if bias is not None: # broadcast [C] to [NCHW] - broadcasted_bias = stablehlo.broadcast_in_dim(output_type, bias, [1]) + broadcasted_bias = stablehlo.broadcast_in_dim( + output_type, bias, ir.DenseI64ArrayAttr.get([1]) + ) res = stablehlo.add( lhs=res, rhs=broadcasted_bias, diff --git a/ai_edge_torch/odml_torch/lowerings/_layer_norm.py b/ai_edge_torch/odml_torch/lowerings/_layer_norm.py index 0672c7b9..e2fa9f17 100644 --- a/ai_edge_torch/odml_torch/lowerings/_layer_norm.py +++ b/ai_edge_torch/odml_torch/lowerings/_layer_norm.py @@ -20,6 +20,7 @@ from ai_edge_torch.odml_torch.lowerings import utils from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo as stablehlo +import numpy as np import torch @@ -66,12 +67,20 @@ def _aten_native_layer_norm( normalized_rank = len(normalized_shape) if weight is not None: weight = stablehlo.broadcast_in_dim( - data_type, weight, list(range(data_rank - normalized_rank, data_rank)) + data_type, + weight, + ir.DenseI64ArrayAttr.get( + list(range(data_rank - normalized_rank, data_rank)) + ), ) output = stablehlo.multiply(weight, output) if bias is not None: bias = stablehlo.broadcast_in_dim( - data_type, bias, list(range(data_rank - normalized_rank, data_rank)) + data_type, + bias, + ir.DenseI64ArrayAttr.get( + list(range(data_rank - normalized_rank, data_rank)) + ), ) output = stablehlo.add(bias, output) diff --git a/ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py b/ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py index 608bf8ae..de88fa58 100644 --- a/ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py +++ b/ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== """Lowerings for PT2E torch.ops.quantized_decomposed ops.""" -from typing import Union, cast +from typing import Optional, Union, cast from ai_edge_torch.odml_torch.lowerings import context from ai_edge_torch.odml_torch.lowerings import utils @@ -30,15 +30,15 @@ def _uniform_quantized_type( - stored_type: str | ir.Type, - expressed_type: str | ir.Type, + stored_type: Union[str, ir.Type], + expressed_type: Union[str, ir.Type], *, - scale=float | list[float] | tuple[float], - zero_point=float | list[float] | tuple[float], - storage_type_min: int | None = None, - storage_type_max: int | None = None, - channel_axis: int | None = None, - channel_axis_size: int | None = None, + scale=Union[float, list[float], tuple[float]], + zero_point=Union[float, list[float], tuple[float]], + storage_type_min: Optional[int] = None, + storage_type_max: Optional[int] = None, + channel_axis: Optional[int] = None, + channel_axis_size: Optional[int] = None, ): """Polyfill for quant.UniformQuantizedType.""" if storage_type_min and storage_type_max: diff --git a/odmltorch-requirements.txt b/odmltorch-requirements.txt deleted file mode 100644 index a9492f17..00000000 --- a/odmltorch-requirements.txt +++ /dev/null @@ -1,16 +0,0 @@ --f https://download.pytorch.org/whl/torch/ -torch==2.4.0+cpu --f https://download.pytorch.org/whl/torchvision/ -torchvision==0.19.0+cpu --f https://download.pytorch.org/whl/torchaudio/ -torchaudio==2.4.0+cpu ---pre -tf-nightly>=2.18.0.dev20240905 -torch_xla2[odml]>=0.0.1.dev20240801 -ai-edge-litert-nightly -ai-edge-quantizer-nightly -jax[cpu] -scipy -numpy -tabulate -safetensors \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 85a97766..cece68a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,11 +4,12 @@ torch==2.4.0+cpu torchvision==0.19.0+cpu -f https://download.pytorch.org/whl/torchaudio/ torchaudio==2.4.0+cpu -torch_xla==2.4.0 --pre tf-nightly>=2.19.0.dev20241201 ai-edge-litert-nightly ai-edge-quantizer-nightly +jax[cpu] +torch-xla2[odml]>=0.0.1.dev20241201 scipy numpy tabulate diff --git a/run_tests.sh b/run_tests.sh index 75f45329..dcd6aac4 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -16,5 +16,4 @@ SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) # TODO(b/362799258) Setup CIs to test odml-torch path and remove test ignore -PYTHONPATH=$SCRIPT_DIR:$PYTHONPATH python -m pytest $SCRIPT_DIR -n auto \ - --ignore=$SCRIPT_DIR/ai_edge_torch/odml_torch +PYTHONPATH=$SCRIPT_DIR:$PYTHONPATH python -m pytest $SCRIPT_DIR -n auto diff --git a/setup.py b/setup.py index 2f3c1dd0..1c67fa89 100644 --- a/setup.py +++ b/setup.py @@ -17,9 +17,10 @@ import pathlib import re -from setuptools import find_packages -from setuptools import setup +import setuptools +setup = setuptools.setup +find_packages = setuptools.find_packages here = pathlib.Path(__file__).parent.resolve() # Get the long description from the README file @@ -86,9 +87,13 @@ "safetensors", "tabulate", "torch>=2.4.0", - "torch_xla>=2.4.0", "tf-nightly>=2.19.0.dev20241201", "ai-edge-litert-nightly", "ai-edge-quantizer-nightly", + "jax", + "torch-xla2[odml]>=0.0.1.dev20241201", ], + extras_require={ + "torch-xla": ["torch_xla>=2.4.0"], + }, ) diff --git a/test/test_quantize.py b/test/test_quantize.py index 4c3a3d86..99db733a 100644 --- a/test/test_quantize.py +++ b/test/test_quantize.py @@ -18,7 +18,6 @@ import tempfile import ai_edge_torch -from ai_edge_torch import config from ai_edge_torch.quantize import pt2e_quantizer from ai_edge_torch.quantize import quant_config import torch @@ -39,7 +38,7 @@ def setUp(self): torch.manual_seed(0) @googletest.skipIf( - not config.Config.use_torch_xla, + not ai_edge_torch.config.use_torch_xla, reason="Only working with torch_xla at the moment.", ) def test_quantizer_arg(self):