diff --git a/ai_edge_torch/generative/README.md b/ai_edge_torch/generative/README.md index eaeece2a..6b0ca7a4 100644 --- a/ai_edge_torch/generative/README.md +++ b/ai_edge_torch/generative/README.md @@ -34,7 +34,7 @@ Quantization can be done via the API exposed in [quantize](quantize/). To apply `quant_recipes.py` contains a list of recipes that are known to be well-supported during runtime. For the average user, this is a good starting point to select the quantization scheme that is best suited for your deployment needs. After identifying the target recipe, the model can be quantized as follows. This example is extracted from `generative/examples/quantize/example.py`. ``` -quant_config = quant_recipes.full_linear_int8_dynamic_recipe() +quant_config = quant_recipes.full_int8_dynamic_recipe() edge_model = ai_edge_torch.convert( model, (tokens, input_pos), quant_config=quant_config ) diff --git a/ai_edge_torch/generative/examples/README.md b/ai_edge_torch/generative/examples/README.md index 87491208..6a85f8ca 100644 --- a/ai_edge_torch/generative/examples/README.md +++ b/ai_edge_torch/generative/examples/README.md @@ -72,7 +72,7 @@ To apply quantization, we need to create a configuration that fully expresses ho `quantize/quant_recipes.py` contains a list of recipes that are known to be well-supported during runtime. For the average user, this is a good starting point to select the quantization scheme that is best suited for your deployment needs. After identifying the target recipe, the model can be quantized as follows. This example is extracted from `generative/examples/quantize/example.py`. ``` -quant_config = quant_recipes.full_linear_int8_dynamic_recipe() +quant_config = quant_recipes.full_int8_dynamic_recipe() edge_model = ai_edge_torch.convert( model, (tokens, input_pos), quant_config=quant_config ) diff --git a/ai_edge_torch/generative/examples/gemma/convert_to_tflite.py b/ai_edge_torch/generative/examples/gemma/convert_to_tflite.py index 1a2c4925..741f50b8 100644 --- a/ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/gemma/convert_to_tflite.py @@ -50,7 +50,7 @@ def convert_gemma_to_tflite( decode_token = torch.tensor([[0]], dtype=torch.long) decode_input_pos = torch.tensor([0], dtype=torch.int64) - quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None + quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None edge_model = ( ai_edge_torch.signature( 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos) diff --git a/ai_edge_torch/generative/examples/phi2/convert_to_tflite.py b/ai_edge_torch/generative/examples/phi2/convert_to_tflite.py index f6387554..ed5e650d 100644 --- a/ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/phi2/convert_to_tflite.py @@ -48,7 +48,7 @@ def convert_phi2_to_tflite( decode_token = torch.tensor([[0]], dtype=torch.long) decode_input_pos = torch.tensor([0], dtype=torch.int64) - quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None + quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None edge_model = ( ai_edge_torch.signature( 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos) diff --git a/ai_edge_torch/generative/examples/t5/convert_to_tflite.py b/ai_edge_torch/generative/examples/t5/convert_to_tflite.py index 3b49a7c3..ef1047f3 100644 --- a/ai_edge_torch/generative/examples/t5/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/t5/convert_to_tflite.py @@ -101,7 +101,7 @@ def convert_t5_to_tflite_multisig(checkpoint_path: str): # Pad with `-inf` for any tokens indices that aren't desired. pad_mask = torch.zeros([seq_len], dtype=torch.float32) hidden_states = torch.zeros((1, 512, 768), dtype=torch.float32) - quant_config = quant_recipes.full_linear_int8_dynamic_recipe() + quant_config = quant_recipes.full_int8_dynamic_recipe() edge_model = ( ai_edge_torch.signature( diff --git a/ai_edge_torch/generative/examples/t5/t5_conversion_colab.ipynb b/ai_edge_torch/generative/examples/t5/t5_conversion_colab.ipynb index 93036cd9..54fdb9c0 100644 --- a/ai_edge_torch/generative/examples/t5/t5_conversion_colab.ipynb +++ b/ai_edge_torch/generative/examples/t5/t5_conversion_colab.ipynb @@ -145,7 +145,7 @@ " # Pad with `-inf` for any tokens indices that aren't desired.\n", " pad_mask = torch.zeros([seq_len], dtype=torch.float32)\n", " hidden_states = torch.zeros((1, 512, 768), dtype=torch.float32)\n", - " quant_config = quant_recipes.full_linear_int8_dynamic_recipe()\n", + " quant_config = quant_recipes.full_int8_dynamic_recipe()\n", "\n", " edge_model = ai_edge_torch.signature(\n", " 'encode',\n", diff --git a/ai_edge_torch/generative/examples/test_models/toy_model.py b/ai_edge_torch/generative/examples/test_models/toy_model.py index 6e1be6c2..51980b97 100644 --- a/ai_edge_torch/generative/examples/test_models/toy_model.py +++ b/ai_edge_torch/generative/examples/test_models/toy_model.py @@ -70,7 +70,7 @@ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor: return self.lm_head(x) -def define_and_run() -> None: +def get_model_config() -> cfg.ModelConfig: attn_config = cfg.AttentionConfig( num_heads=32, num_query_groups=4, rotary_percentage=1.0, enable_kv_cache=False ) @@ -91,8 +91,11 @@ def define_and_run() -> None: pre_ff_norm_config=norm_config, final_norm_config=norm_config, ) + return config + - model = ToySingleLayerModel(config) +def define_and_run() -> None: + model = ToySingleLayerModel(get_model_config()) idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0) input_pos = torch.arange(0, KV_CACHE_MAX_LEN) print('running an inference') diff --git a/ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py b/ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py index 21f1ae20..9c8c1ffc 100644 --- a/ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +++ b/ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py @@ -50,7 +50,7 @@ def convert_tiny_llama_to_tflite( decode_token = torch.tensor([[0]], dtype=torch.long) decode_input_pos = torch.tensor([0], dtype=torch.int64) - quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None + quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None edge_model = ( ai_edge_torch.signature( 'prefill', pytorch_model, (prefill_tokens, prefill_input_pos) diff --git a/ai_edge_torch/generative/quantize/README.md b/ai_edge_torch/generative/quantize/README.md index 10336de5..7d0bb8ec 100644 --- a/ai_edge_torch/generative/quantize/README.md +++ b/ai_edge_torch/generative/quantize/README.md @@ -7,7 +7,7 @@ To apply quantization, we need to create a configuration that fully expresses ho `quant_recipes.py` contains a list of recipes that are known to be well-supported during runtime. For the average user, this is a good starting point to select the quantization scheme that is best suited for your deployment needs. After identifying the target recipe, the model can be quantized as follows. This example is extracted from `generative/examples/quantize/example.py`. ``` -quant_config = quant_recipes.full_linear_int8_dynamic_recipe() +quant_config = quant_recipes.full_int8_dynamic_recipe() edge_model = ai_edge_torch.convert( model, (tokens, input_pos), quant_config=quant_config ) diff --git a/ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py b/ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py index 86c78c96..c9b6587d 100644 --- a/ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +++ b/ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py @@ -26,12 +26,10 @@ _OpQuantConfig = quantizer.qtyping.OpQuantizationConfig _DEFAULT_REGEX_STR = '.*' -_ATTENTION_IDX_REGEX_STR = ( - 'transformer_blocks\[{}\]/ai_edge_torch.generative.layers.attention' -) -_FEEDFORWARD_IDX_REGEX_STR = ( - 'transformer_blocks\[{}\]/ai_edge_torch.generative.layers.feed_forward' -) +_SINGULAR_TRANSFORMER_BLOCK_REGEX_STR = 'transformer_block' +_IDX_TRANSFORMER_BLOCKS_REGEX_STR = 'transformer_blocks\[{}\]' +_ATTENTION_REGEX_STR = 'ai_edge_torch.generative.layers.attention' +_FEEDFORWARD_REGEX_STR = 'ai_edge_torch.generative.layers.feed_forward' _EMBEDDING_REGEX_STR = 'Embedding_tok_embedding' _ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}' @@ -82,27 +80,20 @@ def _set_quant_config( layer_recipe: quant_recipe.LayerQuantRecipe, regex: str, ): - support_op_list = [_OpName.FULLY_CONNECTED, _OpName.CONV_2D] - if layer_recipe.algorithm == quant_attrs.Algorithm.MIN_MAX: - support_op_list += [_OpName.BATCH_MATMUL, _OpName.EMBEDDING_LOOKUP] - for op_name in support_op_list: - rm.add_quantization_config( - regex=regex, - operation_name=op_name, - op_config=_OpQuantConfig( - weight_tensor_config=_TensorQuantConfig( - num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype), - symmetric=True, - channel_wise=_get_channelwise_from_granularity( - layer_recipe.granularity - ), - dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype), - ), - execution_mode=_get_execution_mode_from_mode(layer_recipe.mode), - ), - algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm), - override_algorithm=True, - ) + rm.add_quantization_config( + regex=regex, + operation_name=_OpName.ALL_SUPPORTED, + op_config=_OpQuantConfig( + weight_tensor_config=_TensorQuantConfig( + num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype), + symmetric=True, + channel_wise=_get_channelwise_from_granularity(layer_recipe.granularity), + dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype), + ), + execution_mode=_get_execution_mode_from_mode(layer_recipe.mode), + ), + algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm), + ) def translate_to_ai_edge_recipe( @@ -119,23 +110,31 @@ def translate_to_ai_edge_recipe( if recipe.attention is not None: if isinstance(recipe.attention, dict): for idx, layer in recipe.attention.items(): - _set_quant_config(rm, layer, _ATTENTION_IDX_REGEX_STR.format(idx)) + _set_quant_config( + rm, + layer, + f'{_IDX_TRANSFORMER_BLOCKS_REGEX_STR.format(idx)}/{_ATTENTION_REGEX_STR}', + ) else: _set_quant_config( rm, recipe.attention, - _ATTENTION_IDX_REGEX_STR.format(_ANY_TWO_DIGITS_REGEX_STR), + f'{_SINGULAR_TRANSFORMER_BLOCK_REGEX_STR}/{_ATTENTION_REGEX_STR}', ) if recipe.feedforward is not None: if isinstance(recipe.feedforward, dict): for idx, layer in recipe.feedforward.items(): - _set_quant_config(rm, layer, _FEEDFORWARD_IDX_REGEX_STR.format(idx)) + _set_quant_config( + rm, + layer, + f'{_IDX_TRANSFORMER_BLOCKS_REGEX_STR.format(idx)}/{_FEEDFORWARD_REGEX_STR}', + ) else: _set_quant_config( rm, recipe.feedforward, - _FEEDFORWARD_IDX_REGEX_STR.format(_ANY_TWO_DIGITS_REGEX_STR), + f'{_SINGULAR_TRANSFORMER_BLOCK_REGEX_STR}/{_FEEDFORWARD_REGEX_STR}', ) return rm.get_quantization_recipe() @@ -144,21 +143,6 @@ def translate_to_ai_edge_recipe( def quantize_model( model: bytearray, recipe: quantizer.recipe_manager.ModelQuantizationRecipe ) -> bytearray: - # TODO(b/336599483): Remove tempfile and use bytearray instead - tmp_model_path = '/tmp/tmp.tflite' - tmp_recipe_path = '/tmp/recipe.json' - with open(tmp_model_path, 'wb') as fp: - fp.write(model) - with open(tmp_recipe_path, 'w') as rp: - rp.write(json.dumps(recipe)) - - qt = quantizer.Quantizer(tmp_model_path, tmp_recipe_path) + qt = quantizer.Quantizer(bytearray(model), recipe) result = qt.quantize() - - # TODO(b/336599483): Remove tempfile and use bytearray instead - import os - - os.remove(tmp_model_path) - os.remove(tmp_recipe_path) - return result.quantized_model diff --git a/ai_edge_torch/generative/quantize/example.py b/ai_edge_torch/generative/quantize/example.py index 24ca0a8d..28944b78 100644 --- a/ai_edge_torch/generative/quantize/example.py +++ b/ai_edge_torch/generative/quantize/example.py @@ -31,7 +31,7 @@ def main(): input_pos = torch.arange(0, 10) # Create a quantization recipe to be applied to the model - quant_config = quant_recipes.full_linear_int8_dynamic_recipe() + quant_config = quant_recipes.full_int8_dynamic_recipe() print(quant_config) # Convert with quantization diff --git a/ai_edge_torch/generative/quantize/quant_recipes.py b/ai_edge_torch/generative/quantize/quant_recipes.py index e0baeb86..156847eb 100644 --- a/ai_edge_torch/generative/quantize/quant_recipes.py +++ b/ai_edge_torch/generative/quantize/quant_recipes.py @@ -21,7 +21,7 @@ Typical usage example: - quant_config = quant_recipes.full_linear_int8_dynamic_recipe() + quant_config = quant_recipes.full_int8_dynamic_recipe() edge_model = ai_edge_torch.convert( model, (tokens, input_pos), quant_config=quant_config ) @@ -32,7 +32,7 @@ from ai_edge_torch.quantize import quant_config -def full_linear_int8_dynamic_recipe() -> quant_config.QuantConfig: +def full_int8_dynamic_recipe() -> quant_config.QuantConfig: return quant_config.QuantConfig( generative_recipe=quant_recipe.GenerativeQuantRecipe( default=quant_recipe_utils.create_layer_quant_int8_dynamic(), diff --git a/ai_edge_torch/generative/test/test_quantize.py b/ai_edge_torch/generative/test/test_quantize.py index 8a97b874..66c19b28 100644 --- a/ai_edge_torch/generative/test/test_quantize.py +++ b/ai_edge_torch/generative/test/test_quantize.py @@ -19,7 +19,7 @@ import torch import ai_edge_torch -from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA +from ai_edge_torch.generative.examples.test_models import toy_model # NOQA from ai_edge_torch.generative.quantize import quant_recipe from ai_edge_torch.generative.quantize import quant_recipe_utils from ai_edge_torch.generative.quantize import quant_recipes @@ -93,35 +93,33 @@ def test_verify_valid_recipes( class TestQuantizeConvert(unittest.TestCase): """Test conversion with quantization.""" - def _attention_1_int8_dynamic_recipe() -> quant_config.QuantConfig: + def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig: return quant_config.QuantConfig( generative_recipe=quant_recipe.GenerativeQuantRecipe( - attention={1: quant_recipe_utils.create_layer_quant_int8_dynamic()}, + attention=quant_recipe_utils.create_layer_quant_int8_dynamic(), ) ) - def _feedforward_0_int8_dynamic_recipe() -> quant_config.QuantConfig: + def _feedforward_int8_dynamic_recipe() -> quant_config.QuantConfig: return quant_config.QuantConfig( generative_recipe=quant_recipe.GenerativeQuantRecipe( - feedforward={0: quant_recipe_utils.create_layer_quant_int8_dynamic()}, + feedforward=quant_recipe_utils.create_layer_quant_int8_dynamic(), ) ) @parameterized.expand( [ - (quant_recipes.full_fp16_recipe(), 0.75), - (quant_recipes.full_linear_int8_dynamic_recipe(), 0.64), - (_attention_1_int8_dynamic_recipe(), 0.95), - (_feedforward_0_int8_dynamic_recipe(), 0.87), + (quant_recipes.full_fp16_recipe(), 0.65), + (quant_recipes.full_int8_dynamic_recipe(), 0.47), + (_attention_int8_dynamic_recipe(), 0.89), + (_feedforward_int8_dynamic_recipe(), 0.72), ] ) def test_quantize_convert_toy_sizes(self, quant_config, expected_compression): - self.skipTest("b/346896669") - config = toy_model_with_kv_cache.get_model_config() - pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config) - idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor( - [10], dtype=torch.int64 - ) + config = toy_model.get_model_config() + pytorch_model = toy_model.ToySingleLayerModel(config) + idx = torch.unsqueeze(torch.arange(0, 100), 0) + input_pos = torch.arange(0, 100) quantized_model = ai_edge_torch.convert( pytorch_model, (idx, input_pos), quant_config=quant_config diff --git a/requirements.txt b/requirements.txt index 8b8974bd..f9d775e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -ai-edge-quantizer-nightly==0.0.1.dev202406070009 +ai-edge-quantizer-nightly==0.0.1.dev202407012233 scipy numpy tabulate