Skip to content

Commit

Permalink
[StaticQuant] add a linear observer class and test
Browse files Browse the repository at this point in the history
stack-info: PR: #807, branch: drisspg/stack/8
  • Loading branch information
drisspg committed Sep 4, 2024
1 parent 848e123 commit e7ffec3
Show file tree
Hide file tree
Showing 5 changed files with 354 additions and 19 deletions.
2 changes: 2 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ include = [
"torchao/dtypes/nf4tensor.py",
"test/dtypes/test_nf4.py",
"torchao/float8/float8_tensor.py",
"torchao/quantization/linear_observer_tensor.py",
"test/quantization/test_observer.py",
]
102 changes: 98 additions & 4 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import TestCase
from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
PerTensor,
PerAxis,
)
from torchao.quantization import quantize_
from torchao.quantization.quant_primitives import (
MappingType,
)
from torchao.quantization.linear_observer_tensor import (
insert_observers,
)
import unittest

# NOTE: we can copy paste these here if we decide to deprecate them in torch.ao
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver


class TestQuantFlow(TestCase):
def _test_obs_helper(self, obs1, obs2):
example_inputs = [torch.randn(10, 2048), torch.randn(10, 2048), torch.randn(10, 2048)]
example_inputs = [
torch.randn(10, 2048),
torch.randn(10, 2048),
torch.randn(10, 2048),
]
for example_input in example_inputs:
obs1(example_input)
obs2(example_input)
Expand All @@ -25,15 +36,98 @@ def _test_obs_helper(self, obs1, obs2):
self.assertTrue(torch.allclose(zero_point1, zero_point2))

def test_min_max_per_tensor_affine(self):
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity_type=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
)
ref_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine)
self._test_obs_helper(obs, ref_obs)

def test_min_max_per_channel_affine(self):
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine)
obs = AffineQuantizedMinMaxObserver(
MappingType.ASYMMETRIC,
torch.uint8,
granularity_type=PerAxis(axis=0),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
)
ref_obs = PerChannelMinMaxObserver(
dtype=torch.uint8, qscheme=torch.per_channel_affine
)
self._test_obs_helper(obs, ref_obs)


class TestLinearObserver(TestCase):
def test_linear_observer_tensor(self):
# Create a simple linear layer
in_features, out_features = 10, 5
linear = nn.Linear(in_features, out_features)

# Create observers
input_observer = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
)
weight_observer = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
)

# Wrap the weight with LinearObserverTensor
quantize_(linear, insert_observers(input_observer, weight_observer))

# Create some example inputs
example_inputs = [torch.randn(5, in_features) for _ in range(3)]
max_val = 42.1234
min_val = -39.760
big_tensor = torch.full((6, in_features), max_val)
small_tensor = torch.full((40, in_features), min_val)
example_inputs.extend([big_tensor, small_tensor])

# Run forward passes
for example_input in example_inputs:
_ = linear(example_input)

input_observer = linear.weight.input_observer
weight_observer = linear.weight.weight_observer

# Check that the observers have recorded statistics
assert input_observer.min_val == min_val
assert input_observer.max_val == max_val

# Calculate qparams and ensure they're not None
input_scale, input_zero_point = input_observer.calculate_qparams()
weight_scale, weight_zero_point = weight_observer.calculate_qparams()

max_fp8 = torch.finfo(torch.float8_e4m3fn).max
self.assertEqual(
input_scale.item(),
max_val / max_fp8,
)
self.assertIsNotNone(input_zero_point)
torch.testing.assert_close(
weight_scale,
torch.max(linear.weight.original_weight_tensor) / max_fp8,
atol=5e-5,
rtol=0.0,
)
self.assertIsNotNone(weight_zero_point)


if __name__ == "__main__":
unittest.main()
205 changes: 205 additions & 0 deletions torchao/quantization/linear_observer_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import torch
import torch.nn as nn
from torchao.utils import (
_implements,
_dispatch__torch_function__,
_dispatch__torch_dispatch__,
)
from typing import Callable, Optional, Dict
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.utils import (
TorchAOBaseTensor,
TORCH_VERSION_AT_LEAST_2_5,
)

from torchao.quantization.observer import AffineQuantizedObserverBase

__all__ = [
"LinearObserverTensor",
"to_linear_observer_tensor",
"insert_observers",
]

aten = torch.ops.aten
Tensor = torch.Tensor


class LinearObserverTensor(TorchAOBaseTensor):
"""
This subclass of Tensor is used in conjuction with a static calibration flow.
The flow is broken up into 3 parts;
1. Insert the LinearObserverTensor subclass into the model's nn.Linear layers
2. Run the model with a calibration dataset, the observer will record the min/max of the input and weight
3. quantize_ the model to static using the statistics recorded by the observer
This subclass wraps the original weight tensor on the nn.Linear layer. When forward is called, the observer
will first calculat statistics on BOTH the input and weight, and then run the linear op.
"""

