Skip to content

Commit

Permalink
Migrate ai-edge-torch to use ai-edge-litert instead of TFLite for pyt…
Browse files Browse the repository at this point in the history
…hon.

PiperOrigin-RevId: 676497735
  • Loading branch information
pak-laura authored and copybara-github committed Sep 19, 2024
1 parent 180ee2a commit 9c14805
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 9 deletions.
10 changes: 7 additions & 3 deletions ai_edge_torch/_convert/test/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from ai_edge_torch._convert import conversion_utils
from ai_edge_torch.testing import model_coverage
import numpy as np
import tensorflow as tf
import torch
from torch import nn
import torchvision

from absl.testing import absltest as googletest
from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import


@dataclasses.dataclass
Expand Down Expand Up @@ -466,7 +466,9 @@ def forward(self, x, y, z):
np.testing.assert_almost_equal(edge_output["y_data_2_0"], args[1])
np.testing.assert_almost_equal(edge_output["y_data_2_1"], args[2])

interpreter = tf.lite.Interpreter(model_content=edge_model._tflite_model)
interpreter = tfl_interpreter.Interpreter(
model_content=edge_model._tflite_model
)
runner = interpreter.get_signature_runner("serving_default")
output_details = runner.get_output_details()
self.assertIn("x", output_details.keys())
Expand All @@ -477,7 +479,9 @@ def forward(self, x, y, z):
def _compare_tflite_torch_args_kwargs(self, model, args, kwargs, flat_inputs):
model.eval()
edge_model = ai_edge_torch.convert(model, args, kwargs)
interpreter = tf.lite.Interpreter(model_content=edge_model._tflite_model)
interpreter = tfl_interpreter.Interpreter(
model_content=edge_model._tflite_model
)
runner = interpreter.get_signature_runner("serving_default")
input_details = runner.get_input_details()
self.assertEqual(input_details.keys(), flat_inputs.keys())
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/test/test_model_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import torch

from absl.testing import absltest as googletest
from tensorflow.lite.python import interpreter
from ai_edge_litert import interpreter


class TestModelConversion(googletest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch

from absl.testing import absltest as googletest
from tensorflow.lite.python import interpreter
from ai_edge_litert import interpreter


class TestModelConversion(googletest.TestCase):
Expand Down
11 changes: 7 additions & 4 deletions ai_edge_torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import numpy.typing as npt
import tensorflow as tf

from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import

DEFAULT_SIGNATURE_NAME = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY


Expand Down Expand Up @@ -65,7 +67,7 @@ def __init__(self, tflite_model):
tflite_model: A TFlite serialized object.
"""
self._tflite_model = tflite_model
self._interpreter_builder = lambda: tf.lite.Interpreter(
self._interpreter_builder = lambda: tfl_interpreter.Interpreter(
model_content=self._tflite_model,
experimental_default_delegate_latest_features=True,
)
Expand All @@ -75,12 +77,13 @@ def tflite_model(self) -> bytes:
return self._tflite_model

def set_interpreter_builder(
self, builder: Callable[[], tf.lite.Interpreter]
self, builder: Callable[[], tfl_interpreter.Interpreter]
) -> None:
"""Sets a custom interpreter builder.
Args:
builder: A function that returns a `tf.lite.Interpreter` or its subclass.
builder: A function that returns a `tfl_interpreter.Interpreter` or its
subclass.
"""
self._interpreter_builder = builder

Expand Down Expand Up @@ -166,7 +169,7 @@ def load(path: str) -> TfLiteModel | None:

# Check if this is indeed a tflite model:
try:
interpreter = tf.lite.Interpreter(model_content=model_content)
interpreter = tfl_interpreter.Interpreter(model_content=model_content)
interpreter.get_signature_list()
except:
return None
Expand Down
1 change: 1 addition & 0 deletions odmltorch-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ torchaudio==2.4.0+cpu
--pre
tf-nightly>=2.18.0.dev20240905
torch_xla2[odml]>=0.0.1.dev20240801
ai-edge-litert-nightly
ai-edge-quantizer-nightly
jax[cpu]
scipy
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ torchaudio==2.4.0+cpu
torch_xla==2.4.0
--pre
tf-nightly>=2.18.0.dev20240905
ai-edge-litert-nightly
ai-edge-quantizer-nightly
scipy
numpy
Expand Down

0 comments on commit 9c14805

Please sign in to comment.