Skip to content

Commit

Permalink
Add new QAT API through quantize_ (#1415)
Browse files Browse the repository at this point in the history
**Summary:** This commit adds a new QAT API that can be used with the existing `quantize_`. This is an alternative to the old QAT *Quantizer APIs, which are much less flexible. The new API can be used as follows:

```
from torchao import quantize_
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    intx_quantization_aware_training,
)
my_model = ...
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
    my_model,
    intx_quantization_aware_training(activation_config, weight_config),
)
```

**Test Plan:**
python test/quantization/test_qat.py -k test_quantize_api
Pull Request resolved: #1415
Approved by: https://github.com/jerryzh168
  • Loading branch information
andrewor14 authored and pytorchmergebot committed Dec 16, 2024
1 parent 46b8796 commit 200589b
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 4 deletions.
106 changes: 106 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch.nn.functional as F
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401

from torchao import quantize_
from torchao.quantization.GPTQ import _replace_linear_8da4w, _replace_linear_int4
from torchao.quantization.granularity import (
PerAxis,
Expand All @@ -24,6 +25,7 @@
from torchao.quantization.qat.api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
intx_quantization_aware_training,
)
from torchao.quantization.qat.embedding import (
FakeQuantizedEmbedding,
Expand Down Expand Up @@ -104,6 +106,25 @@ def forward(self, x):
return self.embedding(x)


class M3(torch.nn.Module):
def __init__(self):
super().__init__()
self.embedding = torch.nn.Embedding(10, 512)
self.linear1 = torch.nn.Linear(512, 256, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(256, 512, bias=False).to(torch.float)
self.relu = torch.nn.ReLU()

def example_inputs(self):
return (torch.randint(1, 10, (1, 512)),)

def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.relu(x)
return x


class TestQAT(unittest.TestCase):
SEED = 123

Expand Down Expand Up @@ -1156,6 +1177,91 @@ def test_qat_prototype_bc(self):
Int8DynActInt4WeightQATQuantizer,
)

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_quantize_api(self):
"""
Test that the following:
quantize_(model, intx_quantization_aware_training(...))
can produce the same results as `ComposableQATQuantizer`.
"""
from torchao.quantization.qat import (
ComposableQATQuantizer,
Int4WeightOnlyEmbeddingQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)

group_size = 16
torch.manual_seed(self.SEED)
m = M3()
baseline_model = copy.deepcopy(m)

# Baseline quantizer
baseline_quantizer = ComposableQATQuantizer(
[
Int8DynActInt4WeightQATQuantizer(groupsize=group_size),
Int4WeightOnlyEmbeddingQATQuantizer(group_size=group_size),
]
)
baseline_model = baseline_quantizer.prepare(baseline_model)

# quantize_ API
activation_config = FakeQuantizeConfig(
torch.int8,
"per_token",
is_symmetric=False,
)
weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
quantize_(
m,
intx_quantization_aware_training(activation_config, weight_config),
)
quantize_(
m,
intx_quantization_aware_training(weight_config=weight_config),
filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)

# Compare model values
torch.manual_seed(self.SEED)
x = m.example_inputs()
x2 = copy.deepcopy(x)
out = m(*x)
baseline_out = baseline_model(*x2)
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_quantize_api_errors(self):
"""
Test that we throw exceptions with helpful error messages if `quantize_`
runs into unexpected configurations.
"""
my_config = FakeQuantizeConfig(torch.int8, group_size=32)
m = M3()

# Embedding currently only supports weight-only quantization
with self.assertRaisesRegex(
ValueError, "Activation fake quantization is not supported for embedding"
):
quantize_(
m,
intx_quantization_aware_training(my_config, my_config),
lambda m, _: isinstance(m, torch.nn.Embedding),
)

# Only linear and embedding are supported currently
with self.assertRaisesRegex(ValueError, "does not have QAT support"):
quantize_(
m,
intx_quantization_aware_training(my_config, my_config),
lambda m, _: isinstance(m, torch.nn.ReLU),
)


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
int8_dynamic_activation_int8_semi_sparse_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
intx_quantization_aware_training,
quantize_,
swap_conv2d_1x1_to_linear,
uintx_weight_only,
Expand Down Expand Up @@ -103,6 +104,7 @@
"int8_dynamic_activation_int8_semi_sparse_weight",
"int4_weight_only",
"int8_weight_only",
"intx_quantization_aware_training",
"float8_weight_only",
"float8_dynamic_activation_float8_weight",
"float8_static_activation_float8_weight",
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/qat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
intx_quantization_aware_training,
)
from .embedding import (
Int4WeightOnlyEmbeddingQATQuantizer,
Expand All @@ -11,7 +13,9 @@

__all__ = [
"ComposableQATQuantizer",
"FakeQuantizeConfig",
"Int4WeightOnlyQATQuantizer",
"Int4WeightOnlyEmbeddingQATQuantizer",
"Int8DynActInt4WeightQATQuantizer",
"intx_quantization_aware_training",
]
58 changes: 57 additions & 1 deletion torchao/quantization/qat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class FakeQuantizeConfig:
scale_precision: scale dtype (default torch.fp32)
zero_point_precision: zero point dtype (default torch.int32)
zero_point_domain: whether zero point is in integer (default) or float domain
is_dynamic: whether to use dynamic (defualt) or static scale and zero points
is_dynamic: whether to use dynamic (default) or static scale and zero points
range_learning: whether to learn scale and zero points during training (coming soon)
kwargs (optional):
Expand Down Expand Up @@ -239,6 +239,62 @@ def __setattr__(self, name: str, value: Any):
super().__setattr__(name, value)


def intx_quantization_aware_training(
activation_config: Optional[FakeQuantizeConfig] = None,
weight_config: Optional[FakeQuantizeConfig] = None,
) -> torch.nn.Module:
"""
Return a function that applies fake quantization to a `torch.nn.Module`.
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
Example usage::
from torchao.quantization import quantize_
from torchao.quantization.qat import FakeQuantizeConfig
activation_config = FakeQuantizeConfig(
torch.int8, "per_token", is_symmetric=False,
)
weight_config = FakeQuantizeConfig(
torch.int4, group_size=32, is_symmetric=True,
)
quantize_(
model,
intx_quantization_aware_training(activation_config, weight_config),
)
Note: If the returned function is applied on a module that is not
`torch.nn.Linear` or `torch.nn.Embedding`, or it is applied on
`torch.nn.Embedding` with an activation config, then we will raise
ValueError as these are not supported.
"""

def _insert_fake_quantize(mod: torch.nn.Module):
"""
Swap the given module with its corresponding fake quantized version.
"""
from .embedding import FakeQuantizedEmbedding
from .linear import FakeQuantizedLinear

if isinstance(mod, torch.nn.Linear):
return FakeQuantizedLinear.from_linear(
mod,
activation_config,
weight_config,
)
elif isinstance(mod, torch.nn.Embedding):
if activation_config is not None:
raise ValueError(
"Activation fake quantization is not supported for embedding"
)
return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
else:
raise ValueError(
"Module of type '%s' does not have QAT support" % type(mod)
)

return _insert_fake_quantize


class ComposableQATQuantizer(TwoStepQuantizer):
"""
Composable quantizer that users can use to apply multiple QAT quantizers easily.
Expand Down
31 changes: 28 additions & 3 deletions torchao/quantization/qat/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
import torch
import torch.nn.functional as F

from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
)
from torchao.quantization.quant_primitives import TorchAODType
from torchao.quantization.unified import TwoStepQuantizer
from torchao.quantization.utils import get_group_qparams_symmetric
Expand Down Expand Up @@ -85,6 +82,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.sparse,
)

@classmethod
def from_embedding(
cls,
mod: torch.nn.Embedding,
weight_config: Optional[FakeQuantizeConfig] = None,
):
new_embedding = FakeQuantizedEmbedding(
mod.num_embeddings,
mod.embedding_dim,
mod.padding_idx,
mod.max_norm,
mod.norm_type,
mod.scale_grad_by_freq,
mod.sparse,
weight_config=weight_config,
device=mod.weight.device,
)
# In distributed training, the model may be instantiated
# on the meta device, in which case there is no need to
# copy the weights, and doing so will result in an error
if mod.weight.device != torch.device("meta"):
new_embedding.weight = mod.weight
return new_embedding


# ======================================
# | Embedding int4 weight-only QAT |
Expand Down Expand Up @@ -115,6 +136,10 @@ def prepare(
"""
Swap `nn.Embedding` modules with `Int4WeightOnlyQATEmbedding`.
"""
# avoid circular imports
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
)

def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
return isinstance(child, torch.nn.Embedding)
Expand Down
22 changes: 22 additions & 0 deletions torchao/quantization/qat/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
w = self.weight
return F.linear(x, w)

@classmethod
def from_linear(
cls,
mod: torch.nn.Linear,
activation_config: Optional[FakeQuantizeConfig] = None,
weight_config: Optional[FakeQuantizeConfig] = None,
):
new_linear = FakeQuantizedLinear(
mod.in_features,
mod.out_features,
mod.bias,
activation_config=activation_config,
weight_config=weight_config,
device=mod.weight.device,
)
# In distributed training, the model may be instantiated
# on the meta device, in which case there is no need to
# copy the weights, and doing so will result in an error
if mod.weight.device != torch.device("meta"):
new_linear.weight = mod.weight
return new_linear


class _LegacyQATQuantizer(TwoStepQuantizer):
"""
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@
LinearActivationQuantizedTensor,
to_linear_activation_quantized,
)
from .qat import (
intx_quantization_aware_training,
)
from .quant_primitives import (
MappingType,
ZeroPointDomain,
Expand Down Expand Up @@ -101,6 +104,7 @@
"int8_dynamic_activation_int8_semi_sparse_weight",
"int4_weight_only",
"int8_weight_only",
"intx_quantization_aware_training",
"float8_weight_only",
"uintx_weight_only",
"fpx_weight_only",
Expand Down

0 comments on commit 200589b

Please sign in to comment.