Skip to content

Commit

Permalink
Implement multi-signature calibration
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 682034001
  • Loading branch information
v-dziuba authored and copybara-github committed Oct 25, 2024
1 parent 1d6688a commit f4cae14
Show file tree
Hide file tree
Showing 15 changed files with 285 additions and 60 deletions.
52 changes: 28 additions & 24 deletions ai_edge_quantizer/calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from collections.abc import Callable, Iterable
import copy
from typing import Any, Optional, Union
from typing import Any, Union

from absl import logging
import numpy as np
Expand Down Expand Up @@ -62,9 +62,8 @@ def __init__(
# TODO(b/330740605)- Collect multiple QSVs in one run to save compute.
def calibrate(
self,
calibration_dataset: Iterable[_SignatureInput],
calibration_dataset: dict[str, Iterable[_SignatureInput]],
model_recipe_manager: recipe_manager.RecipeManager,
signature_key: Optional[str] = None,
cache_output: bool = False,
qsv_update_func: Callable[
[qtyping.QSV, qtyping.QSV],
Expand All @@ -87,13 +86,10 @@ def calibrate(
6. Start another round of calibration.
Args:
calibration_dataset: A list of input data for calibration for the given
model signature.
calibration_dataset: A dictionary of input data for calibration for the
given model signature.
model_recipe_manager: A RecipeManager object that contains the
quantization recipe.
signature_key: The signature key to be used for invoking the models. If
the model doesn't have a signature key (or only has one ), this can be
set to None.
cache_output: Whether to cache the output of the model during the
calibration process. This is useful if there are dependencies between
signatures/models (e.g., decode requires encode output).
Expand All @@ -111,23 +107,31 @@ def calibrate(
)

# TODO: b/329322226 - Enable parrallel calibration.
for data in calibration_dataset:
# Initialize tensor names that are updated in this round of calibration.
updated_tensor_names = set()

# Step1: run tfl interpreter to get tensor content.
signature_output = tfl_interpreter_utils.invoke_interpreter_signature(
self._tfl_interpreter, data, signature_key
)
if cache_output:
self._cached_output.append(signature_output)
self._tensor_content_map = (
tfl_interpreter_utils.get_tensor_name_to_content_map(
self._tfl_interpreter
)
for signature_key, dataset in calibration_dataset.items():
# Step0: get subgraph index.
subgraph_idx = tfl_interpreter_utils.get_signature_main_subgraph_index(
self._tfl_interpreter, signature_key
)
# Step2: go through each op to update quantization statistic values.
for subgraph in self._flatbuffer_model.subgraphs:

for data in dataset:
# Initialize tensor names that are updated in this round of calibration.
updated_tensor_names = set()

# Step1: run tfl interpreter on subgraph to get tensor content.
signature_output = tfl_interpreter_utils.invoke_interpreter_signature(
self._tfl_interpreter, data, signature_key
)
if cache_output:
self._cached_output.append(signature_output)
self._tensor_content_map.update(
tfl_interpreter_utils.get_tensor_name_to_content_map(
self._tfl_interpreter, subgraph_idx
)
)

# Step2: go through each op in subgraph to update quantization
# statistic values.
subgraph = self._flatbuffer_model.subgraphs[subgraph_idx]
graph_info = qtyping.GraphInfo(
subgraph.tensors, self._flatbuffer_model.buffers
)
Expand Down
64 changes: 62 additions & 2 deletions ai_edge_quantizer/calibrator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

"""Tests for calibrator."""

from collections.abc import Generator
import os
from typing import Any, Dict

import numpy as np

Expand All @@ -24,6 +26,7 @@
from ai_edge_quantizer import qtyping
from ai_edge_quantizer import recipe_manager
from ai_edge_quantizer.utils import test_utils
from ai_edge_quantizer.utils import tfl_interpreter_utils

_ComputePrecision = qtyping.ComputePrecision
_AlgorithmName = recipe_manager.AlgorithmName
Expand All @@ -33,6 +36,8 @@

TEST_MIN_VAL, TEST_MAX_VAL = -1, 1

_RNG = np.random.default_rng(66)


def _representative_dataset_gen(size=(1, 8), num_samples=10):
for _ in range(num_samples):
Expand All @@ -44,6 +49,14 @@ def _representative_dataset_gen(size=(1, 8), num_samples=10):
yield {"input_1": vals}


def _get_calibration_data(dataset_gen: Generator[Dict[str, Any], Any, None]):
calibration_samples = [sample for sample in dataset_gen]
calibration_data = {
tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples,
}
return calibration_data


def _add_default_int8xint8_integer_recipe(recipe_manager_object):
recipe_manager_object.add_quantization_config(
regex=".*",
Expand All @@ -69,7 +82,8 @@ def setUp(self):
)
self._calibrator = calibrator.Calibrator(self._test_model_path)
self._recipe_manager = recipe_manager.RecipeManager()
self._representative_dataset = _representative_dataset_gen()
dataset_gen = _representative_dataset_gen()
self._representative_dataset = _get_calibration_data(dataset_gen)

def test_calibrator_state_manipulation(self):
# load/get qsvs
Expand Down Expand Up @@ -204,8 +218,9 @@ def test_calibrate_unsupported_ops_success(self):
)
test_calibrator = calibrator.Calibrator(test_model_path)
_add_default_int8xint8_integer_recipe(self._recipe_manager)
dataset_gen = _representative_dataset_gen(size=(3, 4, 4, 1))
test_calibrator.calibrate(
_representative_dataset_gen(size=(3, 4, 4, 1)),
_get_calibration_data(dataset_gen),
self._recipe_manager,
cache_output=True,
)
Expand All @@ -231,5 +246,50 @@ def test_check_is_float_model_raises_error_when_model_is_quantized(self):
_ = calibrator.Calibrator(test_model_path)


class CalibratorToyGemma2Test(googletest.TestCase):

def setUp(self):
super().setUp()
np.random.seed(0)

self._test_model_path = os.path.join(
TEST_DATA_PREFIX_PATH,
"tests/models/toy_model_with_kv_cache_multi_signature.tflite",
)

self._toy_gemma2_calibration_dataset = {
"signature_1": [{
"cache_0": _RNG.random(size=(1, 100, 4, 4)).astype(np.float32),
"cache_1": _RNG.random(size=(1, 100, 4, 4)).astype(np.float32),
"positions": _RNG.integers(low=0, high=10, size=(1, 100)).astype(
np.int32
),
"tokens": _RNG.integers(low=0, high=10, size=(1, 100)).astype(
np.int32
),
}],
"signature_2": [{
"cache_0": _RNG.random(size=(1, 100, 4, 4)).astype(np.float32),
"cache_1": _RNG.random(size=(1, 100, 4, 4)).astype(np.float32),
"positions": _RNG.integers(low=0, high=10, size=(1, 100)).astype(
np.int32
),
"tokens": _RNG.integers(low=0, high=10, size=(1, 100)).astype(
np.int32
),
}],
}

def test_toy_gemma2_calibration_success(self):
calib = calibrator.Calibrator(self._test_model_path)
recipe_mngr = recipe_manager.RecipeManager()
_add_default_int8xint8_integer_recipe(recipe_mngr)
calib.calibrate(
self._toy_gemma2_calibration_dataset,
model_recipe_manager=recipe_mngr,
)
self.assertLen(calib.get_model_qsvs(), 260)


if __name__ == "__main__":
googletest.main()
21 changes: 17 additions & 4 deletions ai_edge_quantizer/params_generator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

"""Tests for params_generator."""

from collections.abc import Generator
import os
from typing import Any, Dict

from absl.testing import parameterized
import numpy as np
Expand All @@ -27,6 +29,7 @@
from ai_edge_quantizer import recipe_manager
from ai_edge_quantizer.utils import test_utils
from ai_edge_quantizer.utils import tfl_flatbuffer_utils
from ai_edge_quantizer.utils import tfl_interpreter_utils


_ComputePrecision = qtyping.ComputePrecision
Expand All @@ -51,6 +54,14 @@ def _int_transpose_model_representative_dataset_gen(num_samples=5):
return data


def _get_calibration_data(dataset_gen: Generator[Dict[str, Any], Any, None]):
calibration_samples = [sample for sample in dataset_gen]
calibration_data = {
tfl_interpreter_utils.DEFAULT_SIGNATURE_KEY: calibration_samples,
}
return calibration_data


class ParamsGeneratorTest(parameterized.TestCase):

def setUp(self):
Expand Down Expand Up @@ -444,9 +455,10 @@ def test_generate_config_int8xint8_single_fc(

# Calibrate then quantize
model_calibrator = calibrator.Calibrator(single_fc_model_path)
model_calibrator.calibrate(
_single_fc_model_representative_dataset_gen(), self._recipe_manager
calibration_data = _get_calibration_data(
_single_fc_model_representative_dataset_gen()
)
model_calibrator.calibrate(calibration_data, self._recipe_manager)
model_qsvs = model_calibrator.get_model_qsvs()
quant_params = params_generator_single_fc.generate_quantization_parameters(
self._recipe_manager,
Expand Down Expand Up @@ -905,9 +917,10 @@ def test_quantize_integer_input_output(self):

# Calibrate then quantize.
model_calibrator = calibrator.Calibrator(model_path)
model_calibrator.calibrate(
_int_transpose_model_representative_dataset_gen(), self._recipe_manager
calibration_data = _get_calibration_data(
_int_transpose_model_representative_dataset_gen()
)
model_calibrator.calibrate(calibration_data, self._recipe_manager)
model_qsvs = model_calibrator.get_model_qsvs()
quant_params = pg.generate_quantization_parameters(
self._recipe_manager,
Expand Down
7 changes: 2 additions & 5 deletions ai_edge_quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,16 +213,13 @@ def need_calibration(self) -> bool:

def calibrate(
self,
calibration_data: Iterable[_SignatureInput],
signature_key: Optional[str] = None,
calibration_data: dict[str, Iterable[_SignatureInput]],
previous_calibration_result: Optional[_CalibrationResult] = None,
) -> _CalibrationResult:
"""Calibrates the float model (required by static range quantization).
Args:
calibration_data: Calibration data for a model signature.
signature_key: The signature key to be used for invoking the models. If
the model doesn't have a signature key, this can be set to None.
previous_calibration_result: Previous calibration result to be loaded. The
calibration process will be resumed from the previous result.
Expand All @@ -235,7 +232,7 @@ def calibrate(
calib = calibrator.Calibrator(self.float_model)
if previous_calibration_result is not None:
calib.load_model_qsvs(previous_calibration_result)
calib.calibrate(calibration_data, self._recipe_manager, signature_key)
calib.calibrate(calibration_data, self._recipe_manager)
return calib.get_model_qsvs()

def quantize(
Expand Down
Loading

0 comments on commit f4cae14

Please sign in to comment.