From 1a11857e44644ffcc03b1db61be7110653eac091 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 10 Jan 2025 11:53:37 -0800 Subject: [PATCH] Add convert path for quantize_ QAT API Summary: https://github.com/pytorch/ao/pull/1415 added a quantize_ QAT API for the prepare path. This commit adds the remaining convert path for users to actually perform end-to-end QAT using the quantize_ API. The new flow will look like: ``` from torchao.quantization import ( quantize_, int8_dynamic_activation_int4_weight, ) from torchao.quantization.qat import ( FakeQuantizeConfig, from_intx_quantization_aware_training, intx_quantization_aware_training, ) activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=32) quantize_( my_model, intx_quantization_aware_training(activation_config, weight_config), ) quantize_(my_model, from_intx_quantization_aware_training()) quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32)) ``` Test Plan: python test/quantization/test_qat.py -k test_quantize_api_convert_path ghstack-source-id: e6ea0427d2a307baa138afd2a4058298a21710b0 Pull Request resolved: https://github.com/pytorch/ao/pull/1540 --- test/quantization/test_qat.py | 65 +++++++++++++++++++++++++++ torchao/quantization/qat/__init__.py | 2 + torchao/quantization/qat/api.py | 40 ++++++++++++++++- torchao/quantization/qat/embedding.py | 18 ++++++++ torchao/quantization/qat/linear.py | 11 +++++ 5 files changed, 134 insertions(+), 2 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 42900c54f1..642f0bd4ad 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -25,6 +25,7 @@ from torchao.quantization.qat.api import ( ComposableQATQuantizer, FakeQuantizeConfig, + from_intx_quantization_aware_training, intx_quantization_aware_training, ) from torchao.quantization.qat.embedding import ( @@ -42,6 +43,9 @@ _GenericFakeQuantize, _get_qmin_qmax, ) +from torchao.quantization.quant_api import ( + int8_dynamic_activation_int4_weight, +) from torchao.quantization.quant_primitives import ( MappingType, TorchAODType, @@ -1262,6 +1266,67 @@ def test_quantize_api_errors(self): lambda m, _: isinstance(m, torch.nn.ReLU), ) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" + ) + def test_quantize_api_convert_path(self): + """ + Test that the following: + + quantize_(model, intx_quantization_aware_training(...)) + quantize_(model, from_intx_quantization_aware_training(...)) + quantize_(model, int8_dynamic_activation_int4_weight()) + + can produce the same results as `Int8DynActInt4WeightQATQuantizer` prepare + convert. + """ + from torchao.quantization.qat import ( + Int8DynActInt4WeightQATQuantizer, + ) + + group_size = 16 + torch.manual_seed(self.SEED) + m = M() + baseline_model = copy.deepcopy(m) + + # Baseline prepare + baseline_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + baseline_model = baseline_quantizer.prepare(baseline_model) + + # quantize_ prepare + activation_config = FakeQuantizeConfig( + torch.int8, + "per_token", + is_symmetric=False, + ) + weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size) + quantize_( + m, + intx_quantization_aware_training(activation_config, weight_config), + ) + + # Compare prepared values + torch.manual_seed(self.SEED) + x = m.example_inputs() + x2 = copy.deepcopy(x) + out = m(*x) + baseline_out = baseline_model(*x2) + torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) + + # Baseline convert + baseline_model = baseline_quantizer.convert(baseline_model) + + # quantize_ convert + quantize_(m, from_intx_quantization_aware_training()) + quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size)) + + # Compare converted values + torch.manual_seed(self.SEED) + x = m.example_inputs() + x2 = copy.deepcopy(x) + out = m(*x) + baseline_out = baseline_model(*x2) + torch.testing.assert_close(out, baseline_out, atol=0, rtol=0) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py index 75ba6f22db..15008e03ea 100644 --- a/torchao/quantization/qat/__init__.py +++ b/torchao/quantization/qat/__init__.py @@ -1,6 +1,7 @@ from .api import ( ComposableQATQuantizer, FakeQuantizeConfig, + from_intx_quantization_aware_training, intx_quantization_aware_training, ) from .embedding import ( @@ -18,4 +19,5 @@ "Int4WeightOnlyEmbeddingQATQuantizer", "Int8DynActInt4WeightQATQuantizer", "intx_quantization_aware_training", + "from_intx_quantization_aware_training", ] diff --git a/torchao/quantization/qat/api.py b/torchao/quantization/qat/api.py index 8f0244a858..cd3813291f 100644 --- a/torchao/quantization/qat/api.py +++ b/torchao/quantization/qat/api.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Any, List, Optional, Union +from typing import Any, Callable, List, Optional, Union import torch @@ -242,7 +242,7 @@ def __setattr__(self, name: str, value: Any): def intx_quantization_aware_training( activation_config: Optional[FakeQuantizeConfig] = None, weight_config: Optional[FakeQuantizeConfig] = None, -) -> torch.nn.Module: +) -> Callable: """ Return a function that applies fake quantization to a `torch.nn.Module`. to be used with :func:`~torchao.quantization.quant_api.quantize_`. @@ -295,6 +295,42 @@ def _insert_fake_quantize(mod: torch.nn.Module): return _insert_fake_quantize +def from_intx_quantization_aware_training() -> Callable: + """ + Return a function that converts a model with fake quantized modules, + such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear` + and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`, + back to model with the original, corresponding modules without + fake quantization. This should be used with + :func:`~torchao.quantization.quant_api.quantize_`. + + Example usage:: + + from torchao.quantization import quantize_ + quantize_( + model_with_fake_quantized_linears, + from_intx_quantization_aware_training(), + ) + """ + + def _remove_fake_quantize(mod: torch.nn.Module): + """ + If the given module is a fake quantized module, return the original + corresponding version of the module without fake quantization. + """ + from .embedding import FakeQuantizedEmbedding + from .linear import FakeQuantizedLinear + + if isinstance(mod, FakeQuantizedLinear): + return mod.to_linear() + elif isinstance(mod, FakeQuantizedEmbedding): + return mod.to_embedding() + else: + return mod + + return _remove_fake_quantize + + class ComposableQATQuantizer(TwoStepQuantizer): """ Composable quantizer that users can use to apply multiple QAT quantizers easily. diff --git a/torchao/quantization/qat/embedding.py b/torchao/quantization/qat/embedding.py index ff580ac1d3..cc63c5181d 100644 --- a/torchao/quantization/qat/embedding.py +++ b/torchao/quantization/qat/embedding.py @@ -82,6 +82,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.sparse, ) + def to_embedding(self) -> torch.nn.Embedding: + new_embedding = torch.nn.Embedding( + self.num_embeddings, + self.embedding_dim, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + device=self.weight.device, + ) + # In distributed training, the model may be instantiated + # on the meta device, in which case there is no need to + # copy the weights, and doing so will result in an error + if self.weight.device != torch.device("meta"): + new_embedding.weight = self.weight + return new_embedding + @classmethod def from_embedding( cls, diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index 153e324838..fafda68d58 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -105,6 +105,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: w = self.weight return F.linear(x, w) + def to_linear(self) -> torch.nn.Linear: + new_linear = torch.nn.Linear( + self.in_features, self.out_features, self.bias, device=self.weight.device + ) + # In distributed training, the model may be instantiated + # on the meta device, in which case there is no need to + # copy the weights, and doing so will result in an error + if self.weight.device != torch.device("meta"): + new_linear.weight = self.weight + return new_linear + @classmethod def from_linear( cls,