Skip to content

Commit

Permalink
Update quantizer to new nightly (#72)
Browse files Browse the repository at this point in the history
* Update quantizer to new nightly

* Read model from bytearray instead of temporary files
* Use ALL_SUPPORTED keyword

BUG=b/346896669,b/346896669

* Rename full_linear to full

* Update to new nightly

---------

Co-authored-by: Haoliang Zhang <[email protected]>
  • Loading branch information
paulinesho and haozha111 authored Jul 3, 2024
1 parent 6e37d95 commit 68752fe
Show file tree
Hide file tree
Showing 14 changed files with 61 additions and 76 deletions.
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Quantization can be done via the API exposed in [quantize](quantize/). To apply
`quant_recipes.py` contains a list of recipes that are known to be well-supported during runtime. For the average user, this is a good starting point to select the quantization scheme that is best suited for your deployment needs. After identifying the target recipe, the model can be quantized as follows. This example is extracted from `generative/examples/quantize/example.py`.

```
quant_config = quant_recipes.full_linear_int8_dynamic_recipe()
quant_config = quant_recipes.full_int8_dynamic_recipe()
edge_model = ai_edge_torch.convert(
model, (tokens, input_pos), quant_config=quant_config
)
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ To apply quantization, we need to create a configuration that fully expresses ho
`quantize/quant_recipes.py` contains a list of recipes that are known to be well-supported during runtime. For the average user, this is a good starting point to select the quantization scheme that is best suited for your deployment needs. After identifying the target recipe, the model can be quantized as follows. This example is extracted from `generative/examples/quantize/example.py`.

```
quant_config = quant_recipes.full_linear_int8_dynamic_recipe()
quant_config = quant_recipes.full_int8_dynamic_recipe()
edge_model = ai_edge_torch.convert(
model, (tokens, input_pos), quant_config=quant_config
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def convert_gemma_to_tflite(
decode_token = torch.tensor([[0]], dtype=torch.long)
decode_input_pos = torch.tensor([0], dtype=torch.int64)

quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
edge_model = (
ai_edge_torch.signature(
'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def convert_phi2_to_tflite(
decode_token = torch.tensor([[0]], dtype=torch.long)
decode_input_pos = torch.tensor([0], dtype=torch.int64)

quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
edge_model = (
ai_edge_torch.signature(
'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/examples/t5/convert_to_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def convert_t5_to_tflite_multisig(checkpoint_path: str):
# Pad with `-inf` for any tokens indices that aren't desired.
pad_mask = torch.zeros([seq_len], dtype=torch.float32)
hidden_states = torch.zeros((1, 512, 768), dtype=torch.float32)
quant_config = quant_recipes.full_linear_int8_dynamic_recipe()
quant_config = quant_recipes.full_int8_dynamic_recipe()

edge_model = (
ai_edge_torch.signature(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
" # Pad with `-inf` for any tokens indices that aren't desired.\n",
" pad_mask = torch.zeros([seq_len], dtype=torch.float32)\n",
" hidden_states = torch.zeros((1, 512, 768), dtype=torch.float32)\n",
" quant_config = quant_recipes.full_linear_int8_dynamic_recipe()\n",
" quant_config = quant_recipes.full_int8_dynamic_recipe()\n",
"\n",
" edge_model = ai_edge_torch.signature(\n",
" 'encode',\n",
Expand Down
7 changes: 5 additions & 2 deletions ai_edge_torch/generative/examples/test_models/toy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
return self.lm_head(x)


def define_and_run() -> None:
def get_model_config() -> cfg.ModelConfig:
attn_config = cfg.AttentionConfig(
num_heads=32, num_query_groups=4, rotary_percentage=1.0, enable_kv_cache=False
)
Expand All @@ -91,8 +91,11 @@ def define_and_run() -> None:
pre_ff_norm_config=norm_config,
final_norm_config=norm_config,
)
return config


model = ToySingleLayerModel(config)
def define_and_run() -> None:
model = ToySingleLayerModel(get_model_config())
idx = torch.unsqueeze(torch.arange(0, KV_CACHE_MAX_LEN), 0)
input_pos = torch.arange(0, KV_CACHE_MAX_LEN)
print('running an inference')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def convert_tiny_llama_to_tflite(
decode_token = torch.tensor([[0]], dtype=torch.long)
decode_input_pos = torch.tensor([0], dtype=torch.int64)

quant_config = quant_recipes.full_linear_int8_dynamic_recipe() if quantize else None
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
edge_model = (
ai_edge_torch.signature(
'prefill', pytorch_model, (prefill_tokens, prefill_input_pos)
Expand Down
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/quantize/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ To apply quantization, we need to create a configuration that fully expresses ho
`quant_recipes.py` contains a list of recipes that are known to be well-supported during runtime. For the average user, this is a good starting point to select the quantization scheme that is best suited for your deployment needs. After identifying the target recipe, the model can be quantized as follows. This example is extracted from `generative/examples/quantize/example.py`.

```
quant_config = quant_recipes.full_linear_int8_dynamic_recipe()
quant_config = quant_recipes.full_int8_dynamic_recipe()
edge_model = ai_edge_torch.convert(
model, (tokens, input_pos), quant_config=quant_config
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@
_OpQuantConfig = quantizer.qtyping.OpQuantizationConfig

_DEFAULT_REGEX_STR = '.*'
_ATTENTION_IDX_REGEX_STR = (
'transformer_blocks\[{}\]/ai_edge_torch.generative.layers.attention'
)
_FEEDFORWARD_IDX_REGEX_STR = (
'transformer_blocks\[{}\]/ai_edge_torch.generative.layers.feed_forward'
)
_SINGULAR_TRANSFORMER_BLOCK_REGEX_STR = 'transformer_block'
_IDX_TRANSFORMER_BLOCKS_REGEX_STR = 'transformer_blocks\[{}\]'
_ATTENTION_REGEX_STR = 'ai_edge_torch.generative.layers.attention'
_FEEDFORWARD_REGEX_STR = 'ai_edge_torch.generative.layers.feed_forward'
_EMBEDDING_REGEX_STR = 'Embedding_tok_embedding'
_ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}'

Expand Down Expand Up @@ -82,27 +80,20 @@ def _set_quant_config(
layer_recipe: quant_recipe.LayerQuantRecipe,
regex: str,
):
support_op_list = [_OpName.FULLY_CONNECTED, _OpName.CONV_2D]
if layer_recipe.algorithm == quant_attrs.Algorithm.MIN_MAX:
support_op_list += [_OpName.BATCH_MATMUL, _OpName.EMBEDDING_LOOKUP]
for op_name in support_op_list:
rm.add_quantization_config(
regex=regex,
operation_name=op_name,
op_config=_OpQuantConfig(
weight_tensor_config=_TensorQuantConfig(
num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
symmetric=True,
channel_wise=_get_channelwise_from_granularity(
layer_recipe.granularity
),
dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
),
execution_mode=_get_execution_mode_from_mode(layer_recipe.mode),
),
algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
override_algorithm=True,
)
rm.add_quantization_config(
regex=regex,
operation_name=_OpName.ALL_SUPPORTED,
op_config=_OpQuantConfig(
weight_tensor_config=_TensorQuantConfig(
num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
symmetric=True,
channel_wise=_get_channelwise_from_granularity(layer_recipe.granularity),
dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
),
execution_mode=_get_execution_mode_from_mode(layer_recipe.mode),
),
algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
)


def translate_to_ai_edge_recipe(
Expand All @@ -119,23 +110,31 @@ def translate_to_ai_edge_recipe(
if recipe.attention is not None:
if isinstance(recipe.attention, dict):
for idx, layer in recipe.attention.items():
_set_quant_config(rm, layer, _ATTENTION_IDX_REGEX_STR.format(idx))
_set_quant_config(
rm,
layer,
f'{_IDX_TRANSFORMER_BLOCKS_REGEX_STR.format(idx)}/{_ATTENTION_REGEX_STR}',
)
else:
_set_quant_config(
rm,
recipe.attention,
_ATTENTION_IDX_REGEX_STR.format(_ANY_TWO_DIGITS_REGEX_STR),
f'{_SINGULAR_TRANSFORMER_BLOCK_REGEX_STR}/{_ATTENTION_REGEX_STR}',
)

if recipe.feedforward is not None:
if isinstance(recipe.feedforward, dict):
for idx, layer in recipe.feedforward.items():
_set_quant_config(rm, layer, _FEEDFORWARD_IDX_REGEX_STR.format(idx))
_set_quant_config(
rm,
layer,
f'{_IDX_TRANSFORMER_BLOCKS_REGEX_STR.format(idx)}/{_FEEDFORWARD_REGEX_STR}',
)
else:
_set_quant_config(
rm,
recipe.feedforward,
_FEEDFORWARD_IDX_REGEX_STR.format(_ANY_TWO_DIGITS_REGEX_STR),
f'{_SINGULAR_TRANSFORMER_BLOCK_REGEX_STR}/{_FEEDFORWARD_REGEX_STR}',
)

return rm.get_quantization_recipe()
Expand All @@ -144,21 +143,6 @@ def translate_to_ai_edge_recipe(
def quantize_model(
model: bytearray, recipe: quantizer.recipe_manager.ModelQuantizationRecipe
) -> bytearray:
# TODO(b/336599483): Remove tempfile and use bytearray instead
tmp_model_path = '/tmp/tmp.tflite'
tmp_recipe_path = '/tmp/recipe.json'
with open(tmp_model_path, 'wb') as fp:
fp.write(model)
with open(tmp_recipe_path, 'w') as rp:
rp.write(json.dumps(recipe))

qt = quantizer.Quantizer(tmp_model_path, tmp_recipe_path)
qt = quantizer.Quantizer(bytearray(model), recipe)
result = qt.quantize()

# TODO(b/336599483): Remove tempfile and use bytearray instead
import os

os.remove(tmp_model_path)
os.remove(tmp_recipe_path)

return result.quantized_model
2 changes: 1 addition & 1 deletion ai_edge_torch/generative/quantize/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def main():
input_pos = torch.arange(0, 10)

# Create a quantization recipe to be applied to the model
quant_config = quant_recipes.full_linear_int8_dynamic_recipe()
quant_config = quant_recipes.full_int8_dynamic_recipe()
print(quant_config)

# Convert with quantization
Expand Down
4 changes: 2 additions & 2 deletions ai_edge_torch/generative/quantize/quant_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
Typical usage example:
quant_config = quant_recipes.full_linear_int8_dynamic_recipe()
quant_config = quant_recipes.full_int8_dynamic_recipe()
edge_model = ai_edge_torch.convert(
model, (tokens, input_pos), quant_config=quant_config
)
Expand All @@ -32,7 +32,7 @@
from ai_edge_torch.quantize import quant_config


def full_linear_int8_dynamic_recipe() -> quant_config.QuantConfig:
def full_int8_dynamic_recipe() -> quant_config.QuantConfig:
return quant_config.QuantConfig(
generative_recipe=quant_recipe.GenerativeQuantRecipe(
default=quant_recipe_utils.create_layer_quant_int8_dynamic(),
Expand Down
28 changes: 13 additions & 15 deletions ai_edge_torch/generative/test/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch

import ai_edge_torch
from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache # NOQA
from ai_edge_torch.generative.examples.test_models import toy_model # NOQA
from ai_edge_torch.generative.quantize import quant_recipe
from ai_edge_torch.generative.quantize import quant_recipe_utils
from ai_edge_torch.generative.quantize import quant_recipes
Expand Down Expand Up @@ -93,35 +93,33 @@ def test_verify_valid_recipes(
class TestQuantizeConvert(unittest.TestCase):
"""Test conversion with quantization."""

def _attention_1_int8_dynamic_recipe() -> quant_config.QuantConfig:
def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
return quant_config.QuantConfig(
generative_recipe=quant_recipe.GenerativeQuantRecipe(
attention={1: quant_recipe_utils.create_layer_quant_int8_dynamic()},
attention=quant_recipe_utils.create_layer_quant_int8_dynamic(),
)
)

def _feedforward_0_int8_dynamic_recipe() -> quant_config.QuantConfig:
def _feedforward_int8_dynamic_recipe() -> quant_config.QuantConfig:
return quant_config.QuantConfig(
generative_recipe=quant_recipe.GenerativeQuantRecipe(
feedforward={0: quant_recipe_utils.create_layer_quant_int8_dynamic()},
feedforward=quant_recipe_utils.create_layer_quant_int8_dynamic(),
)
)

@parameterized.expand(
[
(quant_recipes.full_fp16_recipe(), 0.75),
(quant_recipes.full_linear_int8_dynamic_recipe(), 0.64),
(_attention_1_int8_dynamic_recipe(), 0.95),
(_feedforward_0_int8_dynamic_recipe(), 0.87),
(quant_recipes.full_fp16_recipe(), 0.65),
(quant_recipes.full_int8_dynamic_recipe(), 0.47),
(_attention_int8_dynamic_recipe(), 0.89),
(_feedforward_int8_dynamic_recipe(), 0.72),
]
)
def test_quantize_convert_toy_sizes(self, quant_config, expected_compression):
self.skipTest("b/346896669")
config = toy_model_with_kv_cache.get_model_config()
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
[10], dtype=torch.int64
)
config = toy_model.get_model_config()
pytorch_model = toy_model.ToySingleLayerModel(config)
idx = torch.unsqueeze(torch.arange(0, 100), 0)
input_pos = torch.arange(0, 100)

quantized_model = ai_edge_torch.convert(
pytorch_model, (idx, input_pos), quant_config=quant_config
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ai-edge-quantizer-nightly==0.0.1.dev202406070009
ai-edge-quantizer-nightly==0.0.1.dev202407012233
scipy
numpy
tabulate
Expand Down

0 comments on commit 68752fe

Please sign in to comment.