From 7fe576e4303a4c99632add4dce4a57da9ad498f8 Mon Sep 17 00:00:00 2001 From: Maria Lyubimtseva Date: Thu, 7 Nov 2024 13:00:51 -0800 Subject: [PATCH] Ignore min_weight_elements during policy check PiperOrigin-RevId: 694220477 --- .../naive_min_max_quantize.py | 6 +++ .../conv2d_test.py | 9 +++- .../conv2d_transpose_test.py | 7 ++- .../depthwise_conv2d_test.py | 7 ++- .../fully_connected_test.py | 7 ++- .../naive_min_max_quantize_test.py | 22 ++++++++ .../utils/min_max_quantize_utils.py | 11 +++- .../utils/min_max_quantize_utils_test.py | 13 +++++ ai_edge_quantizer/recipe.py | 28 ++++++++++ ai_edge_quantizer/recipe_test.py | 53 +++++++++++++------ 10 files changed, 141 insertions(+), 22 deletions(-) diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py index 24dcd59..5104424 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize.py @@ -56,6 +56,12 @@ def check_op_quantization_config( " only), please set algorithm key as 'float_casting'." ) + if op_quant_config.min_weight_elements < 0: + raise ValueError( + f"min_weight_elements must be non-negative for op: {op_name} with" + f" config: {op_quant_config}." + ) + if op_quant_config.compute_precision in [ _ComputePrecision.INTEGER, _ComputePrecision.FLOAT, diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/conv2d_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/conv2d_test.py index e18e7b0..ef4928f 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/conv2d_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/conv2d_test.py @@ -174,12 +174,17 @@ def test_materialize_srq_conv2d_succeeds( expect_weights_quantized=False, ), dict( - testcase_name="weights_are_quantized", + testcase_name="weights_are_quantized_for_min_weight_elements_0", min_weight_elements=0, expect_weights_quantized=True, ), + dict( + testcase_name="weights_are_quantized_for_min_weight_elements_1", + min_weight_elements=1, + expect_weights_quantized=True, + ), ) - def test_materialize_conv2d_quantizes_weights_larger_than_min_weight_elements_for_w8_afp32_( + def test_materialize_conv2d_quantizes_weights_larger_than_min_weight_elements_for_w8_afp32( self, min_weight_elements, expect_weights_quantized ): self._test_materialize_fn_quantizes_weights_larger_than_min_weight_elements_for_w8_afp32( diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/conv2d_transpose_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/conv2d_transpose_test.py index ee38940..885c8ff 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/conv2d_transpose_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/conv2d_transpose_test.py @@ -179,10 +179,15 @@ def test_materialize_srq_conv2d_transpose_succeeds( expect_weights_quantized=False, ), dict( - testcase_name="weights_are_quantized", + testcase_name="weights_are_quantized_for_min_weight_elements_0", min_weight_elements=0, expect_weights_quantized=True, ), + dict( + testcase_name="weights_are_quantized_for_min_weight_elements_1", + min_weight_elements=1, + expect_weights_quantized=True, + ), ) def test_materialize_conv2d_transpose_quantizes_weights_larger_than_min_weight_elements_for_w8_afp32( self, min_weight_elements, expect_weights_quantized diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/depthwise_conv2d_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/depthwise_conv2d_test.py index 4f0e90b..e94898e 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/depthwise_conv2d_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/depthwise_conv2d_test.py @@ -166,10 +166,15 @@ def test_materialize_srq_depthwise_conv2d_succeeds( expect_weights_quantized=False, ), dict( - testcase_name="weights_are_quantized", + testcase_name="weights_are_quantized_for_min_weight_elements_0", min_weight_elements=0, expect_weights_quantized=True, ), + dict( + testcase_name="weights_are_quantized_for_min_weight_elements_1", + min_weight_elements=1, + expect_weights_quantized=True, + ), ) def _test_materialize_depthwise_conv2d_quantizes_weights_larger_than_min_weight_elements_for_w8_afp32( self, min_weight_elements, expect_weights_quantized diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/fully_connected_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/fully_connected_test.py index 52ecf5d..921eb5b 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/fully_connected_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_op_tests/fully_connected_test.py @@ -133,10 +133,15 @@ def test_materialize_fully_connected_succeeds( expect_weights_quantized=False, ), dict( - testcase_name="weights_are_quantized", + testcase_name="weights_are_quantized_for_min_weight_elements_0", min_weight_elements=0, expect_weights_quantized=True, ), + dict( + testcase_name="weights_are_quantized_for_min_weight_elements_1", + min_weight_elements=1, + expect_weights_quantized=True, + ), ) def test_materialize_fully_connected_quantizes_weights_larger_than_min_weight_elements_for_w8_afp32( self, min_weight_elements, expect_weights_quantized diff --git a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py index 227a7e6..20bd2ce 100644 --- a/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py +++ b/ai_edge_quantizer/algorithms/uniform_quantize/naive_min_max_quantize_test.py @@ -19,6 +19,7 @@ import numpy as np from tensorflow.python.platform import googletest +from ai_edge_quantizer import default_policy from ai_edge_quantizer import qtyping from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize from ai_edge_quantizer.utils import test_utils @@ -157,6 +158,27 @@ def test_min_max_calibrate(self): self.assertNotIn("arith.constant1", op_qsvs) self.assertNotIn("arith.constant2", op_qsvs) + def test_check_op_quantization_config_with_negative_min_weight_elements_raises_error( + self, + ): + op_quant_config = qtyping.OpQuantizationConfig( + weight_tensor_config=_TensorQuantConfig( + num_bits=8, + granularity=qtyping.QuantGranularity.CHANNELWISE, + ), + compute_precision=qtyping.ComputePrecision.INTEGER, # DRQ. + min_weight_elements=-1, + ) + with self.assertRaisesWithPredicateMatch( + ValueError, + lambda err: "min_weight_elements must be non-negative" in str(err), + ): + naive_min_max_quantize.check_op_quantization_config( + _TFLOpName.FULLY_CONNECTED, + op_quant_config, + default_policy.DEFAULT_CONFIG_CHECK_POLICY, + ) + if __name__ == "__main__": googletest.main() diff --git a/ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py b/ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py index 25f39b7..4345f28 100644 --- a/ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py +++ b/ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py @@ -16,6 +16,7 @@ """Utils for min/max based quantization.""" from collections.abc import Sequence +import dataclasses import enum from typing import Any, Optional import numpy as np @@ -103,7 +104,15 @@ def check_if_valid_op_config( f"No policy was specified for op: {op_name} with config:" f" {op_quant_config}." ) - elif op_quant_config not in config_check_policy[op_name]: + # The config_check_policy contains all possible valid configs, except for + # variations in the min_weight_elements field (it's set to 0 for all of them). + # min_weight_elements has to be ignored during policy check here because it + # can be any non-negative integer, which means we can't list all possible + # values in the policy. + elif ( + dataclasses.replace(op_quant_config, min_weight_elements=0) + not in config_check_policy[op_name] + ): error_msg = ( f"Quantization config for op: {op_name} with config:" f" {op_quant_config} was not found in the policy." diff --git a/ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py b/ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py index f400102..a88f65f 100644 --- a/ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py +++ b/ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py @@ -211,6 +211,19 @@ def test_check_drq_config_asymmetric_weights_raise_error(self, op_name): op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY ) + def test_check_drq_config_with_non_default_min_weight_elements_succeeds(self): + op_quant_config = _OpQuantConfig( + weight_tensor_config=_TensorQuantConfig( + num_bits=8, + granularity=qtyping.QuantGranularity.CHANNELWISE, + ), + compute_precision=_ComputePrecision.INTEGER, # DRQ. + min_weight_elements=100, + ) + min_max_quantize_utils.check_if_valid_op_config( + _TFLOpName.CONV_2D, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY + ) + @parameterized.product( op_name=(_TFLOpName.FULLY_CONNECTED, _TFLOpName.CONV_2D), act_num_bits=(8, 16), diff --git a/ai_edge_quantizer/recipe.py b/ai_edge_quantizer/recipe.py index 69df3f8..f979d83 100644 --- a/ai_edge_quantizer/recipe.py +++ b/ai_edge_quantizer/recipe.py @@ -37,3 +37,31 @@ def dynamic_wi8_afp32(): }, }) ] + + +def dynamic_legacy_wi8_afp32(): + """Returns a dynamic quantization legacy recipe with int8 weights and float32 activation. + + The difference between this and dynamic_wi8_afp32 is that this recipe sets + min_weight_elements to 1024 to match the old quantizer behavior. + """ + return [ + dict({ + 'regex': '.*', + 'operation': '*', + 'algorithm_key': 'min_max_uniform_quantize', + 'op_config': { + 'weight_tensor_config': { + 'num_bits': 8, + 'symmetric': True, + 'granularity': 'CHANNELWISE', + 'dtype': 'INT', + 'block_size': 0, + }, + 'compute_precision': 'INTEGER', + 'explicit_dequantize': False, + 'skip_checks': False, + 'min_weight_elements': 1024, + }, + }) + ] diff --git a/ai_edge_quantizer/recipe_test.py b/ai_edge_quantizer/recipe_test.py index 77e6b7a..4cd8617 100644 --- a/ai_edge_quantizer/recipe_test.py +++ b/ai_edge_quantizer/recipe_test.py @@ -15,15 +15,18 @@ import os +from absl.testing import parameterized + from tensorflow.python.platform import googletest from ai_edge_quantizer import quantizer from ai_edge_quantizer import recipe from ai_edge_quantizer.utils import test_utils + _TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('') -class RecipeTest(googletest.TestCase): +class RecipeTest(parameterized.TestCase): def setUp(self): super().setUp() @@ -31,33 +34,51 @@ def setUp(self): _TEST_DATA_PREFIX_PATH, 'tests/models/single_conv2d_transpose_bias.tflite', ) - self._test_json_recipe_path = os.path.join( - _TEST_DATA_PREFIX_PATH, - 'recipes/dynamic_wi8_afp32_recipe.json', - ) - def test_quantization_from_dynamic_wi8_afp32_func_succeeds(self): + def _quantize_with_recipe_func(self, recipe_func): qt = quantizer.Quantizer(self._test_model_path) - qt.load_quantization_recipe(recipe.dynamic_wi8_afp32()) + qt.load_quantization_recipe(recipe_func()) self.assertIsNone(qt._result.quantized_model) quant_result = qt.quantize() self.assertIsNotNone(quant_result.quantized_model) + return quant_result + + def test_quantization_from_dynamic_wi8_afp32_func_succeeds(self): + quant_result = self._quantize_with_recipe_func(recipe.dynamic_wi8_afp32) self.assertLess( len(quant_result.quantized_model), os.path.getsize(self._test_model_path), ) - def test_dynamic_wi8_afp32_func_and_json_matches(self): - # Quantize with dynamic_wi8_afp32() from recipe module. - qt_func = quantizer.Quantizer(self._test_model_path) - qt_func.load_quantization_recipe(recipe.dynamic_wi8_afp32()) - self.assertIsNone(qt_func._result.quantized_model) - quant_result_from_func = qt_func.quantize() - self.assertIsNotNone(quant_result_from_func.quantized_model) + def test_quantization_from_dynamic_legacy_wi8_afp32_func_succeeds(self): + quant_result = self._quantize_with_recipe_func( + recipe.dynamic_legacy_wi8_afp32 + ) + self.assertLen( + quant_result.quantized_model, + os.path.getsize(self._test_model_path), + ) + + @parameterized.named_parameters( + dict( + testcase_name='dynamic_wi8_afp32', + recipe_json_path='recipes/dynamic_wi8_afp32_recipe.json', + recipe_func=recipe.dynamic_wi8_afp32, + ), + dict( + testcase_name='dynamic_legacy_wi8_afp32', + recipe_json_path='recipes/dynamic_legacy_wi8_afp32_recipe.json', + recipe_func=recipe.dynamic_legacy_wi8_afp32, + ), + ) + def test_recipe_func_and_json_matches(self, recipe_json_path, recipe_func): + # Quantize with recipe from function in recipe module. + quant_result_from_func = self._quantize_with_recipe_func(recipe_func) - # Quantize with dynamic_wi8_afp32_recipe.json. + # Quantize with recipe from json file. qt_json = quantizer.Quantizer(self._test_model_path) - qt_json.load_quantization_recipe(self._test_json_recipe_path) + json_recipe_path = os.path.join(_TEST_DATA_PREFIX_PATH, recipe_json_path) + qt_json.load_quantization_recipe(json_recipe_path) quant_result_from_json = qt_json.quantize() self.assertIsNotNone(quant_result_from_json.quantized_model)