diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index fcab07c913..d7a0f2dd0f 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -18,12 +18,19 @@ get_symmetric_quantization_config, ) +from torchao.quantization.subclass import ( + to_aqt, + to_laqt, + AffineQuantizedTensor, + LinearActQuantizedTensor, +) from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, apply_dynamic_quant, apply_weight_only_int8_quant, Quantizer, TwoStepQuantizer, + TensorSubclassQuantizer, ) from torchao.quantization.utils import ( TORCH_VERSION_AFTER_2_3, @@ -92,8 +99,8 @@ def __init__(self, m=64, n=32, k=64): self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float) self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float) - def example_inputs(self): - return (torch.randn(1, self.linear1.in_features).to(torch.float),) + def example_inputs(self, batch_size=1): + return (torch.randn(batch_size, self.linear1.in_features).to(torch.float),) def forward(self, x): x = self.linear1(x) @@ -423,20 +430,31 @@ def get_per_token_block_size(x): # input settings input_mapping_type = MappingType.ASYMMETRIC input_target_dtype = torch.int8 - input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) - - def dynamic_quant(linear): - # note: order is important - linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps), requires_grad=False) - linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False) + input_quant_func = lambda x: to_aqt(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() - dynamic_quant(m.linear1) - dynamic_quant(m.linear2) + + weight_quantizer = TensorSubclassQuantizer( + to_aqt, + mapping_type=mapping_type, + block_size=block_size, + target_dtype=target_dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=eps + ) + dynamic_act_quantizer = TensorSubclassQuantizer(to_laqt, input_quant_func=input_quant_func) + + # note: order is important + m = weight_quantizer.quantize(m) + m = dynamic_act_quantizer.quantize(m) + assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) + assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) + assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor) # reference from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer @@ -475,16 +493,19 @@ def test_quantized_tensor_subclass_int4(self): m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs())) - def to_quantized(weight): - return AffineQuantizedTensor.from_float( - weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=ZeroPointDomain.FLOAT, - ) - - m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False) - m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False) + weight_quantizer = TensorSubclassQuantizer( + to_aqt, + mapping_type=mapping_type, + block_size=block_size, + target_dtype=target_dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=ZeroPointDomain.FLOAT, + ) + m = weight_quantizer.quantize(m) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -515,12 +536,20 @@ def test_quantized_tensor_subclass_int8(self): m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) - def to_quantized(weight): - block_size = (1, weight.shape[1]) - return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + def get_block_size(x): + return (1, x.shape[1]) + + weight_quantizer = TensorSubclassQuantizer( + to_aqt, + mapping_type=mapping_type, + get_block_size=get_block_size, + target_dtype=target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype + ) + + m = weight_quantizer.quantize(m) - m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False) - m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -537,7 +566,7 @@ def to_quantized(weight): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int8_dyn_quant(self): - from torchao.quantization.subclass import AffineQuantizedTensor + from torchao.quantization.subclass import to_aqt from torchao.quantization.subclass import LinearActQuantizedTensor from torchao.quantization.quant_primitives import MappingType from torchao.quantization.quant_primitives import ZeroPointDomain @@ -563,20 +592,26 @@ def get_per_token_block_size(x): input_eps = 1e-5 input_quant_min = -127 input_quant_max = 127 - input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float) + input_quant_func = lambda x: to_aqt(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float) # use 1024 so that we don't need padding m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") m_copy = copy.deepcopy(m) - example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs())) - - def dynamic_quant(linear): - # note: order is important - linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, get_weight_block_size(linear.weight), target_dtype, eps=eps, zero_point_dtype=zero_point_dtype), requires_grad=False) - linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False) + # setting batch_size to 20 to be compatible with the kernel + example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs(batch_size=20))) + + weight_quantizer = TensorSubclassQuantizer( + to_aqt, + mapping_type=mapping_type, + get_block_size=get_weight_block_size, + target_dtype=target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype + ) + dynamic_act_quantizer = TensorSubclassQuantizer(to_laqt, input_quant_func=input_quant_func) + m = weight_quantizer.quantize(m) + m = dynamic_act_quantizer.quantize(m) - dynamic_quant(m.linear1) - dynamic_quant(m.linear2) assert isinstance(m.linear1.weight, LinearActQuantizedTensor) assert isinstance(m.linear2.weight, LinearActQuantizedTensor) assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) @@ -591,6 +626,18 @@ def dynamic_quant(linear): self.assertTrue(torch.equal(res, ref)) + # workaround for export path + from torchao.quantization.quant_api import _unwrap_tensor_subclass + m = _unwrap_tensor_subclass(m) + m = torch.export.export(m, example_inputs).module() + exported_model_res = m(*example_inputs) + + self.assertTrue(torch.equal(exported_model_res, ref)) + + # make sure it compiles + torch._export.aot_compile(m, example_inputs) + + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a5a3a2b3db..c0b908356d 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -18,6 +18,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from typing import Any from .dynamic_quant import DynamicallyPerAxisQuantizedLinear from .utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 @@ -48,6 +49,7 @@ "TwoStepQuantizer", "Int4WeightOnlyGPTQQuantizer", "Int4WeightOnlyQuantizer", + "TensorSubclassQuantizer", "autoquant" ] @@ -214,3 +216,78 @@ def replace_conv2d_1x1(conv): _replace_with_custom_fn_if_matches_filter( model, replace_conv2d_1x1, filter_fn=filter_fn ) + +class UnwrapTensorSubclass(nn.Module): + def forward(self, *tensors): + todo = list(tensors) + for tp, meta, inner_tensors in reversed(self.rebuild_stack): + nb_tensor = len(inner_tensors) + inner_tensors = {a: b for a, b in zip(inner_tensors, todo[-nb_tensor:])} + todo = todo[nb_tensor:] + rebuilt = tp.__tensor_unflatten__(inner_tensors, meta, None, None) + todo.append(rebuilt) + + assert len(todo) == 1 + return todo[0] + + def right_inverse(self, tensor): + assert type(tensor) is not torch.Tensor + rebuild_stack = [] + plain_tensors = [] + todo = [tensor] + while todo: + obj = todo.pop() + inner_tensors, metadata = obj.__tensor_flatten__() + rebuild_stack.append((type(obj), metadata, inner_tensors)) + for attr_name in inner_tensors: + val = getattr(obj, attr_name) + if type(val) is torch.Tensor: + plain_tensors.append(val) + else: + assert isinstance(val, torch.Tensor) + todo.append(val) + + self.rebuild_stack = rebuild_stack + + return plain_tensors + +def _unwrap_tensor_subclass(model, filter_fn=None): + def insert_parametrization(lin): + parametrize.register_parametrization(lin, "weight", UnwrapTensorSubclass()) + return lin + + _replace_with_custom_fn_if_matches_filter( + model, + insert_parametrization, + _is_linear if filter_fn is None else filter_fn, + ) + + return model + + +def _get_linear_subclass_inserter(constructor, **kwargs): + def insert_subclass(lin): + # so that we don't modify the original kwargs + copied_kwargs = dict(kwargs) + get_block_size = copied_kwargs.pop("get_block_size", None) + if get_block_size: + block_size = get_block_size(lin.weight) + copied_kwargs["block_size"] = block_size + lin.weight = torch.nn.Parameter(constructor(lin.weight, **copied_kwargs), requires_grad=False) + return lin + + return insert_subclass + +class TensorSubclassQuantizer(Quantizer): + def __init__(self, factory_fn, **kwargs): + super().__init__() + self.factory_fn = factory_fn + self.kwargs = kwargs + + def quantize(self, model: torch.nn.Module, filter_fn=None) -> torch.nn.Module: + _replace_with_custom_fn_if_matches_filter( + model, + _get_linear_subclass_inserter(self.factory_fn, **self.kwargs), + _is_linear if filter_fn is None else filter_fn, + ) + return model diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 8d0af8b369..6e844530d4 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -35,6 +35,7 @@ "Int8WeightOnlyQuantizedLinearWeight", "Int4WeightOnlyQuantizedLinearWeight", "AffineQuantizedTensor", + "LinearActQuantizedTensor", ] @@ -266,7 +267,6 @@ def __new__(cls, int_data, q_scales, transposed, shape, dtype=None, **kwargs): return super().__new__(cls, int_data, transposed, shape, **kwargs) # type: ignore[attr-defined] def __init__(self, int_data, q_scales, transposed, shape, dtype=None, **kwargs): - self.q_scales = q_scales super().__init__(int_data, transposed) @@ -629,32 +629,6 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8): int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) return int_data, scales_and_zeros, False, groupsize, inner_k_tiles -def to_aqt( - input_float, - mapping_type, - block_size, - target_dtype, - quant_min = None, - quant_max = None, - eps = None, - scale_dtype = None, - zero_point_dtype = None, - preserve_zero = True, - zero_point_domain = ZeroPointDomain.INT, -): - return AffineQuantizedTensor.from_float( - input_float, - mapping_type, - block_size, - target_dtype, - quant_min=quant_min, - quant_max=quant_max, - eps=eps, - scale_dtype=scale_dtype, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=zero_point_domain - ) # TODO: merge with nf4 implements decorator # aten op to their __torch_dispatch__ implemnetations for the tensor subclass @@ -777,7 +751,7 @@ def dequantize(self, output_dtype=None): return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) def __tensor_flatten__(self): - return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] + return ["int_data", "scale", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] @classmethod def __tensor_unflatten__( @@ -1091,7 +1065,7 @@ def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): original_weight_tensor = tensor_data_dict["original_weight_tensor"] - input_quant_func = tensor_attributes + input_quant_func, = tensor_attributes return cls( original_weight_tensor, input_quant_func, @@ -1176,3 +1150,6 @@ def __torch_dispatch__(cls, func, types, args, kwargs): raise NotImplementedError( f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported" ) + +to_aqt = AffineQuantizedTensor.from_float +to_laqt = LinearActQuantizedTensor.from_float