Skip to content

Commit

Permalink
Move tensorflow lite python calls to ai-edge-litert.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 686633199
  • Loading branch information
pak-laura authored and copybara-github committed Oct 16, 2024
1 parent 927a0b9 commit 27c57a0
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 15 deletions.
2 changes: 2 additions & 0 deletions mediapipe/model_maker/python/core/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ py_library(
deps = [
":quantization",
"//mediapipe/model_maker/python/core/data:dataset",
"@model_maker_pip_deps_ai_edge_litert//:pkg",
"@model_maker_pip_deps_numpy//:pkg",
"@model_maker_pip_deps_tensorflow//:pkg",
],
Expand Down Expand Up @@ -172,6 +173,7 @@ py_test(
":quantization",
":test_util",
"@model_maker_pip_deps_absl_py//:pkg",
"@model_maker_pip_deps_ai_edge_litert//:pkg",
"@model_maker_pip_deps_tensorflow//:pkg",
],
)
3 changes: 2 additions & 1 deletion mediapipe/model_maker/python/core/utils/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from mediapipe.model_maker.python.core.data import dataset
from mediapipe.model_maker.python.core.utils import quantization
from ai_edge_litert import interpreter as tfl_interpreter

DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0
ESTIMITED_STEPS_PER_EPOCH = 1000
Expand Down Expand Up @@ -273,7 +274,7 @@ def __init__(self, tflite_model: bytearray):
Args:
tflite_model: A valid flatbuffer representing the TFLite model.
"""
self.interpreter = tf.lite.Interpreter(model_content=tflite_model)
self.interpreter = tfl_interpreter.Interpreter(model_content=tflite_model)
self.interpreter.allocate_tensors()
self.input_details = self.interpreter.get_input_details()
self.output_details = self.interpreter.get_output_details()
Expand Down
5 changes: 3 additions & 2 deletions mediapipe/model_maker/python/core/utils/quantization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from mediapipe.model_maker.python.core.utils import quantization
from mediapipe.model_maker.python.core.utils import test_util
from ai_edge_litert import interpreter as tfl_interpreter


class QuantizationTest(tf.test.TestCase, parameterized.TestCase):
Expand Down Expand Up @@ -59,7 +60,7 @@ def test_set_converter_with_quantization_from_int8_config(self):
self.assertEqual(config.supported_ops,
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8])
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter = tfl_interpreter.Interpreter(model_content=tflite_model)
self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.uint8)
self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.uint8)

Expand All @@ -82,7 +83,7 @@ def test_set_converter_with_quantization_from_float16_config(self):
converter = config.set_converter_with_quantization(converter=converter)
self.assertEqual(config.supported_types, [tf.float16])
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter = tfl_interpreter.Interpreter(model_content=tflite_model)
# The input and output are expected to be set to float32 by default.
self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.float32)
self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.float32)
Expand Down
1 change: 1 addition & 0 deletions mediapipe/model_maker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
absl-py
ai-edge-litert
mediapipe>=0.10.0
numpy<2
opencv-python
Expand Down
1 change: 1 addition & 0 deletions mediapipe/model_maker/requirements_bazel.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
absl-py
ai-edge-litert
numpy<2
opencv-python
setuptools==70.3.0 # needed due to https://github.com/pypa/setuptools/issues/4487
Expand Down
22 changes: 10 additions & 12 deletions mediapipe/model_maker/requirements_lock.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# This file is autogenerated by pip-compile with Python 3.9
# This file is autogenerated by pip-compile with Python 3.11
# by the following command:
#
# pip-compile --allow-unsafe --output-file=mediapipe/opensource_only/model_maker_requirements_lock.txt mediapipe/opensource_only/model_maker_requirements_bazel.txt
Expand All @@ -15,6 +15,8 @@ absl-py==1.4.0
# tensorflow-metadata
# tensorflow-model-optimization
# tf-slim
ai-edge-litert==1.0.1
# via -r mediapipe/opensource_only/model_maker_requirements_bazel.txt
array-record==0.5.1
# via tensorflow-datasets
astunparse==1.6.3
Expand Down Expand Up @@ -75,7 +77,9 @@ google-auth-oauthlib==1.2.1
google-pasta==0.2.0
# via tensorflow
googleapis-common-protos==1.65.0
# via google-api-core
# via
# google-api-core
# tensorflow-metadata
grpcio==1.66.2
# via
# tensorboard
Expand All @@ -91,12 +95,8 @@ idna==3.10
# via requests
immutabledict==4.2.0
# via tf-models-official
importlib-metadata==8.5.0
# via markdown
importlib-resources==6.4.5
# via
# etils
# matplotlib
# via etils
joblib==1.4.2
# via scikit-learn
kaggle==1.6.17
Expand All @@ -122,6 +122,7 @@ ml-dtypes==0.3.2
numpy==1.26.4
# via
# -r mediapipe/opensource_only/model_maker_requirements_bazel.txt
# ai-edge-litert
# contourpy
# etils
# h5py
Expand Down Expand Up @@ -168,7 +169,7 @@ promise==2.3
# via tensorflow-datasets
proto-plus==1.24.0
# via google-api-core
protobuf==3.20.3
protobuf==4.25.5
# via
# google-api-core
# googleapis-common-protos
Expand Down Expand Up @@ -333,10 +334,7 @@ wrapt==1.14.1
# tensorflow
# tensorflow-datasets
zipp==3.20.2
# via
# etils
# importlib-metadata
# importlib-resources
# via etils

# The following packages are considered to be unsafe in a requirements file:
setuptools==70.3.0
Expand Down

0 comments on commit 27c57a0

Please sign in to comment.