Skip to content

Commit

Permalink
Ignore min_weight_elements during policy check
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 694220477
  • Loading branch information
marialyu authored and copybara-github committed Nov 8, 2024
1 parent 0abada8 commit 7fe576e
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def check_op_quantization_config(
" only), please set algorithm key as 'float_casting'."
)

if op_quant_config.min_weight_elements < 0:
raise ValueError(
f"min_weight_elements must be non-negative for op: {op_name} with"
f" config: {op_quant_config}."
)

if op_quant_config.compute_precision in [
_ComputePrecision.INTEGER,
_ComputePrecision.FLOAT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,17 @@ def test_materialize_srq_conv2d_succeeds(
expect_weights_quantized=False,
),
dict(
testcase_name="weights_are_quantized",
testcase_name="weights_are_quantized_for_min_weight_elements_0",
min_weight_elements=0,
expect_weights_quantized=True,
),
dict(
testcase_name="weights_are_quantized_for_min_weight_elements_1",
min_weight_elements=1,
expect_weights_quantized=True,
),
)
def test_materialize_conv2d_quantizes_weights_larger_than_min_weight_elements_for_w8_afp32_(
def test_materialize_conv2d_quantizes_weights_larger_than_min_weight_elements_for_w8_afp32(
self, min_weight_elements, expect_weights_quantized
):
self._test_materialize_fn_quantizes_weights_larger_than_min_weight_elements_for_w8_afp32(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,15 @@ def test_materialize_srq_conv2d_transpose_succeeds(
expect_weights_quantized=False,
),
dict(
testcase_name="weights_are_quantized",
testcase_name="weights_are_quantized_for_min_weight_elements_0",
min_weight_elements=0,
expect_weights_quantized=True,
),
dict(
testcase_name="weights_are_quantized_for_min_weight_elements_1",
min_weight_elements=1,
expect_weights_quantized=True,
),
)
def test_materialize_conv2d_transpose_quantizes_weights_larger_than_min_weight_elements_for_w8_afp32(
self, min_weight_elements, expect_weights_quantized
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,15 @@ def test_materialize_srq_depthwise_conv2d_succeeds(
expect_weights_quantized=False,
),
dict(
testcase_name="weights_are_quantized",
testcase_name="weights_are_quantized_for_min_weight_elements_0",
min_weight_elements=0,
expect_weights_quantized=True,
),
dict(
testcase_name="weights_are_quantized_for_min_weight_elements_1",
min_weight_elements=1,
expect_weights_quantized=True,
),
)
def _test_materialize_depthwise_conv2d_quantizes_weights_larger_than_min_weight_elements_for_w8_afp32(
self, min_weight_elements, expect_weights_quantized
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,15 @@ def test_materialize_fully_connected_succeeds(
expect_weights_quantized=False,
),
dict(
testcase_name="weights_are_quantized",
testcase_name="weights_are_quantized_for_min_weight_elements_0",
min_weight_elements=0,
expect_weights_quantized=True,
),
dict(
testcase_name="weights_are_quantized_for_min_weight_elements_1",
min_weight_elements=1,
expect_weights_quantized=True,
),
)
def test_materialize_fully_connected_quantizes_weights_larger_than_min_weight_elements_for_w8_afp32(
self, min_weight_elements, expect_weights_quantized
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np

from tensorflow.python.platform import googletest
from ai_edge_quantizer import default_policy
from ai_edge_quantizer import qtyping
from ai_edge_quantizer.algorithms.uniform_quantize import naive_min_max_quantize
from ai_edge_quantizer.utils import test_utils
Expand Down Expand Up @@ -157,6 +158,27 @@ def test_min_max_calibrate(self):
self.assertNotIn("arith.constant1", op_qsvs)
self.assertNotIn("arith.constant2", op_qsvs)

def test_check_op_quantization_config_with_negative_min_weight_elements_raises_error(
self,
):
op_quant_config = qtyping.OpQuantizationConfig(
weight_tensor_config=_TensorQuantConfig(
num_bits=8,
granularity=qtyping.QuantGranularity.CHANNELWISE,
),
compute_precision=qtyping.ComputePrecision.INTEGER, # DRQ.
min_weight_elements=-1,
)
with self.assertRaisesWithPredicateMatch(
ValueError,
lambda err: "min_weight_elements must be non-negative" in str(err),
):
naive_min_max_quantize.check_op_quantization_config(
_TFLOpName.FULLY_CONNECTED,
op_quant_config,
default_policy.DEFAULT_CONFIG_CHECK_POLICY,
)


if __name__ == "__main__":
googletest.main()
11 changes: 10 additions & 1 deletion ai_edge_quantizer/algorithms/utils/min_max_quantize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Utils for min/max based quantization."""

from collections.abc import Sequence
import dataclasses
import enum
from typing import Any, Optional
import numpy as np
Expand Down Expand Up @@ -103,7 +104,15 @@ def check_if_valid_op_config(
f"No policy was specified for op: {op_name} with config:"
f" {op_quant_config}."
)
elif op_quant_config not in config_check_policy[op_name]:
# The config_check_policy contains all possible valid configs, except for
# variations in the min_weight_elements field (it's set to 0 for all of them).
# min_weight_elements has to be ignored during policy check here because it
# can be any non-negative integer, which means we can't list all possible
# values in the policy.
elif (
dataclasses.replace(op_quant_config, min_weight_elements=0)
not in config_check_policy[op_name]
):
error_msg = (
f"Quantization config for op: {op_name} with config:"
f" {op_quant_config} was not found in the policy."
Expand Down
13 changes: 13 additions & 0 deletions ai_edge_quantizer/algorithms/utils/min_max_quantize_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,19 @@ def test_check_drq_config_asymmetric_weights_raise_error(self, op_name):
op_name, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
)

def test_check_drq_config_with_non_default_min_weight_elements_succeeds(self):
op_quant_config = _OpQuantConfig(
weight_tensor_config=_TensorQuantConfig(
num_bits=8,
granularity=qtyping.QuantGranularity.CHANNELWISE,
),
compute_precision=_ComputePrecision.INTEGER, # DRQ.
min_weight_elements=100,
)
min_max_quantize_utils.check_if_valid_op_config(
_TFLOpName.CONV_2D, op_quant_config, _DEFAULT_CONFIG_CHECK_POLICY
)

@parameterized.product(
op_name=(_TFLOpName.FULLY_CONNECTED, _TFLOpName.CONV_2D),
act_num_bits=(8, 16),
Expand Down
28 changes: 28 additions & 0 deletions ai_edge_quantizer/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,31 @@ def dynamic_wi8_afp32():
},
})
]


def dynamic_legacy_wi8_afp32():
"""Returns a dynamic quantization legacy recipe with int8 weights and float32 activation.
The difference between this and dynamic_wi8_afp32 is that this recipe sets
min_weight_elements to 1024 to match the old quantizer behavior.
"""
return [
dict({
'regex': '.*',
'operation': '*',
'algorithm_key': 'min_max_uniform_quantize',
'op_config': {
'weight_tensor_config': {
'num_bits': 8,
'symmetric': True,
'granularity': 'CHANNELWISE',
'dtype': 'INT',
'block_size': 0,
},
'compute_precision': 'INTEGER',
'explicit_dequantize': False,
'skip_checks': False,
'min_weight_elements': 1024,
},
})
]
53 changes: 37 additions & 16 deletions ai_edge_quantizer/recipe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,49 +15,70 @@

import os

from absl.testing import parameterized

from tensorflow.python.platform import googletest
from ai_edge_quantizer import quantizer
from ai_edge_quantizer import recipe
from ai_edge_quantizer.utils import test_utils


_TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('')


class RecipeTest(googletest.TestCase):
class RecipeTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self._test_model_path = os.path.join(
_TEST_DATA_PREFIX_PATH,
'tests/models/single_conv2d_transpose_bias.tflite',
)
self._test_json_recipe_path = os.path.join(
_TEST_DATA_PREFIX_PATH,
'recipes/dynamic_wi8_afp32_recipe.json',
)

def test_quantization_from_dynamic_wi8_afp32_func_succeeds(self):
def _quantize_with_recipe_func(self, recipe_func):
qt = quantizer.Quantizer(self._test_model_path)
qt.load_quantization_recipe(recipe.dynamic_wi8_afp32())
qt.load_quantization_recipe(recipe_func())
self.assertIsNone(qt._result.quantized_model)
quant_result = qt.quantize()
self.assertIsNotNone(quant_result.quantized_model)
return quant_result

def test_quantization_from_dynamic_wi8_afp32_func_succeeds(self):
quant_result = self._quantize_with_recipe_func(recipe.dynamic_wi8_afp32)
self.assertLess(
len(quant_result.quantized_model),
os.path.getsize(self._test_model_path),
)

def test_dynamic_wi8_afp32_func_and_json_matches(self):
# Quantize with dynamic_wi8_afp32() from recipe module.
qt_func = quantizer.Quantizer(self._test_model_path)
qt_func.load_quantization_recipe(recipe.dynamic_wi8_afp32())
self.assertIsNone(qt_func._result.quantized_model)
quant_result_from_func = qt_func.quantize()
self.assertIsNotNone(quant_result_from_func.quantized_model)
def test_quantization_from_dynamic_legacy_wi8_afp32_func_succeeds(self):
quant_result = self._quantize_with_recipe_func(
recipe.dynamic_legacy_wi8_afp32
)
self.assertLen(
quant_result.quantized_model,
os.path.getsize(self._test_model_path),
)

@parameterized.named_parameters(
dict(
testcase_name='dynamic_wi8_afp32',
recipe_json_path='recipes/dynamic_wi8_afp32_recipe.json',
recipe_func=recipe.dynamic_wi8_afp32,
),
dict(
testcase_name='dynamic_legacy_wi8_afp32',
recipe_json_path='recipes/dynamic_legacy_wi8_afp32_recipe.json',
recipe_func=recipe.dynamic_legacy_wi8_afp32,
),
)
def test_recipe_func_and_json_matches(self, recipe_json_path, recipe_func):
# Quantize with recipe from function in recipe module.
quant_result_from_func = self._quantize_with_recipe_func(recipe_func)

# Quantize with dynamic_wi8_afp32_recipe.json.
# Quantize with recipe from json file.
qt_json = quantizer.Quantizer(self._test_model_path)
qt_json.load_quantization_recipe(self._test_json_recipe_path)
json_recipe_path = os.path.join(_TEST_DATA_PREFIX_PATH, recipe_json_path)
qt_json.load_quantization_recipe(json_recipe_path)
quant_result_from_json = qt_json.quantize()
self.assertIsNotNone(quant_result_from_json.quantized_model)

Expand Down

0 comments on commit 7fe576e

Please sign in to comment.