-
Notifications
You must be signed in to change notification settings - Fork 198
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[StaticQuant] add a linear observer class and test
stack-info: PR: #807, branch: drisspg/stack/8
- Loading branch information
Showing
5 changed files
with
354 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torchao.utils import ( | ||
_implements, | ||
_dispatch__torch_function__, | ||
_dispatch__torch_dispatch__, | ||
) | ||
from typing import Callable, Optional, Dict | ||
from torch.utils._python_dispatch import return_and_correct_aliasing | ||
from torchao.utils import ( | ||
TorchAOBaseTensor, | ||
TORCH_VERSION_AT_LEAST_2_5, | ||
) | ||
|
||
from torchao.quantization.observer import AffineQuantizedObserverBase | ||
|
||
__all__ = [ | ||
"LinearObserverTensor", | ||
"to_linear_observer_tensor", | ||
"insert_observers", | ||
] | ||
|
||
aten = torch.ops.aten | ||
Tensor = torch.Tensor | ||
|
||
|
||
class LinearObserverTensor(TorchAOBaseTensor): | ||
""" | ||
This subclass of Tensor is used in conjuction with a static calibration flow. | ||
The flow is broken up into 3 parts; | ||
1. Insert the LinearObserverTensor subclass into the model's nn.Linear layers | ||
2. Run the model with a calibration dataset, the observer will record the min/max of the input and weight | ||
3. quantize_ the model to static using the statistics recorded by the observer | ||
This subclass wraps the original weight tensor on the nn.Linear layer. When forward is called, the observer | ||
will first calculat statistics on BOTH the input and weight, and then run the linear op. | ||
""" | ||
|
||
original_weight_tensor: torch.Tensor | ||
input_observer: Optional[AffineQuantizedObserverBase] | ||
weight_observer: Optional[AffineQuantizedObserverBase] | ||
|
||
def __new__( | ||
cls, | ||
original_weight_tensor: torch.Tensor, | ||
input_observer: Optional[AffineQuantizedObserverBase] = None, | ||
weight_observer: Optional[AffineQuantizedObserverBase] = None, | ||
): | ||
kwargs = {} | ||
dtype = original_weight_tensor.dtype | ||
kwargs["dtype"] = dtype | ||
kwargs["requires_grad"] = False | ||
kwargs["device"] = original_weight_tensor.device | ||
shape = original_weight_tensor.shape | ||
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] | ||
|
||
def __init__( | ||
self, | ||
original_weight_tensor: torch.Tensor, | ||
input_observer: Optional[AffineQuantizedObserverBase] = None, | ||
weight_observer: Optional[AffineQuantizedObserverBase] = None, | ||
): | ||
self.original_weight_tensor = original_weight_tensor | ||
self.input_observer = input_observer | ||
self.weight_observer = weight_observer | ||
|
||
def __repr__(self): | ||
return ( | ||
f"LinearObserverTensor(\n" | ||
f"original_weight={self.original_weight_tensor}\n" | ||
f"input_observer={self.input_observer.__class__.__name__ if self.input_observer else None}\n" | ||
f"weight_observer={self.weight_observer.__class__.__name__ if self.weight_observer else None}\n)" | ||
) | ||
|
||
def __tensor_flatten__(self): | ||
return ["original_weight_tensor"], [self.input_observer, self.weight_observer] | ||
|
||
@classmethod | ||
def __tensor_unflatten__( | ||
cls, | ||
tensor_data_dict: Dict[str, Tensor], | ||
tensor_attributes, | ||
outer_size, | ||
outer_stride, | ||
): | ||
original_weight_tensor = tensor_data_dict["original_weight_tensor"] | ||
(input_observer, weight_observer) = tensor_attributes | ||
return cls(original_weight_tensor, input_observer, weight_observer) | ||
|
||
@classmethod | ||
def from_float( | ||
cls, | ||
original_weight_tensor: Tensor, | ||
input_observer: Optional[AffineQuantizedObserverBase] = None, | ||
weight_observer: Optional[AffineQuantizedObserverBase] = None, | ||
): | ||
return cls(original_weight_tensor, input_observer, weight_observer) | ||
|
||
def _apply_fn_to_data(self, fn: Callable): | ||
"""Applies a fn to the tensor component of the LinearObserverTensor""" | ||
return self.__class__( | ||
fn(self.original_weight_tensor), | ||
self.input_observer, | ||
self.weight_observer, | ||
) | ||
|
||
def to(self, *args, **kwargs): | ||
kwargs = self._get_to_kwargs(*args, **kwargs) | ||
return self._apply_fn_to_data(lambda x: x.to(**kwargs)) | ||
|
||
implements = classmethod(_implements) | ||
__torch_function__ = classmethod(_dispatch__torch_function__) | ||
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__) | ||
|
||
|
||
implements = LinearObserverTensor.implements | ||
|
||
|
||
@implements(torch.nn.functional.linear) | ||
def _(func, types, args, kwargs): | ||
input_tensor, weight_tensor, bias = ( | ||
args[0], | ||
args[1], | ||
args[2] if len(args) > 2 else None, | ||
) | ||
if weight_tensor.input_observer is not None: | ||
input_tensor = weight_tensor.input_observer(input_tensor) | ||
if weight_tensor.weight_observer is not None: | ||
weight_tensor = weight_tensor.weight_observer( | ||
weight_tensor.original_weight_tensor | ||
) | ||
|
||
return torch.nn.functional.linear(input_tensor, weight_tensor, bias) | ||
|
||
|
||
@implements(aten.detach.default) | ||
def _(func, types, args, kwargs): | ||
return return_and_correct_aliasing( | ||
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) | ||
) | ||
|
||
|
||
@implements(aten.clone.default) | ||
def _(func, types, args, kwargs): | ||
return return_and_correct_aliasing( | ||
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) | ||
) | ||
|
||
|
||
@implements(aten._to_copy.default) | ||
def _(func, types, args, kwargs): | ||
return return_and_correct_aliasing( | ||
func, | ||
args, | ||
kwargs, | ||
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), | ||
) | ||
|
||
|
||
to_linear_observer_tensor = LinearObserverTensor.from_float | ||
|
||
if TORCH_VERSION_AT_LEAST_2_5: | ||
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True` | ||
torch.serialization.add_safe_globals([LinearObserverTensor]) | ||
|
||
|
||
def insert_observers( | ||
input_observer: Optional[AffineQuantizedObserverBase], | ||
weight_observer: Optional[AffineQuantizedObserverBase], | ||
) -> Callable: | ||
""" | ||
Converts the weight of a linear module to a LinearObserverTensor. | ||
This function wraps the weight of the given linear module with a LinearObserverTensor, | ||
which enables observation of both input and weight tensors during forward passes. | ||
The wrapped weight is then re-wrapped as a nn.Parameter to maintain compatibility | ||
with PyTorch's module system. | ||
Usaage: | ||
``` | ||
linear_module = nn.Linear(10, 20) | ||
quantize_(linear_module, convert_linear_weight_to_observer)) | ||
Args: | ||
linear_module (nn.Linear): The linear module to be converted. | ||
input_observer (Optional[AffineQuantizedObserverBase]): Observer for input tensor. | ||
weight_observer (Optional[AffineQuantizedObserverBase]): Observer for weight tensor. | ||
Returns: | ||
nn.Linear: The modified linear module with its weight wrapped in a LinearObserverTensor. | ||
""" | ||
|
||
def convert_to_linear_observer(linear_module: nn.Linear): | ||
# Wrap the weight with LinearObserverTensor and then with nn.Parameter | ||
linear_module.weight = nn.Parameter( | ||
to_linear_observer_tensor( | ||
linear_module.weight, | ||
input_observer=input_observer, | ||
weight_observer=weight_observer, | ||
), | ||
requires_grad=linear_module.weight.requires_grad, | ||
) | ||
return linear_module | ||
|
||
return convert_to_linear_observer |
Oops, something went wrong.