diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index fcab07c913..17ea7bde78 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -18,12 +18,18 @@ get_symmetric_quantization_config, ) +from torchao.quantization.subclass import ( + to_aqt, + AffineQuantizedTensor, + LinearActQuantizedTensor, +) from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, apply_dynamic_quant, apply_weight_only_int8_quant, Quantizer, TwoStepQuantizer, + TensorSubclassQuantizer, ) from torchao.quantization.utils import ( TORCH_VERSION_AFTER_2_3, @@ -92,8 +98,8 @@ def __init__(self, m=64, n=32, k=64): self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float) self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float) - def example_inputs(self): - return (torch.randn(1, self.linear1.in_features).to(torch.float),) + def example_inputs(self, batch_size=1): + return (torch.randn(batch_size, self.linear1.in_features).to(torch.float),) def forward(self, x): x = self.linear1(x) @@ -425,18 +431,29 @@ def get_per_token_block_size(x): input_target_dtype = torch.int8 input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) - def dynamic_quant(linear): - # note: order is important - linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps), requires_grad=False) - linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False) - m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() - dynamic_quant(m.linear1) - dynamic_quant(m.linear2) + + weight_quantizer = TensorSubclassQuantizer( + to_aqt, + mapping_type=mapping_type, + block_size=block_size, + target_dtype=target_dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=eps + ) + dynamic_act_quantizer = TensorSubclassQuantizer(LinearActQuantizedTensor.from_float, input_quant_func=input_quant_func) + + # note: order is important + m = weight_quantizer.quantize(m) + m = dynamic_act_quantizer.quantize(m) + assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) + assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) + assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor) # reference from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer @@ -475,16 +492,19 @@ def test_quantized_tensor_subclass_int4(self): m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs())) - def to_quantized(weight): - return AffineQuantizedTensor.from_float( - weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=ZeroPointDomain.FLOAT, - ) - - m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False) - m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False) + weight_quantizer = TensorSubclassQuantizer( + to_aqt, + mapping_type=mapping_type, + block_size=block_size, + target_dtype=target_dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=ZeroPointDomain.FLOAT, + ) + m = weight_quantizer.quantize(m) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -515,12 +535,20 @@ def test_quantized_tensor_subclass_int8(self): m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) - def to_quantized(weight): - block_size = (1, weight.shape[1]) - return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + def get_block_size(x): + return (1, x.shape[1]) + + weight_quantizer = TensorSubclassQuantizer( + to_aqt, + mapping_type=mapping_type, + get_block_size=get_block_size, + target_dtype=target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype + ) + + m = weight_quantizer.quantize(m) - m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False) - m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -537,7 +565,7 @@ def to_quantized(weight): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int8_dyn_quant(self): - from torchao.quantization.subclass import AffineQuantizedTensor + from torchao.quantization.subclass import to_aqt from torchao.quantization.subclass import LinearActQuantizedTensor from torchao.quantization.quant_primitives import MappingType from torchao.quantization.quant_primitives import ZeroPointDomain @@ -568,15 +596,21 @@ def get_per_token_block_size(x): # use 1024 so that we don't need padding m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") m_copy = copy.deepcopy(m) - example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs())) - - def dynamic_quant(linear): - # note: order is important - linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, get_weight_block_size(linear.weight), target_dtype, eps=eps, zero_point_dtype=zero_point_dtype), requires_grad=False) - linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False) + # setting batch_size to 20 to be compatible with the kernel + example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs(batch_size=20))) + + weight_quantizer = TensorSubclassQuantizer( + to_aqt, + mapping_type=mapping_type, + get_block_size=get_weight_block_size, + target_dtype=target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype + ) + dynamic_act_quantizer = TensorSubclassQuantizer(LinearActQuantizedTensor.from_float, input_quant_func=input_quant_func) + m = weight_quantizer.quantize(m) + m = dynamic_act_quantizer.quantize(m) - dynamic_quant(m.linear1) - dynamic_quant(m.linear2) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) @@ -591,6 +625,14 @@ def dynamic_quant(linear): self.assertTrue(torch.equal(res, ref)) + # workaround for export path + from torchao.quantization.quant_api import _unwrap_tensor_subclass + m = _unwrap_tensor_subclass(m) + m = torch.export.export(m, example_inputs).module() + exported_model_res = m(*example_inputs) + + self.assertTrue(torch.equal(exported_model_res, ref)) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a5a3a2b3db..0039368c66 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -18,6 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from typing import Any from .dynamic_quant import DynamicallyPerAxisQuantizedLinear from .utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 @@ -27,6 +28,10 @@ Int8DynamicallyQuantizedLinearWeight, Int8WeightOnlyQuantizedLinearWeight, QuantizedLinearWeightBase, + to_aqt, + LinearActQuantizedTensor, + ConstructTensorSubclassAQT, + ConstructTensorSubclassLAQT, ) from .weight_only import WeightOnlyInt8QuantLinear from .unified import Quantizer, TwoStepQuantizer @@ -48,6 +53,7 @@ "TwoStepQuantizer", "Int4WeightOnlyGPTQQuantizer", "Int4WeightOnlyQuantizer", + "TensorSubclassQuantizer", "autoquant" ] @@ -214,3 +220,83 @@ def replace_conv2d_1x1(conv): _replace_with_custom_fn_if_matches_filter( model, replace_conv2d_1x1, filter_fn=filter_fn ) + +class UnwrapTensorSubclass(nn.Module): + def forward(self, *tensors): + todo = list(tensors) + for tp, meta, inner_tensors in reversed(self.rebuild_stack): + nb_tensor = len(inner_tensors) + inner_tensors = {a: b for a, b in zip(inner_tensors, todo[-nb_tensor:])} + todo = todo[nb_tensor:] + + torch._dynamo.allow_in_graph + def tmp(): + return tp.__tensor_unflatten__(inner_tensors, meta, None, None) + + rebuilt = tmp() + todo.append(rebuilt) + + assert len(todo) == 1 + return todo[0] + + def right_inverse(self, tensor): + assert type(tensor) is not torch.Tensor + rebuild_stack = [] + plain_tensors = [] + todo = [tensor] + while todo: + obj = todo.pop() + inner_tensors, metadata = obj.__tensor_flatten__() + rebuild_stack.append((type(obj), metadata, inner_tensors)) + for attr_name in inner_tensors: + val = getattr(obj, attr_name) + if type(val) is torch.Tensor: + plain_tensors.append(val) + else: + assert isinstance(val, torch.Tensor) + todo.append(val) + + self.rebuild_stack = rebuild_stack + + return plain_tensors + +def _unwrap_tensor_subclass(model, filter_fn=None): + def insert_parametrization(lin): + parametrize.register_parametrization(lin, "weight", UnwrapTensorSubclass()) + return lin + + _replace_with_custom_fn_if_matches_filter( + model, + insert_parametrization, + _is_linear if filter_fn is None else filter_fn, + ) + + return model + + +def _get_linear_subclass_inserter(constructor, **kwargs): + def insert_subclass(lin): + # so that we don't modify the original kwargs + copied_kwargs = dict(kwargs) + get_block_size = copied_kwargs.pop("get_block_size", None) + if get_block_size: + block_size = get_block_size(lin.weight) + copied_kwargs["block_size"] = block_size + lin.weight = torch.nn.Parameter(constructor(lin.weight, **copied_kwargs), requires_grad=False) + return lin + + return insert_subclass + +class TensorSubclassQuantizer(Quantizer): + def __init__(self, factory_fn, **kwargs): + super().__init__() + self.factory_fn = factory_fn + self.kwargs = kwargs + + def quantize(self, model: torch.nn.Module, filter_fn=None) -> torch.nn.Module: + _replace_with_custom_fn_if_matches_filter( + model, + _get_linear_subclass_inserter(self.factory_fn, **self.kwargs), + _is_linear if filter_fn is None else filter_fn, + ) + return model diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 8d0af8b369..5144228f4b 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -35,6 +35,7 @@ "Int8WeightOnlyQuantizedLinearWeight", "Int4WeightOnlyQuantizedLinearWeight", "AffineQuantizedTensor", + "LinearActQuantizedTensor", ] @@ -266,7 +267,6 @@ def __new__(cls, int_data, q_scales, transposed, shape, dtype=None, **kwargs): return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined] def __init__(self, int_data, q_scales, transposed, shape, dtype=None, **kwargs): - self.q_scales = q_scales super().__init__(int_data, transposed) @@ -777,7 +777,7 @@ def dequantize(self, output_dtype=None): return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) def __tensor_flatten__(self): - return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] + return ["int_data", "scale", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] @classmethod def __tensor_unflatten__( @@ -1091,7 +1091,7 @@ def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): original_weight_tensor = tensor_data_dict["original_weight_tensor"] - input_quant_func = tensor_attributes + input_quant_func, = tensor_attributes return cls( original_weight_tensor, input_quant_func, @@ -1176,3 +1176,23 @@ def __torch_dispatch__(cls, func, types, args, kwargs): raise NotImplementedError( f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported" ) + + +# this is a workaround for tensor subclass https://github.com/pytorch/pytorch/issues/124735 +@torch._dynamo.allow_in_graph +def aqt_from_qtensor_components(*args, **kwargs): + return AffineQuantizedTensor(*args, **kwargs) + + +class ConstructTensorSubclassAQT(ConstructTensorSubclass): + def forward(self, int_data, scale, zero_point): + return aqt_from_qtensor_components(int_data, scale, zero_point, *self.args, **self.kwargs) + +@torch._dynamo.allow_in_graph +def laqt_from_qtensor_components(*args, **kwargs): + return LinearActQuantizedTensor(*args, **kwargs) + + +class ConstructTensorSubclassLAQT(ConstructTensorSubclass): + def forward(self, original_weight_tensor): + return laqt_from_qtensor_components(original_weight_tensor, *self.args, **self.kwargs)