Skip to content

Commit

Permalink
[StaticQuant] add a linear observer class and test
Browse files Browse the repository at this point in the history
stack-info: PR: #807, branch: drisspg/stack/8
  • Loading branch information
drisspg committed Sep 5, 2024
1 parent 591472c commit dac3ae9
Show file tree
Hide file tree
Showing 5 changed files with 382 additions and 36 deletions.
2 changes: 2 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ include = [
"torchao/dtypes/nf4tensor.py",
"test/dtypes/test_nf4.py",
"torchao/float8/float8_tensor.py",
"torchao/quantization/linear_observer_tensor.py",
"test/quantization/test_observer.py",
]
112 changes: 108 additions & 4 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import TestCase
from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
Expand All @@ -9,13 +10,23 @@
from torchao.quantization.quant_primitives import (
MappingType,
)
from torchao.quantization.linear_observer_tensor import (
insert_observers_,
)
from torch.testing._internal import common_utils
import unittest

# NOTE: we can copy paste these here if we decide to deprecate them in torch.ao
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver


class TestQuantFlow(TestCase):
def _test_obs_helper(self, obs1, obs2):
example_inputs = [torch.randn(10, 2048), torch.randn(10, 2048), torch.randn(10, 2048)]
example_inputs = [
torch.randn(10, 2048),
torch.randn(10, 2048),
torch.randn(10, 2048),
]
for example_input in example_inputs:
obs1(example_input)
obs2(example_input)
Expand All @@ -26,13 +37,29 @@ def _test_obs_helper(self, obs1, obs2):
self.assertTrue(torch.allclose(zero_point1, zero_point2))

def test_min_max_per_tensor_affine(self):
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity_type=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
)
ref_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine)
self._test_obs_helper(obs, ref_obs)

def test_min_max_per_channel_affine(self):
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine)
obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity_type=PerAxis(axis=0),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
)
ref_obs = PerChannelMinMaxObserver(
dtype=torch.uint8, qscheme=torch.per_channel_affine
)
self._test_obs_helper(obs, ref_obs)

def test_block_size_calc_success(self):
Expand Down Expand Up @@ -109,5 +136,82 @@ def test_block_size_row_errors(self):
obs(example_input)


class TestLinearObserver(TestCase):
@common_utils.parametrize("observe_weight", [True, False])
def test_linear_observer_tensor(self, observe_weight: bool):
# Create a simple linear layer
in_features, out_features = 10, 5
linear = nn.Linear(in_features, out_features)

# Create observers
input_observer = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
)
if observe_weight:
weight_observer = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
)
else:
weight_observer = None

# Wrap the weight with LinearObserverTensor
insert_observers_(linear, input_observer, weight_observer)

# Create some example inputs
example_inputs = [torch.randn(5, in_features) for _ in range(3)]
max_val = 42.1234
min_val = -39.760
big_tensor = torch.full((6, in_features), max_val)
small_tensor = torch.full((40, in_features), min_val)
example_inputs.extend([big_tensor, small_tensor])

# Run forward passes
for example_input in example_inputs:
_ = linear(example_input)

input_observer = linear.weight.input_observer

# Check that the observers have recorded statistics
assert input_observer.min_val == min_val
assert input_observer.max_val == max_val

# Calculate qparams and ensure they're not None
input_scale, input_zero_point = input_observer.calculate_qparams()

max_fp8 = torch.finfo(torch.float8_e4m3fn).max
self.assertEqual(
input_scale.item(),
max_val / max_fp8,
)
self.assertIsNotNone(input_zero_point)

if observe_weight:
weight_observer = linear.weight.weight_observer
weight_scale, weight_zero_point = weight_observer.calculate_qparams()
torch.testing.assert_close(
weight_scale,
torch.max(linear.weight.original_weight_tensor) / max_fp8,
atol=5e-5,
rtol=0.0,
)
self.assertIsNotNone(weight_zero_point)
else:
self.assertIsNone(linear.weight.weight_observer)


common_utils.instantiate_parametrized_tests(TestLinearObserver)

if __name__ == "__main__":
unittest.main()
236 changes: 236 additions & 0 deletions torchao/quantization/linear_observer_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
import torch
import torch.nn as nn
from typing import Callable, Optional, Dict
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.utils import (
TorchAOBaseTensor,
TORCH_VERSION_AT_LEAST_2_5,
)

from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
_is_linear,
)
from torchao.quantization.observer import AffineQuantizedObserverBase

__all__ = [
"LinearActivationWeightObservedTensor",
"insert_observers_",
]

aten = torch.ops.aten
Tensor = torch.Tensor


class LinearActivationWeightObservedTensor(TorchAOBaseTensor):
"""
This subclass of Tensor is used in conjuction with a static calibration flow.
The flow is broken up into 3 parts;
1. Insert the LinearActivationWeightObservedTensor subclass into the model's nn.Linear layers
2. Run the model with a calibration dataset, the observer will record the min/max of the input and weight
3. quantize_ the model to static using the statistics recorded by the observer
This subclass wraps the original weight tensor on the nn.Linear layer. When forward is called, the observer
will first calculat statistics on BOTH the input and weight, and then run the linear op.
"""