original_weight_tensor: torch.Tensor
input_observer: Optional[AffineQuantizedObserverBase]
weight_observer: Optional[AffineQuantizedObserverBase]

def __new__(
cls,
original_weight_tensor: torch.Tensor,
input_observer: Optional[AffineQuantizedObserverBase] = None,
weight_observer: Optional[AffineQuantizedObserverBase] = None,
):
kwargs = {}
dtype = original_weight_tensor.dtype
kwargs["dtype"] = dtype
kwargs["requires_grad"] = False
kwargs["device"] = original_weight_tensor.device
shape = original_weight_tensor.shape
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
original_weight_tensor: torch.Tensor,
input_observer: Optional[AffineQuantizedObserverBase] = None,
weight_observer: Optional[AffineQuantizedObserverBase] = None,
):
self.original_weight_tensor = original_weight_tensor
self.input_observer = input_observer
self.weight_observer = weight_observer

def __repr__(self):
return (
f"LinearObserverTensor(\n"
f"original_weight={self.original_weight_tensor}\n"
f"input_observer={self.input_observer.__class__.__name__ if self.input_observer else None}\n"
f"weight_observer={self.weight_observer.__class__.__name__ if self.weight_observer else None}\n)"
)

def __tensor_flatten__(self):
return ["original_weight_tensor"], [self.input_observer, self.weight_observer]

@classmethod
def __tensor_unflatten__(
cls,
tensor_data_dict: Dict[str, Tensor],
tensor_attributes,
outer_size,
outer_stride,
):
original_weight_tensor = tensor_data_dict["original_weight_tensor"]
(input_observer, weight_observer) = tensor_attributes
return cls(original_weight_tensor, input_observer, weight_observer)

@classmethod
def from_float(
cls,
original_weight_tensor: Tensor,
input_observer: Optional[AffineQuantizedObserverBase] = None,
weight_observer: Optional[AffineQuantizedObserverBase] = None,
):
return cls(original_weight_tensor, input_observer, weight_observer)

def _apply_fn_to_data(self, fn: Callable):
"""Applies a fn to the tensor component of the LinearObserverTensor"""
return self.__class__(
fn(self.original_weight_tensor),
self.input_observer,
self.weight_observer,
)

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
return self._apply_fn_to_data(lambda x: x.to(**kwargs))

implements = classmethod(_implements)
__torch_function__ = classmethod(_dispatch__torch_function__)
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)


implements = LinearObserverTensor.implements


@implements(torch.nn.functional.linear)
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
if weight_tensor.input_observer is not None:
input_tensor = weight_tensor.input_observer(input_tensor)
if weight_tensor.weight_observer is not None:
weight_tensor = weight_tensor.weight_observer(
weight_tensor.original_weight_tensor
)

return torch.nn.functional.linear(input_tensor, weight_tensor, bias)


@implements(aten.detach.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)


@implements(aten.clone.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)


@implements(aten._to_copy.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
)


to_linear_observer_tensor = LinearObserverTensor.from_float

if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([LinearObserverTensor])


def insert_observers(
input_observer: Optional[AffineQuantizedObserverBase],
weight_observer: Optional[AffineQuantizedObserverBase],
) -> Callable:
"""
Converts the weight of a linear module to a LinearObserverTensor.
This function wraps the weight of the given linear module with a LinearObserverTensor,
which enables observation of both input and weight tensors during forward passes.
The wrapped weight is then re-wrapped as a nn.Parameter to maintain compatibility
with PyTorch's module system.
Usaage:
```
linear_module = nn.Linear(10, 20)
quantize_(linear_module, convert_linear_weight_to_observer))
Args:
linear_module (nn.Linear): The linear module to be converted.
input_observer (Optional[AffineQuantizedObserverBase]): Observer for input tensor.
weight_observer (Optional[AffineQuantizedObserverBase]): Observer for weight tensor.
Returns:
nn.Linear: The modified linear module with its weight wrapped in a LinearObserverTensor.
"""

def convert_to_linear_observer(linear_module: nn.Linear):
# Wrap the weight with LinearObserverTensor and then with nn.Parameter
linear_module.weight = nn.Parameter(
to_linear_observer_tensor(
linear_module.weight,
input_observer=input_observer,
weight_observer=weight_observer,
),
requires_grad=linear_module.weight.requires_grad,
)
return linear_module

return convert_to_linear_observer
Loading

0 comments on commit e7ffec3

Please sign in to comment.