diff --git a/mediapipe/model_maker/python/core/utils/BUILD b/mediapipe/model_maker/python/core/utils/BUILD index 4b8ff1bf97..3d50d9aaaf 100644 --- a/mediapipe/model_maker/python/core/utils/BUILD +++ b/mediapipe/model_maker/python/core/utils/BUILD @@ -44,6 +44,7 @@ py_library( deps = [ ":quantization", "//mediapipe/model_maker/python/core/data:dataset", + "@model_maker_pip_deps_ai_edge_litert//:pkg", "@model_maker_pip_deps_numpy//:pkg", "@model_maker_pip_deps_tensorflow//:pkg", ], @@ -172,6 +173,7 @@ py_test( ":quantization", ":test_util", "@model_maker_pip_deps_absl_py//:pkg", + "@model_maker_pip_deps_ai_edge_litert//:pkg", "@model_maker_pip_deps_tensorflow//:pkg", ], ) diff --git a/mediapipe/model_maker/python/core/utils/model_util.py b/mediapipe/model_maker/python/core/utils/model_util.py index 32b509797f..0a13095e6a 100644 --- a/mediapipe/model_maker/python/core/utils/model_util.py +++ b/mediapipe/model_maker/python/core/utils/model_util.py @@ -28,6 +28,7 @@ from mediapipe.model_maker.python.core.data import dataset from mediapipe.model_maker.python.core.utils import quantization +from ai_edge_litert import interpreter as tfl_interpreter DEFAULT_SCALE, DEFAULT_ZERO_POINT = 0, 0 ESTIMITED_STEPS_PER_EPOCH = 1000 @@ -273,7 +274,7 @@ def __init__(self, tflite_model: bytearray): Args: tflite_model: A valid flatbuffer representing the TFLite model. """ - self.interpreter = tf.lite.Interpreter(model_content=tflite_model) + self.interpreter = tfl_interpreter.Interpreter(model_content=tflite_model) self.interpreter.allocate_tensors() self.input_details = self.interpreter.get_input_details() self.output_details = self.interpreter.get_output_details() diff --git a/mediapipe/model_maker/python/core/utils/quantization_test.py b/mediapipe/model_maker/python/core/utils/quantization_test.py index 57523d4056..0164d39bf8 100644 --- a/mediapipe/model_maker/python/core/utils/quantization_test.py +++ b/mediapipe/model_maker/python/core/utils/quantization_test.py @@ -17,6 +17,7 @@ from mediapipe.model_maker.python.core.utils import quantization from mediapipe.model_maker.python.core.utils import test_util +from ai_edge_litert import interpreter as tfl_interpreter class QuantizationTest(tf.test.TestCase, parameterized.TestCase): @@ -59,7 +60,7 @@ def test_set_converter_with_quantization_from_int8_config(self): self.assertEqual(config.supported_ops, [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]) tflite_model = converter.convert() - interpreter = tf.lite.Interpreter(model_content=tflite_model) + interpreter = tfl_interpreter.Interpreter(model_content=tflite_model) self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.uint8) self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.uint8) @@ -82,7 +83,7 @@ def test_set_converter_with_quantization_from_float16_config(self): converter = config.set_converter_with_quantization(converter=converter) self.assertEqual(config.supported_types, [tf.float16]) tflite_model = converter.convert() - interpreter = tf.lite.Interpreter(model_content=tflite_model) + interpreter = tfl_interpreter.Interpreter(model_content=tflite_model) # The input and output are expected to be set to float32 by default. self.assertEqual(interpreter.get_input_details()[0]['dtype'], tf.float32) self.assertEqual(interpreter.get_output_details()[0]['dtype'], tf.float32) diff --git a/mediapipe/model_maker/requirements.txt b/mediapipe/model_maker/requirements.txt index 3ce977b671..404e58a5de 100644 --- a/mediapipe/model_maker/requirements.txt +++ b/mediapipe/model_maker/requirements.txt @@ -1,4 +1,5 @@ absl-py +ai-edge-litert mediapipe>=0.10.0 numpy<2 opencv-python diff --git a/mediapipe/model_maker/requirements_bazel.txt b/mediapipe/model_maker/requirements_bazel.txt index fd6c421cf8..afc6d5b626 100644 --- a/mediapipe/model_maker/requirements_bazel.txt +++ b/mediapipe/model_maker/requirements_bazel.txt @@ -1,4 +1,5 @@ absl-py +ai-edge-litert numpy<2 opencv-python setuptools==70.3.0 # needed due to https://github.com/pypa/setuptools/issues/4487 diff --git a/mediapipe/model_maker/requirements_lock.txt b/mediapipe/model_maker/requirements_lock.txt index ae285cf8fb..326f71b2e0 100644 --- a/mediapipe/model_maker/requirements_lock.txt +++ b/mediapipe/model_maker/requirements_lock.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.9 +# This file is autogenerated by pip-compile with Python 3.11 # by the following command: # # pip-compile --allow-unsafe --output-file=mediapipe/opensource_only/model_maker_requirements_lock.txt mediapipe/opensource_only/model_maker_requirements_bazel.txt @@ -15,6 +15,8 @@ absl-py==1.4.0 # tensorflow-metadata # tensorflow-model-optimization # tf-slim +ai-edge-litert==1.0.1 + # via -r mediapipe/opensource_only/model_maker_requirements_bazel.txt array-record==0.5.1 # via tensorflow-datasets astunparse==1.6.3 @@ -75,7 +77,9 @@ google-auth-oauthlib==1.2.1 google-pasta==0.2.0 # via tensorflow googleapis-common-protos==1.65.0 - # via google-api-core + # via + # google-api-core + # tensorflow-metadata grpcio==1.66.2 # via # tensorboard @@ -91,12 +95,8 @@ idna==3.10 # via requests immutabledict==4.2.0 # via tf-models-official -importlib-metadata==8.5.0 - # via markdown importlib-resources==6.4.5 - # via - # etils - # matplotlib + # via etils joblib==1.4.2 # via scikit-learn kaggle==1.6.17 @@ -122,6 +122,7 @@ ml-dtypes==0.3.2 numpy==1.26.4 # via # -r mediapipe/opensource_only/model_maker_requirements_bazel.txt + # ai-edge-litert # contourpy # etils # h5py @@ -168,7 +169,7 @@ promise==2.3 # via tensorflow-datasets proto-plus==1.24.0 # via google-api-core -protobuf==3.20.3 +protobuf==4.25.5 # via # google-api-core # googleapis-common-protos @@ -333,10 +334,7 @@ wrapt==1.14.1 # tensorflow # tensorflow-datasets zipp==3.20.2 - # via - # etils - # importlib-metadata - # importlib-resources + # via etils # The following packages are considered to be unsafe in a requirements file: setuptools==70.3.0