Skip to content

Commit

Permalink
Add WeightQuantizer and DynamicActQuantizer
Browse files Browse the repository at this point in the history
Summary:
This exposes the AffineQuantizedTensor and LinearActQuantizedTensor
subclass as a model level API that will replace the weights of linear layers
This is in preparation to replace existing tensor subclass APIs such as `change_linear_weights_to_int4_woqtensors`
but currently we can't combine the two quantizers due to some problem with parametrization/nn.Parameter
the error is:

raise KeyError(f"attribute '{name}' already exists")
KeyError: "attribute 'weight' already exists"

happens in
```
lin.weight = torch.nn.Parameter(constructor(lin.weight, **copied_kwargs), requires_grad=False)
```

Test Plan:
regression tests:
```
python test/quantization/test_quant_api.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 18, 2024
1 parent 5741aa2 commit ec109ef
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 16 deletions.
54 changes: 39 additions & 15 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
apply_weight_only_int8_quant,
Quantizer,
TwoStepQuantizer,
WeightQuantizer,
DynamicActQuantizer,
)
from torchao.quantization.utils import (
TORCH_VERSION_AFTER_2_3,
Expand Down Expand Up @@ -475,16 +477,18 @@ def test_quantized_tensor_subclass_int4(self):
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))

def to_quantized(weight):
return AffineQuantizedTensor.from_float(
weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=ZeroPointDomain.FLOAT,
)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
weight_quantizer = WeightQuantizer(
mapping_type=mapping_type,
block_size=block_size,
target_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=eps,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=ZeroPointDomain.FLOAT,
)
m = weight_quantizer.quantize(m)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

Expand Down Expand Up @@ -515,12 +519,19 @@ def test_quantized_tensor_subclass_int8(self):
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))

def to_quantized(weight):
block_size = (1, weight.shape[1])
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
def get_block_size(x):
return (1, x.shape[1])

weight_quantizer = WeightQuantizer(
mapping_type=mapping_type,
get_block_size=get_block_size,
target_dtype=target_dtype,
eps=eps,
zero_point_dtype=zero_point_dtype
)

m = weight_quantizer.quantize(m)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

Expand Down Expand Up @@ -577,6 +588,19 @@ def dynamic_quant(linear):

dynamic_quant(m.linear1)
dynamic_quant(m.linear2)

# TODO: this doesn't work with parametrization because we can't register
# parameter twice: `lin.weight = torch.nn.Parameter(constructor(lin.weight, **copied_kwargs), requires_grad=False)`
# weight_quantizer = WeightQuantizer(
# mapping_type=mapping_type,
# get_block_size=get_weight_block_size,
# target_dtype=target_dtype,
# eps=eps,
# zero_point_dtype=zero_point_dtype
# )
# dynamic_act_quantizer = DynamicActQuantizer(input_quant_func=input_quant_func)
# m = weight_quantizer.quantize(m)
# m = dynamic_act_quantizer.quantize(m)
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
Expand Down
50 changes: 50 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any

from .dynamic_quant import DynamicallyPerAxisQuantizedLinear
from .utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
Expand All @@ -27,6 +28,10 @@
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
QuantizedLinearWeightBase,
to_aqt,
LinearActQuantizedTensor,
ConstructTensorSubclassAQT,
ConstructTensorSubclassLAQT,
)
from .weight_only import WeightOnlyInt8QuantLinear
from .unified import Quantizer, TwoStepQuantizer
Expand All @@ -48,6 +53,8 @@
"TwoStepQuantizer",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
"WeightQuantizer",
"DynamicActQuantizer",
"autoquant"
]

Expand Down Expand Up @@ -214,3 +221,46 @@ def replace_conv2d_1x1(conv):
_replace_with_custom_fn_if_matches_filter(
model, replace_conv2d_1x1, filter_fn=filter_fn
)


def _get_linear_subclass_inserter(constructor, subclass_constructor, **kwargs):
def insert_subclass(lin):
# so that we don't modify the original kwargs
copied_kwargs = dict(kwargs)
get_block_size = copied_kwargs.pop("get_block_size", None)
if get_block_size:
block_size = get_block_size(lin.weight)
copied_kwargs["block_size"] = block_size
lin.weight = torch.nn.Parameter(constructor(lin.weight, **copied_kwargs), requires_grad=False)
_, args = lin.weight.__tensor_flatten__()
parametrize.register_parametrization(lin, "weight", subclass_constructor(*args))
return lin

return insert_subclass

class WeightQuantizer(Quantizer):
def __init__(self, **kwargs):
super().__init__()
self.kwargs = kwargs

def quantize(self, model: torch.nn.Module, filter_fn=None) -> torch.nn.Module:
_replace_with_custom_fn_if_matches_filter(
model,
_get_linear_subclass_inserter(to_aqt, ConstructTensorSubclassAQT, **self.kwargs),
_is_linear if filter_fn is None else filter_fn,
)
return model


class DynamicActQuantizer(Quantizer):
def __init__(self, **kwargs):
super().__init__()
self.kwargs = kwargs

def quantize(self, model: torch.nn.Module, filter_fn=None) -> torch.nn.Module:
_replace_with_custom_fn_if_matches_filter(
model,
_get_linear_subclass_inserter(LinearActQuantizedTensor.from_float, ConstructTensorSubclassLAQT, **self.kwargs),
_is_linear if filter_fn is None else filter_fn,
)
return model
23 changes: 22 additions & 1 deletion torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"Int8WeightOnlyQuantizedLinearWeight",
"Int4WeightOnlyQuantizedLinearWeight",
"AffineQuantizedTensor",
"LinearActQuantizedTensor",
]


Expand Down Expand Up @@ -777,7 +778,7 @@ def dequantize(self, output_dtype=None):
return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype)

def __tensor_flatten__(self):
return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]
return ["int_data", "scale", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]

@classmethod
def __tensor_unflatten__(
Expand Down Expand Up @@ -1176,3 +1177,23 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
raise NotImplementedError(
f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported"
)


# this is a workaround for tensor subclass https://github.com/pytorch/pytorch/issues/124735
@torch._dynamo.allow_in_graph
def aqt_from_qtensor_components(*args, **kwargs):
return AffineQuantizedTensor(*args, **kwargs)


class ConstructTensorSubclassAQT(ConstructTensorSubclass):
def forward(self, int_data, scale, zero_point):
return aqt_from_qtensor_components(int_data, scale, zero_point, *self.args, **self.kwargs)

@torch._dynamo.allow_in_graph
def laqt_from_qtensor_components(*args, **kwargs):
return LinearActQuantizedTensor(*args, **kwargs)


class ConstructTensorSubclassLAQT(ConstructTensorSubclass):
def forward(self, original_weight_tensor):
return laqt_from_qtensor_components(original_weight_tensor, *self.args, **self.kwargs)

0 comments on commit ec109ef

Please sign in to comment.