Skip to content

Commit

Permalink
Add dynamic_legacy_wi8_afp32 recipe to recipe module
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 694578492
  • Loading branch information
marialyu authored and copybara-github committed Nov 8, 2024
1 parent 0abada8 commit 4cb52d2
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 16 deletions.
28 changes: 28 additions & 0 deletions ai_edge_quantizer/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,31 @@ def dynamic_wi8_afp32():
},
})
]


def dynamic_legacy_wi8_afp32():
"""Returns a dynamic quantization legacy recipe with int8 weights and float32 activation.
The difference between this and dynamic_wi8_afp32 is that this recipe sets
min_weight_elements to 1024 to match the old quantizer behavior.
"""
return [
dict({
'regex': '.*',
'operation': '*',
'algorithm_key': 'min_max_uniform_quantize',
'op_config': {
'weight_tensor_config': {
'num_bits': 8,
'symmetric': True,
'granularity': 'CHANNELWISE',
'dtype': 'INT',
'block_size': 0,
},
'compute_precision': 'INTEGER',
'explicit_dequantize': False,
'skip_checks': False,
'min_weight_elements': 1024,
},
})
]
53 changes: 37 additions & 16 deletions ai_edge_quantizer/recipe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,49 +15,70 @@

import os

from absl.testing import parameterized

from tensorflow.python.platform import googletest
from ai_edge_quantizer import quantizer
from ai_edge_quantizer import recipe
from ai_edge_quantizer.utils import test_utils


_TEST_DATA_PREFIX_PATH = test_utils.get_path_to_datafile('')


class RecipeTest(googletest.TestCase):
class RecipeTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self._test_model_path = os.path.join(
_TEST_DATA_PREFIX_PATH,
'tests/models/single_conv2d_transpose_bias.tflite',
)
self._test_json_recipe_path = os.path.join(
_TEST_DATA_PREFIX_PATH,
'recipes/dynamic_wi8_afp32_recipe.json',
)

def test_quantization_from_dynamic_wi8_afp32_func_succeeds(self):
def _quantize_with_recipe_func(self, recipe_func):
qt = quantizer.Quantizer(self._test_model_path)
qt.load_quantization_recipe(recipe.dynamic_wi8_afp32())
qt.load_quantization_recipe(recipe_func())
self.assertIsNone(qt._result.quantized_model)
quant_result = qt.quantize()
self.assertIsNotNone(quant_result.quantized_model)
return quant_result

def test_quantization_from_dynamic_wi8_afp32_func_succeeds(self):
quant_result = self._quantize_with_recipe_func(recipe.dynamic_wi8_afp32)
self.assertLess(
len(quant_result.quantized_model),
os.path.getsize(self._test_model_path),
)

def test_dynamic_wi8_afp32_func_and_json_matches(self):
# Quantize with dynamic_wi8_afp32() from recipe module.
qt_func = quantizer.Quantizer(self._test_model_path)
qt_func.load_quantization_recipe(recipe.dynamic_wi8_afp32())
self.assertIsNone(qt_func._result.quantized_model)
quant_result_from_func = qt_func.quantize()
self.assertIsNotNone(quant_result_from_func.quantized_model)
def test_quantization_from_dynamic_legacy_wi8_afp32_func_succeeds(self):
quant_result = self._quantize_with_recipe_func(
recipe.dynamic_legacy_wi8_afp32
)
self.assertLen(
quant_result.quantized_model,
os.path.getsize(self._test_model_path),
)

@parameterized.named_parameters(
dict(
testcase_name='dynamic_wi8_afp32',
recipe_json_path='recipes/dynamic_wi8_afp32_recipe.json',
recipe_func=recipe.dynamic_wi8_afp32,
),
dict(
testcase_name='dynamic_legacy_wi8_afp32',
recipe_json_path='recipes/dynamic_legacy_wi8_afp32_recipe.json',
recipe_func=recipe.dynamic_legacy_wi8_afp32,
),
)
def test_recipe_func_and_json_matches(self, recipe_json_path, recipe_func):
# Quantize with recipe from function in recipe module.
quant_result_from_func = self._quantize_with_recipe_func(recipe_func)

# Quantize with dynamic_wi8_afp32_recipe.json.
# Quantize with recipe from json file.
qt_json = quantizer.Quantizer(self._test_model_path)
qt_json.load_quantization_recipe(self._test_json_recipe_path)
json_recipe_path = os.path.join(_TEST_DATA_PREFIX_PATH, recipe_json_path)
qt_json.load_quantization_recipe(json_recipe_path)
quant_result_from_json = qt_json.quantize()
self.assertIsNotNone(quant_result_from_json.quantized_model)

Expand Down

0 comments on commit 4cb52d2

Please sign in to comment.