Skip to content

Commit

Permalink
Migrate quantization to use AI Edge Quantizer (#46)
Browse files Browse the repository at this point in the history
* Rename some quant variables to generative

* Migrate quantization to use AI Edge Quantizer

BUG=b/335285041

* Formatting

* ai-edge-quantizer requirements

* Formatting

* Refactor imports

* Use Union for py3.9

* Use newer tf-nightly and specify ops explicitly

* Pin to the latest nightly

* PR comments
  • Loading branch information
paulinesho authored Jun 10, 2024
1 parent 1ef63c7 commit 913831c
Show file tree
Hide file tree
Showing 11 changed files with 325 additions and 50 deletions.
22 changes: 17 additions & 5 deletions ai_edge_torch/convert/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch
from torch_xla import stablehlo

from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
from ai_edge_torch.quantize import quant_config as qcfg

try:
Expand Down Expand Up @@ -249,11 +250,6 @@ def _set_tfl_converter_quant_flags(
converter._experimental_qdq_conversion_mode = "DYNAMIC"
elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.PT2E_STATIC:
converter._experimental_qdq_conversion_mode = "STATIC"
elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.TFLITE_DYNAMIC:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
elif quantizer_mode == qcfg.QuantConfig._QuantizerMode.TFLITE_FP16:
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]


def convert_stablehlo_to_tflite(
Expand Down Expand Up @@ -323,8 +319,24 @@ def convert_stablehlo_to_tflite(
converter._experimental_enable_composite_direct_lowering = True

_set_tfl_converter_quant_flags(converter, quant_config)
if (
quant_config is not None
and quant_config._quantizer_mode
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
):
translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
quant_config.generative_recipe
)

_apply_tfl_backdoor_flags(converter, _tfl_converter_flags)

tflite_model = converter.convert()

if (
quant_config is not None
and quant_config._quantizer_mode
== quant_config._QuantizerMode.AI_EDGE_QUANTIZER
):
tflite_model = translate_recipe.quantize_model(tflite_model, translated_recipe)

return tflite_model
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright 2024 The AI Edge Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import json

from ai_edge_quantizer import quantizer

from ai_edge_torch.generative.quantize import quant_attrs
from ai_edge_torch.generative.quantize import quant_recipe

_OpExecutionMode = quantizer.qtyping.OpExecutionMode
_OpName = quantizer.qtyping.TFLOperationName
_TensorQuantConfig = quantizer.qtyping.TensorQuantizationConfig
_OpQuantConfig = quantizer.qtyping.OpQuantizationConfig

_DEFAULT_REGEX_STR = '.*'
_ATTENTION_IDX_REGEX_STR = (
'transformer_blocks\[{}\]/ai_edge_torch.generative.layers.attention'
)
_FEEDFORWARD_IDX_REGEX_STR = (
'transformer_blocks\[{}\]/ai_edge_torch.generative.layers.feed_forward'
)
_EMBEDDING_REGEX_STR = 'Embedding_tok_embedding'
_ANY_TWO_DIGITS_REGEX_STR = '\d{1,2}'


def _get_nbits_from_dtype(dtype: quant_attrs.Dtype) -> int:
if dtype == quant_attrs.Dtype.FP32:
return 32
elif dtype == quant_attrs.Dtype.FP16:
return 16
elif dtype == quant_attrs.Dtype.INT8:
return 8
raise ValueError('Unimplemented number of bits')


def _get_dtype_from_dtype(dtype: quant_attrs.Dtype) -> quantizer.qtyping.TensorDataType:
if dtype == quant_attrs.Dtype.FP32 or dtype == quant_attrs.Dtype.FP16:
return quantizer.qtyping.TensorDataType.FLOAT
else:
return quantizer.qtyping.TensorDataType.INT


def _get_execution_mode_from_mode(mode: quant_attrs.Mode) -> _OpExecutionMode:
if mode == quant_attrs.Mode.DYNAMIC_RANGE:
return _OpExecutionMode.DRQ
elif mode == quant_attrs.Mode.WEIGHT_ONLY:
return _OpExecutionMode.WEIGHT_ONLY
raise ValueError('Unimplemented execution mode')


def _get_channelwise_from_granularity(granularity: quant_attrs.Granularity) -> bool:
if granularity == quant_attrs.Granularity.CHANNELWISE:
return True
elif granularity == quant_attrs.Granularity.NONE:
return False
raise ValueError('Unimplemented granularity')


def _get_algorithm_key_from_algorithm(algo: quant_attrs.Algorithm) -> str:
if algo == quant_attrs.Algorithm.MIN_MAX:
return quantizer.algorithm_manager.AlgorithmName.MIN_MAX_UNIFORM_QUANT
elif algo == quant_attrs.Algorithm.FLOAT_CAST:
return quantizer.algorithm_manager.AlgorithmName.FLOAT_CASTING
raise ValueError('Unimplemented algorithm')


def _set_quant_config(
rm: quantizer.recipe_manager.RecipeManager,
layer_recipe: quant_recipe.LayerQuantRecipe,
regex: str,
):
support_op_list = [_OpName.FULLY_CONNECTED, _OpName.CONV_2D]
if layer_recipe.algorithm == quant_attrs.Algorithm.MIN_MAX:
support_op_list += [_OpName.BATCH_MATMUL, _OpName.EMBEDDING_LOOKUP]
for op_name in support_op_list:
rm.add_quantization_config(
regex=regex,
operation_name=op_name,
op_config=_OpQuantConfig(
weight_tensor_config=_TensorQuantConfig(
num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
symmetric=True,
channel_wise=_get_channelwise_from_granularity(
layer_recipe.granularity
),
dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
),
execution_mode=_get_execution_mode_from_mode(layer_recipe.mode),
),
algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
override_algorithm=True,
)


def translate_to_ai_edge_recipe(
recipe: quant_recipe.GenerativeQuantRecipe,
) -> quantizer.recipe_manager.ModelQuantizationRecipe:
rm = quantizer.recipe_manager.RecipeManager()

if recipe.default is not None:
_set_quant_config(rm, recipe.default, _DEFAULT_REGEX_STR)

if recipe.embedding is not None:
_set_quant_config(rm, recipe.embedding, _EMBEDDING_REGEX_STR)

if recipe.attention is not None:
if isinstance(recipe.attention, dict):
for idx, layer in recipe.attention.items():
_set_quant_config(rm, layer, _ATTENTION_IDX_REGEX_STR.format(idx))
else:
_set_quant_config(
rm,
recipe.attention,
_ATTENTION_IDX_REGEX_STR.format(_ANY_TWO_DIGITS_REGEX_STR),
)

if recipe.feedforward is not None:
if isinstance(recipe.feedforward, dict):
for idx, layer in recipe.feedforward.items():
_set_quant_config(rm, layer, _FEEDFORWARD_IDX_REGEX_STR.format(idx))
else:
_set_quant_config(
rm,
recipe.feedforward,
_FEEDFORWARD_IDX_REGEX_STR.format(_ANY_TWO_DIGITS_REGEX_STR),
)

return rm.get_quantization_recipe()


def quantize_model(
model: bytearray, recipe: quantizer.recipe_manager.ModelQuantizationRecipe
) -> bytearray:
# TODO(b/336599483): Remove tempfile and use bytearray instead
tmp_model_path = '/tmp/tmp.tflite'
tmp_recipe_path = '/tmp/recipe.json'
with open(tmp_model_path, 'wb') as fp:
fp.write(model)
with open(tmp_recipe_path, 'w') as rp:
rp.write(json.dumps(recipe))

qt = quantizer.Quantizer(tmp_model_path, tmp_recipe_path)
result = qt.quantize()

# TODO(b/336599483): Remove tempfile and use bytearray instead
import os

os.remove(tmp_model_path)
os.remove(tmp_recipe_path)

return result.quantized_model
2 changes: 2 additions & 0 deletions ai_edge_torch/generative/quantize/quant_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ class Algorithm(enum.Enum):
Attributes:
MIN_MAX: Maps the min/max of floating point space to the min/max of
quantized space and quantize uniformly.
FLOAT_CAST: Casts a float to another float of a different type.
"""

MIN_MAX = enum.auto()
FLOAT_CAST = enum.auto()


@enum.unique
Expand Down
53 changes: 49 additions & 4 deletions ai_edge_torch/generative/quantize/quant_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
# ==============================================================================

from dataclasses import dataclass
import enum
from typing import Optional
from typing import Optional, Union

from ai_edge_torch.generative.quantize import quant_attrs
from ai_edge_torch.generative.quantize import supported_schemes
Expand Down Expand Up @@ -80,18 +79,50 @@ def verify(self):


@dataclass
class TransformerQuantRecipe:
class GenerativeQuantRecipe:
"""Quantization recipe for a model composed of the Edge Generative API layers.
Some layers can be specified with different `LayerQuantRecipe` for each block by
providing a dictionary keyed by the TransformerBlock index, e.g. attention
and feedforward. For example,
```
default = LayerQuantRecipeA
attention = { 2: LayerQuantRecipeB }
feedforward = { 3: LayerQuantRecipeC }
```
will apply LayerQuantRecipeA to the entire model, overriden by
LayerQuantRecipeB for the TransformerBlock[2].attention layer and
LayerQuantRecipeC for the TransformerBlock[3].feedforward layer. Any config
with invalid indices will be ignored.
Attributes:
default: The quantization recipe for global scope of the model.
embedding: Recipe for the embedding table.
attention: Recipe for the attention blocks. This could be specified with
different LayerQuantRecipe for each block by providing a dictionary
keyed by the TransformerBlock index.
feedforward: Recipe for the feedforward layers. This could be specified with
different LayerQuantRecipe for each block by providing a dictionary
keyed by the TransformerBlock index.
"""

default: Optional[LayerQuantRecipe] = None
embedding: Optional[LayerQuantRecipe] = None
attention: Union[
Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
] = None
feedforward: Union[
Optional[LayerQuantRecipe], Optional[dict[int, LayerQuantRecipe]]
] = None

def __str__(self):
return f"""TransformerQuantRecipe(
return f"""GenerativeQuantRecipe(
Default: {self.default}
Embedding: {self.embedding}
Attention: {self.attention}
Feedforward: {self.feedforward}
)"""

__repr__ = __str__
Expand All @@ -104,3 +135,17 @@ def verify(self):
"""
if self.default is not None:
self.default.verify()
if self.embedding is not None:
self.embedding.verify()
if self.attention is not None:
if isinstance(self.attention, dict):
for recipe in self.attention.values():
recipe.verify()
else:
self.attention.verify()
if self.feedforward is not None:
if isinstance(self.feedforward, dict):
for recipe in self.feedforward.values():
recipe.verify()
else:
self.feedforward.verify()
4 changes: 2 additions & 2 deletions ai_edge_torch/generative/quantize/quant_recipe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
1. Applying a single layer recipe to the entire model
quant_recipe.TransformerQuantRecipe(
quant_recipe.GenerativeQuantRecipe(
default=quant_recipe_utils.create_layer_quant_int8_dynamic()
)
"""
Expand All @@ -46,6 +46,6 @@ def create_layer_quant_fp16() -> quant_recipe.LayerQuantRecipe:
activation_dtype=quant_attrs.Dtype.FP32,
weight_dtype=quant_attrs.Dtype.FP16,
mode=quant_attrs.Mode.WEIGHT_ONLY,
algorithm=quant_attrs.Algorithm.MIN_MAX,
algorithm=quant_attrs.Algorithm.FLOAT_CAST,
granularity=quant_attrs.Granularity.NONE,
)
6 changes: 3 additions & 3 deletions ai_edge_torch/generative/quantize/quant_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@

def full_linear_int8_dynamic_recipe() -> quant_config.QuantConfig:
return quant_config.QuantConfig(
transformer_recipe=quant_recipe.TransformerQuantRecipe(
default=quant_recipe_utils.create_layer_quant_int8_dynamic()
generative_recipe=quant_recipe.GenerativeQuantRecipe(
default=quant_recipe_utils.create_layer_quant_int8_dynamic(),
)
)


def full_fp16_recipe() -> quant_config.QuantConfig:
return quant_config.QuantConfig(
transformer_recipe=quant_recipe.TransformerQuantRecipe(
generative_recipe=quant_recipe.GenerativeQuantRecipe(
default=quant_recipe_utils.create_layer_quant_fp16()
)
)
3 changes: 2 additions & 1 deletion ai_edge_torch/generative/quantize/supported_schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,6 @@ def get_supported_layer_schemes():

return [
(_t.FP32, _t.INT8, _m.DYNAMIC_RANGE, _a.MIN_MAX, _g.CHANNELWISE),
(_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.NONE),
(_t.FP32, _t.INT8, _m.WEIGHT_ONLY, _a.MIN_MAX, _g.CHANNELWISE),
(_t.FP32, _t.FP16, _m.WEIGHT_ONLY, _a.FLOAT_CAST, _g.NONE),
]
Loading

0 comments on commit 913831c

Please sign in to comment.