diff --git a/ai_edge_quantizer/calibrator.py b/ai_edge_quantizer/calibrator.py index abec283..16e7b61 100644 --- a/ai_edge_quantizer/calibrator.py +++ b/ai_edge_quantizer/calibrator.py @@ -17,7 +17,7 @@ from collections.abc import Callable, Iterable import copy -from typing import Any, Optional, Union +from typing import Any, Union from absl import logging import numpy as np @@ -62,9 +62,8 @@ def __init__( # TODO(b/330740605)- Collect multiple QSVs in one run to save compute. def calibrate( self, - calibration_dataset: Iterable[_SignatureInput], + calibration_dataset: dict[str, Iterable[_SignatureInput]], model_recipe_manager: recipe_manager.RecipeManager, - signature_key: Optional[str] = None, cache_output: bool = False, qsv_update_func: Callable[ [qtyping.QSV, qtyping.QSV], @@ -87,13 +86,10 @@ def calibrate( 6. Start another round of calibration. Args: - calibration_dataset: A list of input data for calibration for the given - model signature. + calibration_dataset: A dictionary of input data for calibration for the + given model signature. model_recipe_manager: A RecipeManager object that contains the quantization recipe. - signature_key: The signature key to be used for invoking the models. If - the model doesn't have a signature key (or only has one ), this can be - set to None. cache_output: Whether to cache the output of the model during the calibration process. This is useful if there are dependencies between signatures/models (e.g., decode requires encode output). @@ -111,23 +107,31 @@ def calibrate( ) # TODO: b/329322226 - Enable parrallel calibration. - for data in calibration_dataset: - # Initialize tensor names that are updated in this round of calibration. - updated_tensor_names = set() - - # Step1: run tfl interpreter to get tensor content. - signature_output = tfl_interpreter_utils.invoke_interpreter_signature( - self._tfl_interpreter, data, signature_key - ) - if cache_output: - self._cached_output.append(signature_output) - self._tensor_content_map = ( - tfl_interpreter_utils.get_tensor_name_to_content_map( - self._tfl_interpreter - ) + for signature_key, dataset in calibration_dataset.items(): + # Step0: get subgraph index. + subgraph_idx = tfl_interpreter_utils.get_signature_main_subgraph_index( + self._tfl_interpreter, signature_key ) - # Step2: go through each op to update quantization statistic values. - for subgraph in self._flatbuffer_model.subgraphs: + + for data in dataset: + # Initialize tensor names that are updated in this round of calibration. + updated_tensor_names = set() + + # Step1: run tfl interpreter on subgraph to get tensor content. + signature_output = tfl_interpreter_utils.invoke_interpreter_signature( + self._tfl_interpreter, data, signature_key + ) + if cache_output: + self._cached_output.append(signature_output) + self._tensor_content_map.update( + tfl_interpreter_utils.get_tensor_name_to_content_map( + self._tfl_interpreter, subgraph_idx + ) + ) + + # Step2: go through each op in subgraph to update quantization + # statistic values. + subgraph = self._flatbuffer_model.subgraphs[subgraph_idx] graph_info = qtyping.GraphInfo( subgraph.tensors, self._flatbuffer_model.buffers ) diff --git a/ai_edge_quantizer/calibrator_test.py b/ai_edge_quantizer/calibrator_test.py index e54fe83..8a8d43f 100644 --- a/ai_edge_quantizer/calibrator_test.py +++ b/ai_edge_quantizer/calibrator_test.py @@ -15,7 +15,9 @@ """Tests for calibrator.""" +from collections.abc import Generator import os +from typing import Any, Dict import numpy as np @@ -24,6 +26,7 @@ from ai_edge_quantizer import qtyping from ai_edge_quantizer import recipe_manager from ai_edge_quantizer.utils import test_utils +from ai_edge_quantizer.utils import tfl_interpreter_utils _ComputePrecision = qtyping.ComputePrecision _AlgorithmName = recipe_manager.AlgorithmName @@ -33,6 +36,8 @@ TEST_MIN_VAL, TEST_MAX_VAL = -1, 1 +_RNG = np.random.default_rng(66) + def _representative_dataset_gen(size=(1, 8), num_samples=10): for _ in range(num_samples): @@ -44,6 +49,14 @@ def _representative_dataset_gen(size=(1, 8), num_samples=10): yield {"input_1": vals} +def _get_calibration_data(dataset_gen: Generator[Dict[str, Any], Any, None]): + calibration_samples = [sample for sample in dataset_gen] + calibration_data = { + tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples, + } + return calibration_data + + def _add_default_int8xint8_integer_recipe(recipe_manager_object): recipe_manager_object.add_quantization_config( regex=".*", @@ -69,7 +82,8 @@ def setUp(self): ) self._calibrator = calibrator.Calibrator(self._test_model_path) self._recipe_manager = recipe_manager.RecipeManager() - self._representative_dataset = _representative_dataset_gen() + dataset_gen = _representative_dataset_gen() + self._representative_dataset = _get_calibration_data(dataset_gen) def test_calibrator_state_manipulation(self): # load/get qsvs @@ -204,8 +218,9 @@ def test_calibrate_unsupported_ops_success(self): ) test_calibrator = calibrator.Calibrator(test_model_path) _add_default_int8xint8_integer_recipe(self._recipe_manager) + dataset_gen = _representative_dataset_gen(size=(3, 4, 4, 1)) test_calibrator.calibrate( - _representative_dataset_gen(size=(3, 4, 4, 1)), + _get_calibration_data(dataset_gen), self._recipe_manager, cache_output=True, ) @@ -231,5 +246,50 @@ def test_check_is_float_model_raises_error_when_model_is_quantized(self): _ = calibrator.Calibrator(test_model_path) +class CalibratorToyGemma2Test(googletest.TestCase): + + def setUp(self): + super().setUp() + np.random.seed(0) + + self._test_model_path = os.path.join( + TEST_DATA_PREFIX_PATH, + "tests/models/toy_model_with_kv_cache_multi_signature.tflite", + ) + + self._toy_gemma2_calibration_dataset = { + "signature_1": [{ + "cache_0": _RNG.random(size=(1, 100, 4, 4)).astype(np.float32), + "cache_1": _RNG.random(size=(1, 100, 4, 4)).astype(np.float32), + "positions": _RNG.integers(low=0, high=10, size=(1, 100)).astype( + np.int32 + ), + "tokens": _RNG.integers(low=0, high=10, size=(1, 100)).astype( + np.int32 + ), + }], + "signature_2": [{ + "cache_0": _RNG.random(size=(1, 100, 4, 4)).astype(np.float32), + "cache_1": _RNG.random(size=(1, 100, 4, 4)).astype(np.float32), + "positions": _RNG.integers(low=0, high=10, size=(1, 100)).astype( + np.int32 + ), + "tokens": _RNG.integers(low=0, high=10, size=(1, 100)).astype( + np.int32 + ), + }], + } + + def test_toy_gemma2_calibration_success(self): + calib = calibrator.Calibrator(self._test_model_path) + recipe_mngr = recipe_manager.RecipeManager() + _add_default_int8xint8_integer_recipe(recipe_mngr) + calib.calibrate( + self._toy_gemma2_calibration_dataset, + model_recipe_manager=recipe_mngr, + ) + self.assertLen(calib.get_model_qsvs(), 260) + + if __name__ == "__main__": googletest.main() diff --git a/ai_edge_quantizer/params_generator_test.py b/ai_edge_quantizer/params_generator_test.py index 7d1351d..69eb56b 100644 --- a/ai_edge_quantizer/params_generator_test.py +++ b/ai_edge_quantizer/params_generator_test.py @@ -15,7 +15,9 @@ """Tests for params_generator.""" +from collections.abc import Generator import os +from typing import Any, Dict from absl.testing import parameterized import numpy as np @@ -27,6 +29,7 @@ from ai_edge_quantizer import recipe_manager from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils +from ai_edge_quantizer.utils import tfl_interpreter_utils _ComputePrecision = qtyping.ComputePrecision @@ -51,6 +54,14 @@ def _int_transpose_model_representative_dataset_gen(num_samples=5): return data +def _get_calibration_data(dataset_gen: Generator[Dict[str, Any], Any, None]): + calibration_samples = [sample for sample in dataset_gen] + calibration_data = { + tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples, + } + return calibration_data + + class ParamsGeneratorTest(parameterized.TestCase): def setUp(self): @@ -444,9 +455,10 @@ def test_generate_config_int8xint8_single_fc( # Calibrate then quantize model_calibrator = calibrator.Calibrator(single_fc_model_path) - model_calibrator.calibrate( - _single_fc_model_representative_dataset_gen(), self._recipe_manager + calibration_data = _get_calibration_data( + _single_fc_model_representative_dataset_gen() ) + model_calibrator.calibrate(calibration_data, self._recipe_manager) model_qsvs = model_calibrator.get_model_qsvs() quant_params = params_generator_single_fc.generate_quantization_parameters( self._recipe_manager, @@ -905,9 +917,10 @@ def test_quantize_integer_input_output(self): # Calibrate then quantize. model_calibrator = calibrator.Calibrator(model_path) - model_calibrator.calibrate( - _int_transpose_model_representative_dataset_gen(), self._recipe_manager + calibration_data = _get_calibration_data( + _int_transpose_model_representative_dataset_gen() ) + model_calibrator.calibrate(calibration_data, self._recipe_manager) model_qsvs = model_calibrator.get_model_qsvs() quant_params = pg.generate_quantization_parameters( self._recipe_manager, diff --git a/ai_edge_quantizer/quantizer.py b/ai_edge_quantizer/quantizer.py index e95604f..66506dc 100644 --- a/ai_edge_quantizer/quantizer.py +++ b/ai_edge_quantizer/quantizer.py @@ -213,16 +213,13 @@ def need_calibration(self) -> bool: def calibrate( self, - calibration_data: Iterable[_SignatureInput], - signature_key: Optional[str] = None, + calibration_data: dict[str, Iterable[_SignatureInput]], previous_calibration_result: Optional[_CalibrationResult] = None, ) -> _CalibrationResult: """Calibrates the float model (required by static range quantization). Args: calibration_data: Calibration data for a model signature. - signature_key: The signature key to be used for invoking the models. If - the model doesn't have a signature key, this can be set to None. previous_calibration_result: Previous calibration result to be loaded. The calibration process will be resumed from the previous result. @@ -235,7 +232,7 @@ def calibrate( calib = calibrator.Calibrator(self.float_model) if previous_calibration_result is not None: calib.load_model_qsvs(previous_calibration_result) - calib.calibrate(calibration_data, self._recipe_manager, signature_key) + calib.calibrate(calibration_data, self._recipe_manager) return calib.get_model_qsvs() def quantize( diff --git a/ai_edge_quantizer/quantizer_test.py b/ai_edge_quantizer/quantizer_test.py index bf4a746..0d6edcf 100644 --- a/ai_edge_quantizer/quantizer_test.py +++ b/ai_edge_quantizer/quantizer_test.py @@ -23,6 +23,7 @@ from ai_edge_quantizer import qtyping from ai_edge_quantizer import quantizer from ai_edge_quantizer.utils import test_utils +from ai_edge_quantizer.utils import tfl_interpreter_utils _ComputePrecision = qtyping.ComputePrecision _TFLOpName = qtyping.TFLOperationName @@ -31,15 +32,22 @@ _AlgorithmName = quantizer.AlgorithmName TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('') +_MULTI_SIGNATURE_CALIBRATION_DATASET = { + 'add': [{'x': np.array([2.0]).astype(np.float32)}], + 'multiply': [{'x': np.array([1.0]).astype(np.float32)}], +} _RNG = np.random.default_rng(66) def _get_calibration_data(num_samples: int = 16): - calibration_data = [] + calibration_samples = [] for _ in range(num_samples): - calibration_data.append( + calibration_samples.append( {'conv2d_input': _RNG.uniform(size=(1, 28, 28, 1)).astype(np.float32)} ) + calibration_data = { + tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples, + } return calibration_data @@ -342,13 +350,7 @@ def setUp(self): @parameterized.named_parameters( ('default_random_data', None), - ( - 'specific_data', - { - 'add': [{'x': np.array([2.0]).astype(np.float32)}], - 'multiply': [{'x': np.array([1.0]).astype(np.float32)}], - }, - ), + ('specific_data', _MULTI_SIGNATURE_CALIBRATION_DATASET), ) def test_validate_multiple_signatures_succeeds(self, test_data): self._quantizer.quantize(self._calibration_result) @@ -398,6 +400,115 @@ def test_validate_multiply_signature_succeeds(self): self.assertIn('Mul/y', mul_result.constant_tensors) self.assertEmpty(mul_result.intermediate_tensors) + def test_validate_quatize_after_calibration_succeeds(self): + calib_result = self._quantizer.calibrate( + _MULTI_SIGNATURE_CALIBRATION_DATASET + ) + self._quantizer.quantize(calib_result) + validation_result = self._quantizer.validate( + _MULTI_SIGNATURE_CALIBRATION_DATASET + ) + available_signatures = validation_result.available_signature_keys() + self.assertLen(available_signatures, 2) + + def test_recipe_conflict(self): + recipe = [ + dict({ + 'regex': '.*', + 'operation': 'ADD', + 'algorithm_key': 'min_max_uniform_quantize', + 'op_config': { + 'activation_tensor_config': { + 'num_bits': 8, + 'symmetric': False, + 'granularity': 'TENSORWISE', + 'dtype': 'INT', + 'block_size': 0, + }, + 'weight_tensor_config': { + 'num_bits': 8, + 'symmetric': True, + 'granularity': 'CHANNELWISE', + 'dtype': 'INT', + 'block_size': 0, + }, + 'compute_precision': 'INTEGER', + 'explicit_dequantize': False, + 'skip_checks': False, + }, + }) + ] + + qt = quantizer.Quantizer(self._test_model_path, recipe) + calib_result = qt.calibrate(_MULTI_SIGNATURE_CALIBRATION_DATASET) + + error_message = ( + "The tensors b'Add/y' and b'Mul/y' do not have the same quantization" + ) + with self.assertRaisesWithPredicateMatch( + RuntimeError, lambda err: error_message in str(err) + ): + qt.quantize(calib_result) + + +class QuantizerToyGemma2Test(parameterized.TestCase): + + def setUp(self): + super().setUp() + self._tmp_save_path = self.create_tempdir().full_path + self._test_model_path = os.path.join( + TEST_DATA_PREFIX_PATH, + 'tests/models/toy_model_with_kv_cache_multi_signature.tflite', + ) + + self._toy_gemma2_calibration_dataset = { + 'signature_1': [{ + 'cache_0': _RNG.random(size=(1, 100, 4, 4)).astype(np.float32), + 'cache_1': _RNG.random(size=(1, 100, 4, 4)).astype(np.float32), + 'positions': _RNG.integers(low=0, high=10, size=(1, 100)).astype( + np.int32 + ), + 'tokens': _RNG.integers(low=0, high=10, size=(1, 100)).astype( + np.int32 + ), + }], + 'signature_2': [{ + 'cache_0': _RNG.random(size=(1, 100, 4, 4)).astype(np.float32), + 'cache_1': _RNG.random(size=(1, 100, 4, 4)).astype(np.float32), + 'positions': _RNG.integers(low=0, high=10, size=(1, 100)).astype( + np.int32 + ), + 'tokens': _RNG.integers(low=0, high=10, size=(1, 100)).astype( + np.int32 + ), + }], + } + + self._test_recipe_path = os.path.join( + TEST_DATA_PREFIX_PATH, + 'recipes/default_a8w8_recipe.json', + ) + with open(self._test_recipe_path) as json_file: + self._test_recipe = json.load(json_file) + + self._quantizer = quantizer.Quantizer( + self._test_model_path, self._test_recipe_path + ) + + self._quantizer.update_quantization_recipe( + regex='StatefulPartitionedCall', + operation_name=qtyping.TFLOperationName.FULLY_CONNECTED, + algorithm_key=_AlgorithmName.NO_QUANTIZE, + ) + + def test_toy_gemma2_quantization_succeeds(self): + calib_result = self._quantizer.calibrate( + self._toy_gemma2_calibration_dataset + ) + self.assertIsNotNone(calib_result) + self._quantizer.quantize(calib_result) + self.assertIsNotNone(self._quantizer._result.quantized_model) + if __name__ == '__main__': googletest.main() diff --git a/ai_edge_quantizer/tests/end_to_end_tests/add_test.py b/ai_edge_quantizer/tests/end_to_end_tests/add_test.py index 276a4f0..66625e9 100644 --- a/ai_edge_quantizer/tests/end_to_end_tests/add_test.py +++ b/ai_edge_quantizer/tests/end_to_end_tests/add_test.py @@ -22,6 +22,7 @@ from ai_edge_quantizer import qtyping from ai_edge_quantizer import quantizer from ai_edge_quantizer.utils import test_utils +from ai_edge_quantizer.utils import tfl_interpreter_utils from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import _OpExecutionMode = qtyping.OpExecutionMode @@ -43,11 +44,15 @@ def _get_dummy_data(num_inputs, num_samples): def _get_calibration_data(num_inputs, num_samples: int = 128): - return _get_dummy_data(num_inputs, num_samples) + calibration_samples = _get_dummy_data(num_inputs, num_samples) + calibration_data = { + tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples, + } + return calibration_data def _get_test_data(num_inputs, num_samples: int = 8): - return {'serving_default': _get_dummy_data(num_inputs, num_samples)} + return _get_calibration_data(num_inputs, num_samples) class AddTest(parameterized.TestCase): diff --git a/ai_edge_quantizer/tests/end_to_end_tests/depthwise_conv2d_test.py b/ai_edge_quantizer/tests/end_to_end_tests/depthwise_conv2d_test.py index 7fe7a90..69cac30 100644 --- a/ai_edge_quantizer/tests/end_to_end_tests/depthwise_conv2d_test.py +++ b/ai_edge_quantizer/tests/end_to_end_tests/depthwise_conv2d_test.py @@ -22,6 +22,7 @@ from ai_edge_quantizer import qtyping from ai_edge_quantizer import quantizer from ai_edge_quantizer.utils import test_utils +from ai_edge_quantizer.utils import tfl_interpreter_utils _ComputePrecision = qtyping.ComputePrecision _OpName = qtyping.TFLOperationName @@ -41,11 +42,15 @@ def _get_dummy_data(num_samples): def _get_calibration_data(num_samples: int = 128): - return _get_dummy_data(num_samples) + calibration_samples = _get_dummy_data(num_samples) + calibration_data = { + tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples, + } + return calibration_data def _get_test_data(num_samples: int = 8): - return {'serving_default': _get_dummy_data(num_samples)} + return _get_calibration_data(num_samples) class DepthwiseConv2dTest(parameterized.TestCase): diff --git a/ai_edge_quantizer/tests/end_to_end_tests/gelu_test.py b/ai_edge_quantizer/tests/end_to_end_tests/gelu_test.py index dcbd813..559d798 100644 --- a/ai_edge_quantizer/tests/end_to_end_tests/gelu_test.py +++ b/ai_edge_quantizer/tests/end_to_end_tests/gelu_test.py @@ -22,6 +22,7 @@ from ai_edge_quantizer import qtyping from ai_edge_quantizer import quantizer from ai_edge_quantizer.utils import test_utils +from ai_edge_quantizer.utils import tfl_interpreter_utils _OpExecutionMode = qtyping.OpExecutionMode _OpName = qtyping.TFLOperationName @@ -39,11 +40,15 @@ def _get_dummy_data(num_samples): def _get_calibration_data(num_samples: int = 128): - return _get_dummy_data(num_samples) + calibration_samples = _get_dummy_data(num_samples) + calibration_data = { + tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples, + } + return calibration_data def _get_test_data(num_samples: int = 8): - return {'serving_default': _get_dummy_data(num_samples)} + return _get_calibration_data(num_samples) class GeluTest(parameterized.TestCase): diff --git a/ai_edge_quantizer/tests/end_to_end_tests/mul_test.py b/ai_edge_quantizer/tests/end_to_end_tests/mul_test.py index 16c1ebf..b3bdcb1 100644 --- a/ai_edge_quantizer/tests/end_to_end_tests/mul_test.py +++ b/ai_edge_quantizer/tests/end_to_end_tests/mul_test.py @@ -22,6 +22,7 @@ from ai_edge_quantizer import qtyping from ai_edge_quantizer import quantizer from ai_edge_quantizer.utils import test_utils +from ai_edge_quantizer.utils import tfl_interpreter_utils from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import _ComputePrecision = qtyping.ComputePrecision @@ -43,11 +44,15 @@ def _get_dummy_data(num_inputs, num_samples): def _get_calibration_data(num_inputs, num_samples: int = 512): - return _get_dummy_data(num_inputs, num_samples) + calibration_samples = _get_dummy_data(num_inputs, num_samples) + calibration_data = { + tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples, + } + return calibration_data def _get_test_data(num_inputs, num_samples: int = 8): - return {'serving_default': _get_dummy_data(num_inputs, num_samples)} + return _get_calibration_data(num_inputs, num_samples) class MulTest(parameterized.TestCase): diff --git a/ai_edge_quantizer/tests/end_to_end_tests/sub_test.py b/ai_edge_quantizer/tests/end_to_end_tests/sub_test.py index 355a766..b3f348b 100644 --- a/ai_edge_quantizer/tests/end_to_end_tests/sub_test.py +++ b/ai_edge_quantizer/tests/end_to_end_tests/sub_test.py @@ -22,6 +22,7 @@ from ai_edge_quantizer import qtyping from ai_edge_quantizer import quantizer from ai_edge_quantizer.utils import test_utils +from ai_edge_quantizer.utils import tfl_interpreter_utils from tensorflow.python.platform import gfile # pylint: disable=g-direct-tensorflow-import _OpExecutionMode = qtyping.OpExecutionMode @@ -43,11 +44,15 @@ def _get_dummy_data(num_inputs, num_samples): def _get_calibration_data(num_inputs, num_samples: int = 512): - return _get_dummy_data(num_inputs, num_samples) + calibration_samples = _get_dummy_data(num_inputs, num_samples) + calibration_data = { + tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples, + } + return calibration_data def _get_test_data(num_inputs, num_samples: int = 8): - return {'serving_default': _get_dummy_data(num_inputs, num_samples)} + return _get_calibration_data(num_inputs, num_samples) class SubTest(parameterized.TestCase): diff --git a/ai_edge_quantizer/tests/end_to_end_tests/tanh_test.py b/ai_edge_quantizer/tests/end_to_end_tests/tanh_test.py index a4841ee..d3fa1e6 100644 --- a/ai_edge_quantizer/tests/end_to_end_tests/tanh_test.py +++ b/ai_edge_quantizer/tests/end_to_end_tests/tanh_test.py @@ -22,6 +22,7 @@ from ai_edge_quantizer import qtyping from ai_edge_quantizer import quantizer from ai_edge_quantizer.utils import test_utils +from ai_edge_quantizer.utils import tfl_interpreter_utils _OpExecutionMode = qtyping.OpExecutionMode _OpName = qtyping.TFLOperationName @@ -39,11 +40,15 @@ def _get_dummy_data(num_samples): def _get_calibration_data(num_samples: int = 128): - return _get_dummy_data(num_samples) + calibration_samples = _get_dummy_data(num_samples) + calibration_data = { + tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples, + } + return calibration_data def _get_test_data(num_samples: int = 8): - return {'serving_default': _get_dummy_data(num_samples)} + return _get_calibration_data(num_samples) class TanhTest(parameterized.TestCase): diff --git a/ai_edge_quantizer/tests/end_to_end_tests/transpose_test.py b/ai_edge_quantizer/tests/end_to_end_tests/transpose_test.py index cd8f204..0545463 100644 --- a/ai_edge_quantizer/tests/end_to_end_tests/transpose_test.py +++ b/ai_edge_quantizer/tests/end_to_end_tests/transpose_test.py @@ -25,6 +25,7 @@ from ai_edge_quantizer import quantizer from ai_edge_quantizer.utils import test_utils from ai_edge_quantizer.utils import tfl_flatbuffer_utils +from ai_edge_quantizer.utils import tfl_interpreter_utils _OpExecutionMode = qtyping.OpExecutionMode _OpName = qtyping.TFLOperationName @@ -46,13 +47,17 @@ def _get_dummy_data( def _get_calibration_data( num_samples: int = 128, dtype: np.dtype = np.float32 ) -> list[dict[str, Any]]: - return _get_dummy_data(num_samples, dtype) + calibration_samples = _get_dummy_data(num_samples, dtype) + calibration_data = { + tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples, + } + return calibration_data def _get_test_data( num_samples: int = 8, dtype: np.dtype = np.float32 ) -> list[dict[str, Any]]: - return {'serving_default': _get_dummy_data(num_samples, dtype)} + return _get_calibration_data(num_samples, dtype) class FloatTransposeTest(parameterized.TestCase): diff --git a/ai_edge_quantizer/tests/mnist_test.py b/ai_edge_quantizer/tests/mnist_test.py index c3b447e..cf65e6d 100644 --- a/ai_edge_quantizer/tests/mnist_test.py +++ b/ai_edge_quantizer/tests/mnist_test.py @@ -22,6 +22,7 @@ from ai_edge_quantizer import qtyping from ai_edge_quantizer import quantizer from ai_edge_quantizer.utils import test_utils +from ai_edge_quantizer.utils import tfl_interpreter_utils _ComputePrecision = qtyping.ComputePrecision _OpName = qtyping.TFLOperationName @@ -41,7 +42,11 @@ def _get_dummy_data(num_samples): def _get_calibration_data(num_samples: int = 256): - return _get_dummy_data(num_samples) + calibration_samples = _get_dummy_data(num_samples) + calibration_data = { + tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples, + } + return calibration_data def _get_test_data(num_samples: int = 8): diff --git a/ai_edge_quantizer/tests/models/toy_model_with_kv_cache_multi_signature.tflite b/ai_edge_quantizer/tests/models/toy_model_with_kv_cache_multi_signature.tflite new file mode 100644 index 0000000..1ddd756 Binary files /dev/null and b/ai_edge_quantizer/tests/models/toy_model_with_kv_cache_multi_signature.tflite differ diff --git a/ai_edge_quantizer/utils/tfl_flatbuffer_utils.py b/ai_edge_quantizer/utils/tfl_flatbuffer_utils.py index 13d0733..9c5d17c 100644 --- a/ai_edge_quantizer/utils/tfl_flatbuffer_utils.py +++ b/ai_edge_quantizer/utils/tfl_flatbuffer_utils.py @@ -198,7 +198,7 @@ def buffer_to_tensors(flatbuffer_model: Any) -> dict[int, list[Any]]: the buffer """ buffer_to_tensor_map = {} - for subgraph in flatbuffer_model.subgraphs: + for subgraph in flatbuffer_model.subgraphs[0:2]: for op in subgraph.operators: for tensor in parse_op_tensors(op, subgraph.tensors): if tensor.buffer not in buffer_to_tensor_map: