Skip to content

Commit

Permalink
Make int8 dynamic quant in autoquant serializable (#1484)
Browse files Browse the repository at this point in the history
* Make int8 dynamic quant in autoquant serializable

Summary:
lambda function is not supported for serialization, so we need to reuse the non-lambda
functions that already supports serialization: https://github.com/pytorch/ao/blob/00a8d290aab354985fce8c880e1fded22bc48e30/torchao/quantization/quant_api.py#L1263C5-L1268

Note this PR only supports int8 dynamic quant, will need to test and support
float8 separately (in H100 machines)

Test Plan:
Tested locally with transformer push_to_hub: https://huggingface.co/jerryzh168/llama3-8b-autoquant/tree/main

Reviewers:

Subscribers:

Tasks:

Tags:

* fix

* fixes

* fix
  • Loading branch information
jerryzh168 authored Jan 3, 2025
1 parent 00a8d29 commit 3f36c78
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 25 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ repos:
alias: ruff-isolated
args:
- --isolated
- select F821,F823,W191
- --select
- F821,F823,W191
37 changes: 13 additions & 24 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,12 @@ def from_float(cls, weight):

# avoid circular dep
from torchao.dtypes import to_affine_quantized_intx
from torchao.quantization.quant_api import (
_int8_symm_per_token_reduced_range_quant,
)

# input settings
input_quant_func = _int8_symm_per_token_reduced_range_quant

# weight settings
mapping_type = MappingType.SYMMETRIC
Expand All @@ -436,32 +442,9 @@ def get_weight_block_size(x):
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size) - 1):
block_size[i] = 1
return block_size

input_mapping_type = MappingType.SYMMETRIC
input_target_dtype = torch.int8
input_eps = 1e-5
input_quant_min = -127
input_quant_max = 127
_layout = cls.layout
input_quant_func = lambda x: to_affine_quantized_intx(
x,
input_mapping_type,
get_per_token_block_size(x),
input_target_dtype,
eps=input_eps,
quant_min=input_quant_min,
quant_max=input_quant_max,
scale_dtype=torch.float32 if x.dtype == torch.float16 else None,
)

block_size = get_weight_block_size(weight)

weight = to_affine_quantized_intx(
weight,
mapping_type,
Expand Down Expand Up @@ -937,6 +920,7 @@ def get_per_token_block_size(x):

input_target_dtype = torch.float8_e4m3fn
_layout = Float8Layout(mm_config=Float8MMConfig(use_fast_accum=True))
# TODO: make this serializable
input_quant_func = lambda x: _input_activation_quant_func_fp8(
x=x,
activation_granularity=cls.activation_granularity,
Expand Down Expand Up @@ -980,6 +964,7 @@ def get_weight_block_size(x):

input_target_dtype = torch.float8_e4m3fn
_layout = Float8Layout(mm_config=Float8MMConfig(use_fast_accum=True))
# TODO: make this serializable
input_quant_func = lambda x: _input_activation_quant_func_fp8(
x=x,
activation_granularity=cls.activation_granularity,
Expand Down Expand Up @@ -1287,3 +1272,7 @@ def finalize_autoquant():
model(*example_input)

return model


if TORCH_VERSION_AT_LEAST_2_5:
torch.serialization.add_safe_globals(ALL_AUTOQUANT_CLASS_LIST)

0 comments on commit 3f36c78

Please sign in to comment.