original_weight_tensor: torch.Tensor
input_observer: Optional[AffineQuantizedObserverBase]
weight_observer: Optional[AffineQuantizedObserverBase]

def __new__(
cls,
original_weight_tensor: torch.Tensor,
input_observer: Optional[AffineQuantizedObserverBase] = None,
weight_observer: Optional[AffineQuantizedObserverBase] = None,
):
kwargs = {}
dtype = original_weight_tensor.dtype
kwargs["dtype"] = dtype
kwargs["requires_grad"] = False
kwargs["device"] = original_weight_tensor.device
shape = original_weight_tensor.shape
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
original_weight_tensor: torch.Tensor,
input_observer: Optional[AffineQuantizedObserverBase] = None,
weight_observer: Optional[AffineQuantizedObserverBase] = None,
):
self.original_weight_tensor = original_weight_tensor
self.input_observer = input_observer
self.weight_observer = weight_observer

def __repr__(self):
return (
f"LinearActivationWeightObservedTensor(\n"
f"original_weight={self.original_weight_tensor}\n"
f"input_observer={self.input_observer.__class__.__name__ if self.input_observer else None}\n"
f"weight_observer={self.weight_observer.__class__.__name__ if self.weight_observer else None}\n)"
)

def __tensor_flatten__(self):
return ["original_weight_tensor"], [self.input_observer, self.weight_observer]

@classmethod
def __tensor_unflatten__(
cls,
tensor_data_dict: Dict[str, Tensor],
tensor_attributes,
outer_size,
outer_stride,
):
original_weight_tensor = tensor_data_dict["original_weight_tensor"]
(input_observer, weight_observer) = tensor_attributes
return cls(original_weight_tensor, input_observer, weight_observer)

@classmethod
def from_float(
cls,
original_weight_tensor: Tensor,
input_observer: Optional[AffineQuantizedObserverBase] = None,
weight_observer: Optional[AffineQuantizedObserverBase] = None,
):
return cls(original_weight_tensor, input_observer, weight_observer)

def _apply_fn_to_data(self, fn: Callable):
"""Applies a fn to the tensor component of the LinearActivationWeightObservedTensor"""
return self.__class__(
fn(self.original_weight_tensor),
self.input_observer,
self.weight_observer,
)

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
return self._apply_fn_to_data(lambda x: x.to(**kwargs))


implements = LinearActivationWeightObservedTensor.implements


@implements(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
if weight_tensor.input_observer is not None:
input_tensor = weight_tensor.input_observer(input_tensor)
if weight_tensor.weight_observer is not None:
weight_tensor = weight_tensor.weight_observer(
weight_tensor.original_weight_tensor
)
else:
weight_tensor = weight_tensor.original_weight_tensor

return torch.nn.functional.linear(input_tensor, weight_tensor, bias)


@implements(aten.detach.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)


@implements(aten.clone.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)


@implements(aten._to_copy.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
)


if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([LinearActivationWeightObservedTensor])


def insert_observers_(
model: nn.Module,
input_observer: Optional[AffineQuantizedObserverBase],
weight_observer: Optional[AffineQuantizedObserverBase],
*,
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
):
"""
Converts the weight of a linear module to a LinearActivationWeightObservedTensor.
This function wraps the weight of the given linear module with a LinearActivationWeightObservedTensor,
which enables observation of both input and weight tensors during forward passes.
The wrapped weight is then re-wrapped as a nn.Parameter to maintain compatibility
with PyTorch's module system.
Example::
```
import torch
import torch.nn as nn
from torchao.quantization.linear_observer_tensor import insert_observers_
from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
PerTensor,
MappingType
)
# Create observers
input_observer = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
)
# Create a linear module
linear_module = nn.Linear(10, 20)
# Convert the linear module's weight to an observed tensor
insert_observers_(linear_module, input_observer, weight_observer=None)
# The linear_module can now be used as usual, with observers calculating statistics
output = linear_module(torch.randn(10, 10))
```
Args:
model (nn.Module): The nn.Module to convert.
input_observer (Optional[AffineQuantizedObserverBase]): Observer for input tensor.
weight_observer (Optional[AffineQuantizedObserverBase]): Observer for weight tensor.
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): Filter function to select which modules to convert.
If not provided, all linear modules will be converted.
Returns:
nn.Linear: The modified linear module with its weight wrapped in a LinearActivationWeightObservedTensor.
"""

def convert_to_linear_observer(linear_module: nn.Linear):
# Wrap the weight with LinearActivationWeightObservedTensor and then with nn.Parameter
linear_module.weight = nn.Parameter(
LinearActivationWeightObservedTensor.from_float(
linear_module.weight,
input_observer=input_observer,
weight_observer=weight_observer,
),
requires_grad=linear_module.weight.requires_grad,
)
return linear_module

_replace_with_custom_fn_if_matches_filter(
model,
convert_to_linear_observer,
_is_linear if filter_fn is None else filter_fn,
)
Loading

0 comments on commit dac3ae9

Please sign in to comment.