Skip to content

Commit

Permalink
Add WeightQuantizer and DynamicActQuantizer
Browse files Browse the repository at this point in the history
Summary:
This exposes the AffineQuantizedTensor and LinearActQuantizedTensor
subclass as a model level API that will replace the weights of linear layers
This is in preparation to replace existing tensor subclass APIs such as `change_linear_weights_to_int4_woqtensors`
but currently we can't combine the two quantizers due to some problem with parametrization/nn.Parameter
the error is:

raise KeyError(f"attribute '{name}' already exists")
KeyError: "attribute 'weight' already exists"

happens in
```
lin.weight = torch.nn.Parameter(constructor(lin.weight, **copied_kwargs), requires_grad=False)
```

Test Plan:
regression tests:
```
python test/quantization/test_quant_api.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 22, 2024
1 parent 5741aa2 commit dff6a45
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 36 deletions.
108 changes: 75 additions & 33 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@
get_symmetric_quantization_config,
)

from torchao.quantization.subclass import (
to_aqt,
AffineQuantizedTensor,
LinearActQuantizedTensor,
)
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
apply_dynamic_quant,
apply_weight_only_int8_quant,
Quantizer,
TwoStepQuantizer,
TensorSubclassQuantizer,
)
from torchao.quantization.utils import (
TORCH_VERSION_AFTER_2_3,
Expand Down Expand Up @@ -92,8 +98,8 @@ def __init__(self, m=64, n=32, k=64):
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, self.linear1.in_features).to(torch.float),)
def example_inputs(self, batch_size=1):
return (torch.randn(batch_size, self.linear1.in_features).to(torch.float),)

def forward(self, x):
x = self.linear1(x)
Expand Down Expand Up @@ -425,18 +431,29 @@ def get_per_token_block_size(x):
input_target_dtype = torch.int8
input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

def dynamic_quant(linear):
# note: order is important
linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps), requires_grad=False)
linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)

m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
dynamic_quant(m.linear1)
dynamic_quant(m.linear2)

weight_quantizer = TensorSubclassQuantizer(
to_aqt,
mapping_type=mapping_type,
block_size=block_size,
target_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=eps
)
dynamic_act_quantizer = TensorSubclassQuantizer(LinearActQuantizedTensor.from_float, input_quant_func=input_quant_func)

# note: order is important
m = weight_quantizer.quantize(m)
m = dynamic_act_quantizer.quantize(m)

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
Expand Down Expand Up @@ -475,16 +492,19 @@ def test_quantized_tensor_subclass_int4(self):
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))

def to_quantized(weight):
return AffineQuantizedTensor.from_float(
weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=ZeroPointDomain.FLOAT,
)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
weight_quantizer = TensorSubclassQuantizer(
to_aqt,
mapping_type=mapping_type,
block_size=block_size,
target_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=eps,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=ZeroPointDomain.FLOAT,
)
m = weight_quantizer.quantize(m)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

Expand Down Expand Up @@ -515,12 +535,20 @@ def test_quantized_tensor_subclass_int8(self):
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs()))

def to_quantized(weight):
block_size = (1, weight.shape[1])
return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype)
def get_block_size(x):
return (1, x.shape[1])

weight_quantizer = TensorSubclassQuantizer(
to_aqt,
mapping_type=mapping_type,
get_block_size=get_block_size,
target_dtype=target_dtype,
eps=eps,
zero_point_dtype=zero_point_dtype
)

m = weight_quantizer.quantize(m)

m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

Expand All @@ -537,7 +565,7 @@ def to_quantized(weight):
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int8_dyn_quant(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.subclass import to_aqt
from torchao.quantization.subclass import LinearActQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quant_primitives import ZeroPointDomain
Expand Down Expand Up @@ -568,15 +596,21 @@ def get_per_token_block_size(x):
# use 1024 so that we don't need padding
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m_copy = copy.deepcopy(m)
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs()))

def dynamic_quant(linear):
# note: order is important
linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, get_weight_block_size(linear.weight), target_dtype, eps=eps, zero_point_dtype=zero_point_dtype), requires_grad=False)
linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs(batch_size=20)))

weight_quantizer = TensorSubclassQuantizer(
to_aqt,
mapping_type=mapping_type,
get_block_size=get_weight_block_size,
target_dtype=target_dtype,
eps=eps,
zero_point_dtype=zero_point_dtype
)
dynamic_act_quantizer = TensorSubclassQuantizer(LinearActQuantizedTensor.from_float, input_quant_func=input_quant_func)
m = weight_quantizer.quantize(m)
m = dynamic_act_quantizer.quantize(m)

dynamic_quant(m.linear1)
dynamic_quant(m.linear2)
assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
Expand All @@ -591,6 +625,14 @@ def dynamic_quant(linear):

self.assertTrue(torch.equal(res, ref))

# workaround for export path
from torchao.quantization.quant_api import _unwrap_tensor_subclass
m = _unwrap_tensor_subclass(m)
m = torch.export.export(m, example_inputs).module()
exported_model_res = m(*example_inputs)

self.assertTrue(torch.equal(exported_model_res, ref))


if __name__ == "__main__":
unittest.main()
86 changes: 86 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any

from .dynamic_quant import DynamicallyPerAxisQuantizedLinear
from .utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4
Expand All @@ -27,6 +28,10 @@
Int8DynamicallyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
QuantizedLinearWeightBase,
to_aqt,
LinearActQuantizedTensor,
ConstructTensorSubclassAQT,
ConstructTensorSubclassLAQT,
)
from .weight_only import WeightOnlyInt8QuantLinear
from .unified import Quantizer, TwoStepQuantizer
Expand All @@ -48,6 +53,7 @@
"TwoStepQuantizer",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
"TensorSubclassQuantizer",
"autoquant"
]

Expand Down Expand Up @@ -214,3 +220,83 @@ def replace_conv2d_1x1(conv):
_replace_with_custom_fn_if_matches_filter(
model, replace_conv2d_1x1, filter_fn=filter_fn
)

class UnwrapTensorSubclass(nn.Module):
def forward(self, *tensors):
todo = list(tensors)
for tp, meta, inner_tensors in reversed(self.rebuild_stack):
nb_tensor = len(inner_tensors)
inner_tensors = {a: b for a, b in zip(inner_tensors, todo[-nb_tensor:])}
todo = todo[nb_tensor:]

torch._dynamo.allow_in_graph
def tmp():
return tp.__tensor_unflatten__(inner_tensors, meta, None, None)

rebuilt = tmp()
todo.append(rebuilt)

assert len(todo) == 1
return todo[0]

def right_inverse(self, tensor):
assert type(tensor) is not torch.Tensor
rebuild_stack = []
plain_tensors = []
todo = [tensor]
while todo:
obj = todo.pop()
inner_tensors, metadata = obj.__tensor_flatten__()
rebuild_stack.append((type(obj), metadata, inner_tensors))
for attr_name in inner_tensors:
val = getattr(obj, attr_name)
if type(val) is torch.Tensor:
plain_tensors.append(val)
else:
assert isinstance(val, torch.Tensor)
todo.append(val)

self.rebuild_stack = rebuild_stack

return plain_tensors

def _unwrap_tensor_subclass(model, filter_fn=None):
def insert_parametrization(lin):
parametrize.register_parametrization(lin, "weight", UnwrapTensorSubclass())
return lin

_replace_with_custom_fn_if_matches_filter(
model,
insert_parametrization,
_is_linear if filter_fn is None else filter_fn,
)

return model


def _get_linear_subclass_inserter(constructor, **kwargs):
def insert_subclass(lin):
# so that we don't modify the original kwargs
copied_kwargs = dict(kwargs)
get_block_size = copied_kwargs.pop("get_block_size", None)
if get_block_size:
block_size = get_block_size(lin.weight)
copied_kwargs["block_size"] = block_size
lin.weight = torch.nn.Parameter(constructor(lin.weight, **copied_kwargs), requires_grad=False)
return lin

return insert_subclass

class TensorSubclassQuantizer(Quantizer):
def __init__(self, factory_fn, **kwargs):
super().__init__()
self.factory_fn = factory_fn
self.kwargs = kwargs

def quantize(self, model: torch.nn.Module, filter_fn=None) -> torch.nn.Module:
_replace_with_custom_fn_if_matches_filter(
model,
_get_linear_subclass_inserter(self.factory_fn, **self.kwargs),
_is_linear if filter_fn is None else filter_fn,
)
return model
26 changes: 23 additions & 3 deletions torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"Int8WeightOnlyQuantizedLinearWeight",
"Int4WeightOnlyQuantizedLinearWeight",
"AffineQuantizedTensor",
"LinearActQuantizedTensor",
]


Expand Down Expand Up @@ -266,7 +267,6 @@ def __new__(cls, int_data, q_scales, transposed, shape, dtype=None, **kwargs):
return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined]

def __init__(self, int_data, q_scales, transposed, shape, dtype=None, **kwargs):

self.q_scales = q_scales
super().__init__(int_data, transposed)

Expand Down Expand Up @@ -777,7 +777,7 @@ def dequantize(self, output_dtype=None):
return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype)

def __tensor_flatten__(self):
return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]
return ["int_data", "scale", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype]

@classmethod
def __tensor_unflatten__(
Expand Down Expand Up @@ -1091,7 +1091,7 @@ def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
original_weight_tensor = tensor_data_dict["original_weight_tensor"]
input_quant_func = tensor_attributes
input_quant_func, = tensor_attributes
return cls(
original_weight_tensor,
input_quant_func,
Expand Down Expand Up @@ -1176,3 +1176,23 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
raise NotImplementedError(
f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported"
)


# this is a workaround for tensor subclass https://github.com/pytorch/pytorch/issues/124735
@torch._dynamo.allow_in_graph
def aqt_from_qtensor_components(*args, **kwargs):
return AffineQuantizedTensor(*args, **kwargs)


class ConstructTensorSubclassAQT(ConstructTensorSubclass):
def forward(self, int_data, scale, zero_point):
return aqt_from_qtensor_components(int_data, scale, zero_point, *self.args, **self.kwargs)

@torch._dynamo.allow_in_graph
def laqt_from_qtensor_components(*args, **kwargs):
return LinearActQuantizedTensor(*args, **kwargs)


class ConstructTensorSubclassLAQT(ConstructTensorSubclass):
def forward(self, original_weight_tensor):
return laqt_from_qtensor_components(original_weight_tensor, *self.args, **self.kwargs)

0 comments on commit dff6a45

Please sign in to comment.