diff --git a/ai_edge_torch/_convert/test/test_convert.py b/ai_edge_torch/_convert/test/test_convert.py index 2ec95a08..6f577621 100644 --- a/ai_edge_torch/_convert/test/test_convert.py +++ b/ai_edge_torch/_convert/test/test_convert.py @@ -23,12 +23,12 @@ from ai_edge_torch._convert import conversion_utils from ai_edge_torch.testing import model_coverage import numpy as np -import tensorflow as tf import torch from torch import nn import torchvision from absl.testing import absltest as googletest +from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import @dataclasses.dataclass @@ -466,7 +466,9 @@ def forward(self, x, y, z): np.testing.assert_almost_equal(edge_output["y_data_2_0"], args[1]) np.testing.assert_almost_equal(edge_output["y_data_2_1"], args[2]) - interpreter = tf.lite.Interpreter(model_content=edge_model._tflite_model) + interpreter = tfl_interpreter.Interpreter( + model_content=edge_model._tflite_model + ) runner = interpreter.get_signature_runner("serving_default") output_details = runner.get_output_details() self.assertIn("x", output_details.keys()) @@ -477,7 +479,9 @@ def forward(self, x, y, z): def _compare_tflite_torch_args_kwargs(self, model, args, kwargs, flat_inputs): model.eval() edge_model = ai_edge_torch.convert(model, args, kwargs) - interpreter = tf.lite.Interpreter(model_content=edge_model._tflite_model) + interpreter = tfl_interpreter.Interpreter( + model_content=edge_model._tflite_model + ) runner = interpreter.get_signature_runner("serving_default") input_details = runner.get_input_details() self.assertEqual(input_details.keys(), flat_inputs.keys()) diff --git a/ai_edge_torch/generative/test/test_model_conversion.py b/ai_edge_torch/generative/test/test_model_conversion.py index b857cf7c..114a1bf0 100644 --- a/ai_edge_torch/generative/test/test_model_conversion.py +++ b/ai_edge_torch/generative/test/test_model_conversion.py @@ -25,7 +25,7 @@ import torch from absl.testing import absltest as googletest -from tensorflow.lite.python import interpreter +from ai_edge_litert import interpreter class TestModelConversion(googletest.TestCase): diff --git a/ai_edge_torch/generative/test/test_model_conversion_large.py b/ai_edge_torch/generative/test/test_model_conversion_large.py index 7bd73e29..cffa402e 100644 --- a/ai_edge_torch/generative/test/test_model_conversion_large.py +++ b/ai_edge_torch/generative/test/test_model_conversion_large.py @@ -28,7 +28,7 @@ import torch from absl.testing import absltest as googletest -from tensorflow.lite.python import interpreter +from ai_edge_litert import interpreter class TestModelConversion(googletest.TestCase): diff --git a/ai_edge_torch/model.py b/ai_edge_torch/model.py index a8f43327..33f949a2 100644 --- a/ai_edge_torch/model.py +++ b/ai_edge_torch/model.py @@ -27,6 +27,8 @@ import numpy.typing as npt import tensorflow as tf +from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import + DEFAULT_SIGNATURE_NAME = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -65,7 +67,7 @@ def __init__(self, tflite_model): tflite_model: A TFlite serialized object. """ self._tflite_model = tflite_model - self._interpreter_builder = lambda: tf.lite.Interpreter( + self._interpreter_builder = lambda: tfl_interpreter.Interpreter( model_content=self._tflite_model, experimental_default_delegate_latest_features=True, ) @@ -75,12 +77,13 @@ def tflite_model(self) -> bytes: return self._tflite_model def set_interpreter_builder( - self, builder: Callable[[], tf.lite.Interpreter] + self, builder: Callable[[], tfl_interpreter.Interpreter] ) -> None: """Sets a custom interpreter builder. Args: - builder: A function that returns a `tf.lite.Interpreter` or its subclass. + builder: A function that returns a `tfl_interpreter.Interpreter` or its + subclass. """ self._interpreter_builder = builder @@ -166,7 +169,7 @@ def load(path: str) -> TfLiteModel | None: # Check if this is indeed a tflite model: try: - interpreter = tf.lite.Interpreter(model_content=model_content) + interpreter = tfl_interpreter.Interpreter(model_content=model_content) interpreter.get_signature_list() except: return None diff --git a/odmltorch-requirements.txt b/odmltorch-requirements.txt index 2855ce03..a9492f17 100644 --- a/odmltorch-requirements.txt +++ b/odmltorch-requirements.txt @@ -7,6 +7,7 @@ torchaudio==2.4.0+cpu --pre tf-nightly>=2.18.0.dev20240905 torch_xla2[odml]>=0.0.1.dev20240801 +ai-edge-litert-nightly ai-edge-quantizer-nightly jax[cpu] scipy diff --git a/requirements.txt b/requirements.txt index 4a73db07..fa544557 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ torchaudio==2.4.0+cpu torch_xla==2.4.0 --pre tf-nightly>=2.18.0.dev20240905 +ai-edge-litert-nightly ai-edge-quantizer-nightly scipy numpy