From dac3ae9849b98cfbc30fa8594d90b5b2daf1c4bf Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 4 Sep 2024 12:11:28 -0700 Subject: [PATCH] [StaticQuant] add a linear observer class and test stack-info: PR: https://github.com/pytorch/ao/pull/807, branch: drisspg/stack/8 --- ruff.toml | 2 + test/quantization/test_observer.py | 112 ++++++++- .../quantization/linear_observer_tensor.py | 236 ++++++++++++++++++ torchao/quantization/observer.py | 62 ++--- torchao/quantization/quant_primitives.py | 6 +- 5 files changed, 382 insertions(+), 36 deletions(-) create mode 100644 torchao/quantization/linear_observer_tensor.py diff --git a/ruff.toml b/ruff.toml index dee9710df4..3decb1b809 100644 --- a/ruff.toml +++ b/ruff.toml @@ -8,4 +8,6 @@ include = [ "torchao/dtypes/nf4tensor.py", "test/dtypes/test_nf4.py", "torchao/float8/float8_tensor.py", + "torchao/quantization/linear_observer_tensor.py", + "test/quantization/test_observer.py", ] diff --git a/test/quantization/test_observer.py b/test/quantization/test_observer.py index e0c9257a96..dbed26db11 100644 --- a/test/quantization/test_observer.py +++ b/test/quantization/test_observer.py @@ -1,5 +1,6 @@ import re import torch +import torch.nn as nn from torch.testing._internal.common_utils import TestCase from torchao.quantization.observer import ( AffineQuantizedMinMaxObserver, @@ -9,13 +10,23 @@ from torchao.quantization.quant_primitives import ( MappingType, ) +from torchao.quantization.linear_observer_tensor import ( + insert_observers_, +) +from torch.testing._internal import common_utils import unittest + # NOTE: we can copy paste these here if we decide to deprecate them in torch.ao from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver + class TestQuantFlow(TestCase): def _test_obs_helper(self, obs1, obs2): - example_inputs = [torch.randn(10, 2048), torch.randn(10, 2048), torch.randn(10, 2048)] + example_inputs = [ + torch.randn(10, 2048), + torch.randn(10, 2048), + torch.randn(10, 2048), + ] for example_input in example_inputs: obs1(example_input) obs2(example_input) @@ -26,13 +37,29 @@ def _test_obs_helper(self, obs1, obs2): self.assertTrue(torch.allclose(zero_point1, zero_point2)) def test_min_max_per_tensor_affine(self): - obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int) + obs = AffineQuantizedMinMaxObserver( + MappingType.ASYMMETRIC, + torch.uint8, + granularity_type=PerTensor(), + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float, + zero_point_dtype=torch.int, + ) ref_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine) self._test_obs_helper(obs, ref_obs) def test_min_max_per_channel_affine(self): - obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int) - ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine) + obs = AffineQuantizedMinMaxObserver( + MappingType.ASYMMETRIC, + torch.uint8, + granularity_type=PerAxis(axis=0), + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float, + zero_point_dtype=torch.int, + ) + ref_obs = PerChannelMinMaxObserver( + dtype=torch.uint8, qscheme=torch.per_channel_affine + ) self._test_obs_helper(obs, ref_obs) def test_block_size_calc_success(self): @@ -109,5 +136,82 @@ def test_block_size_row_errors(self): obs(example_input) +class TestLinearObserver(TestCase): + @common_utils.parametrize("observe_weight", [True, False]) + def test_linear_observer_tensor(self, observe_weight: bool): + # Create a simple linear layer + in_features, out_features = 10, 5 + linear = nn.Linear(in_features, out_features) + + # Create observers + input_observer = AffineQuantizedMinMaxObserver( + MappingType.SYMMETRIC, + torch.float8_e4m3fn, + granularity_type=PerTensor(), + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float, + zero_point_dtype=torch.int, + zero_point_domain=None, + ) + if observe_weight: + weight_observer = AffineQuantizedMinMaxObserver( + MappingType.SYMMETRIC, + torch.float8_e4m3fn, + granularity_type=PerTensor(), + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float, + zero_point_dtype=torch.int, + zero_point_domain=None, + ) + else: + weight_observer = None + + # Wrap the weight with LinearObserverTensor + insert_observers_(linear, input_observer, weight_observer) + + # Create some example inputs + example_inputs = [torch.randn(5, in_features) for _ in range(3)] + max_val = 42.1234 + min_val = -39.760 + big_tensor = torch.full((6, in_features), max_val) + small_tensor = torch.full((40, in_features), min_val) + example_inputs.extend([big_tensor, small_tensor]) + + # Run forward passes + for example_input in example_inputs: + _ = linear(example_input) + + input_observer = linear.weight.input_observer + + # Check that the observers have recorded statistics + assert input_observer.min_val == min_val + assert input_observer.max_val == max_val + + # Calculate qparams and ensure they're not None + input_scale, input_zero_point = input_observer.calculate_qparams() + + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + self.assertEqual( + input_scale.item(), + max_val / max_fp8, + ) + self.assertIsNotNone(input_zero_point) + + if observe_weight: + weight_observer = linear.weight.weight_observer + weight_scale, weight_zero_point = weight_observer.calculate_qparams() + torch.testing.assert_close( + weight_scale, + torch.max(linear.weight.original_weight_tensor) / max_fp8, + atol=5e-5, + rtol=0.0, + ) + self.assertIsNotNone(weight_zero_point) + else: + self.assertIsNone(linear.weight.weight_observer) + + +common_utils.instantiate_parametrized_tests(TestLinearObserver) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/linear_observer_tensor.py b/torchao/quantization/linear_observer_tensor.py new file mode 100644 index 0000000000..3474472306 --- /dev/null +++ b/torchao/quantization/linear_observer_tensor.py @@ -0,0 +1,236 @@ +import torch +import torch.nn as nn +from typing import Callable, Optional, Dict +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.utils import ( + TorchAOBaseTensor, + TORCH_VERSION_AT_LEAST_2_5, +) + +from torchao.quantization.quant_api import ( + _replace_with_custom_fn_if_matches_filter, + _is_linear, +) +from torchao.quantization.observer import AffineQuantizedObserverBase + +__all__ = [ + "LinearActivationWeightObservedTensor", + "insert_observers_", +] + +aten = torch.ops.aten +Tensor = torch.Tensor + + +class LinearActivationWeightObservedTensor(TorchAOBaseTensor): + """ + This subclass of Tensor is used in conjuction with a static calibration flow. + The flow is broken up into 3 parts; + 1. Insert the LinearActivationWeightObservedTensor subclass into the model's nn.Linear layers + 2. Run the model with a calibration dataset, the observer will record the min/max of the input and weight + 3. quantize_ the model to static using the statistics recorded by the observer + + This subclass wraps the original weight tensor on the nn.Linear layer. When forward is called, the observer + will first calculat statistics on BOTH the input and weight, and then run the linear op. + """ + + original_weight_tensor: torch.Tensor + input_observer: Optional[AffineQuantizedObserverBase] + weight_observer: Optional[AffineQuantizedObserverBase] + + def __new__( + cls, + original_weight_tensor: torch.Tensor, + input_observer: Optional[AffineQuantizedObserverBase] = None, + weight_observer: Optional[AffineQuantizedObserverBase] = None, + ): + kwargs = {} + dtype = original_weight_tensor.dtype + kwargs["dtype"] = dtype + kwargs["requires_grad"] = False + kwargs["device"] = original_weight_tensor.device + shape = original_weight_tensor.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + original_weight_tensor: torch.Tensor, + input_observer: Optional[AffineQuantizedObserverBase] = None, + weight_observer: Optional[AffineQuantizedObserverBase] = None, + ): + self.original_weight_tensor = original_weight_tensor + self.input_observer = input_observer + self.weight_observer = weight_observer + + def __repr__(self): + return ( + f"LinearActivationWeightObservedTensor(\n" + f"original_weight={self.original_weight_tensor}\n" + f"input_observer={self.input_observer.__class__.__name__ if self.input_observer else None}\n" + f"weight_observer={self.weight_observer.__class__.__name__ if self.weight_observer else None}\n)" + ) + + def __tensor_flatten__(self): + return ["original_weight_tensor"], [self.input_observer, self.weight_observer] + + @classmethod + def __tensor_unflatten__( + cls, + tensor_data_dict: Dict[str, Tensor], + tensor_attributes, + outer_size, + outer_stride, + ): + original_weight_tensor = tensor_data_dict["original_weight_tensor"] + (input_observer, weight_observer) = tensor_attributes + return cls(original_weight_tensor, input_observer, weight_observer) + + @classmethod + def from_float( + cls, + original_weight_tensor: Tensor, + input_observer: Optional[AffineQuantizedObserverBase] = None, + weight_observer: Optional[AffineQuantizedObserverBase] = None, + ): + return cls(original_weight_tensor, input_observer, weight_observer) + + def _apply_fn_to_data(self, fn: Callable): + """Applies a fn to the tensor component of the LinearActivationWeightObservedTensor""" + return self.__class__( + fn(self.original_weight_tensor), + self.input_observer, + self.weight_observer, + ) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self._apply_fn_to_data(lambda x: x.to(**kwargs)) + + +implements = LinearActivationWeightObservedTensor.implements + + +@implements(torch.nn.functional.linear) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + if weight_tensor.input_observer is not None: + input_tensor = weight_tensor.input_observer(input_tensor) + if weight_tensor.weight_observer is not None: + weight_tensor = weight_tensor.weight_observer( + weight_tensor.original_weight_tensor + ) + else: + weight_tensor = weight_tensor.original_weight_tensor + + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +@implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +@implements(aten._to_copy.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + +if TORCH_VERSION_AT_LEAST_2_5: + # Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` + torch.serialization.add_safe_globals([LinearActivationWeightObservedTensor]) + + +def insert_observers_( + model: nn.Module, + input_observer: Optional[AffineQuantizedObserverBase], + weight_observer: Optional[AffineQuantizedObserverBase], + *, + filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, +): + """ + Converts the weight of a linear module to a LinearActivationWeightObservedTensor. + + This function wraps the weight of the given linear module with a LinearActivationWeightObservedTensor, + which enables observation of both input and weight tensors during forward passes. + The wrapped weight is then re-wrapped as a nn.Parameter to maintain compatibility + with PyTorch's module system. + + Example:: + + ``` + import torch + import torch.nn as nn + from torchao.quantization.linear_observer_tensor import insert_observers_ + from torchao.quantization.observer import ( + AffineQuantizedMinMaxObserver, + PerTensor, + MappingType + ) + + # Create observers + input_observer = AffineQuantizedMinMaxObserver( + MappingType.SYMMETRIC, + torch.float8_e4m3fn, + granularity_type=PerTensor(), + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float, + zero_point_dtype=torch.int, + zero_point_domain=None, + ) + + # Create a linear module + linear_module = nn.Linear(10, 20) + + # Convert the linear module's weight to an observed tensor + insert_observers_(linear_module, input_observer, weight_observer=None) + + # The linear_module can now be used as usual, with observers calculating statistics + output = linear_module(torch.randn(10, 10)) + ``` + + Args: + model (nn.Module): The nn.Module to convert. + input_observer (Optional[AffineQuantizedObserverBase]): Observer for input tensor. + weight_observer (Optional[AffineQuantizedObserverBase]): Observer for weight tensor. + filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): Filter function to select which modules to convert. + If not provided, all linear modules will be converted. + + Returns: + nn.Linear: The modified linear module with its weight wrapped in a LinearActivationWeightObservedTensor. + """ + + def convert_to_linear_observer(linear_module: nn.Linear): + # Wrap the weight with LinearActivationWeightObservedTensor and then with nn.Parameter + linear_module.weight = nn.Parameter( + LinearActivationWeightObservedTensor.from_float( + linear_module.weight, + input_observer=input_observer, + weight_observer=weight_observer, + ), + requires_grad=linear_module.weight.requires_grad, + ) + return linear_module + + _replace_with_custom_fn_if_matches_filter( + model, + convert_to_linear_observer, + _is_linear if filter_fn is None else filter_fn, + ) diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index 984f2a765e..33a6fc477e 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -8,9 +8,10 @@ from abc import ABCMeta, abstractmethod from dataclasses import dataclass -from typing import Callable, List, Tuple, Optional, Any +from typing import Tuple, Optional, Any from functools import partial import logging + logger = logging.getLogger(__name__) @@ -19,27 +20,21 @@ class GranularityType: """ Base class for representing the granularity of quantization. -<<<<<<< Updated upstream -======= This class serves as a parent for specific granularity types used in quantization operations, such as per-tensor or per-axis quantization. """ pass ->>>>>>> Stashed changes @dataclass(frozen=True) class PerTensor(GranularityType): """ Represents per-tensor granularity in quantization. -<<<<<<< Updated upstream -======= This granularity type calcualtes the quantization parameters based off the entire tensor. """ pass ->>>>>>> Stashed changes @dataclass(frozen=True) class PerAxis(GranularityType): """ @@ -53,6 +48,7 @@ class PerAxis(GranularityType): """ axis: int + # borrowed from torch.ao.quantization.observer class _PartialWrapper: def __init__(self, p): @@ -67,6 +63,7 @@ def __repr__(self): def with_args(self, *args, **kwargs): return _with_args(self, *args, **kwargs) + def _with_args(cls_or_self, *args, **kwargs): r"""Wrapper that allows creation of class factories. @@ -86,20 +83,7 @@ def _with_args(cls_or_self, *args, **kwargs): r = _PartialWrapper(partial(cls_or_self, *args, **kwargs)) return r -<<<<<<< Updated upstream def get_block_size(input_shape: Tuple[int, ...], granularity_type: GranularityType) -> Tuple[int, ...]: -======= - -def get_block_size( - input_shape: Tuple[int, ...], granularity_type: GranularityType -) -> Tuple[int, ...]: - """Get the block size based on the input shape and granularity type. - - Args: - input_shape: The input tensor shape possibly more than 2 dimensions - granularity_type: The granularity type of the quantization - """ ->>>>>>> Stashed changes if isinstance(granularity_type, PerTensor): return input_shape elif isinstance(granularity_type, PerAxis): @@ -108,8 +92,10 @@ def get_block_size( return tuple(block_size) raise ValueError(f"Unsupported GranularityType: {granularity_type}") + ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: + class AffineQuantizedObserverBase(ABC, torch.nn.Module): """Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization) @@ -119,9 +105,11 @@ class AffineQuantizedObserverBase(ABC, torch.nn.Module): Current supported granularity type are `PerTensor` and `PerAxis` other args: please see `:class:torchao.dtypes.AffineQuantizedTensor` """ + with_args = classmethod(_with_args) - def __init__(self, + def __init__( + self, mapping_type: MappingType, target_dtype: torch.dtype, block_size: Optional[Tuple[int, ...]] = None, @@ -132,12 +120,16 @@ def __init__(self, scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain = ZeroPointDomain.INT, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, ): super().__init__() - assert block_size is not None or granularity_type is not None, "Must specify either block_size or granularity_type" + assert ( + block_size is not None or granularity_type is not None + ), "Must specify either block_size or granularity_type" if block_size is not None and granularity_type is not None: - logger.warning("Both block_size and granularity_type are specified, ignoring granularity_type. block_size: {block_size}, granularity_type: {granularity_type}") + logger.warning( + "Both block_size and granularity_type are specified, ignoring granularity_type. block_size: {block_size}, granularity_type: {granularity_type}" + ) self.mapping_type = mapping_type self.target_dtype = target_dtype self.block_size = block_size @@ -152,7 +144,7 @@ def __init__(self, @abstractmethod def forward(self, input: torch.Tensor) -> torch.Tensor: - """ forward function should take the input tensor + """forward function should take the input tensor and updates internal stats and return the original input Tensor """ pass @@ -164,6 +156,7 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: """ pass + class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase): def forward(self, input: torch.Tensor): if input.numel() == 0: @@ -178,7 +171,9 @@ def forward(self, input: torch.Tensor): if self.block_size is None and not isinstance(self.granularity_type, PerTensor): self.block_size = block_size - shape_for_reduction, reduction_dims = _get_reduction_params(self.block_size, input_detached.size()) + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, input_detached.size() + ) input_detached = input_detached.view(shape_for_reduction) min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False) max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False) @@ -196,12 +191,21 @@ def forward(self, input: torch.Tensor): return input def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: - assert hasattr(self, "min_val") and hasattr(self, "max_val"), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + assert ( + hasattr(self, "min_val") and hasattr(self, "max_val") + ), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + if self.block_size is None: + assert isinstance( + self.granularity_type, PerTensor + ), "block_size is None, but granularity_type is not PerTensor" + block_size = [] + else: + block_size = self.block_size return choose_qparams_affine_with_min_max( self.min_val, self.max_val, self.mapping_type, - self.block_size, + block_size, self.target_dtype, self.quant_min, self.quant_max, @@ -209,5 +213,5 @@ def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: self.scale_dtype, self.zero_point_dtype, self.preserve_zero, - self.zero_point_domain + self.zero_point_domain, ) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index a5fddcb98c..9b8d6378be 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -580,7 +580,7 @@ def choose_qparams_affine( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain = ZeroPointDomain.INT, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -641,7 +641,7 @@ def choose_qparams_affine_with_min_max( scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, - zero_point_domain = ZeroPointDomain.INT, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, ) -> Tuple[torch.Tensor, torch.Tensor]: """A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine` operator that pass in min_val and max_val directly instead of deriving these from a single input. @@ -664,7 +664,7 @@ def choose_qparams_affine_with_min_max( scale_dtype, zero_point_dtype, preserve_zero, - zero_point_domain.name, + zero_point_domain.name if zero_point_domain is not None else None, min_val, max_val, )