diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 43d57b7d12..0939c49f5d 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -16,7 +16,7 @@ int8_dynamic_activation_int8_weight, int8_weight_only, ) -from torchao.quantization.quant_primitives import MappingType +from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, @@ -24,7 +24,9 @@ ) -def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"): +def get_quantization_functions( + do_sparse: bool, do_int4: bool, device: str = "cuda", int4_zp_int: bool = False +): base_functions = [ int8_weight_only(), int8_dynamic_activation_int4_weight(), @@ -36,6 +38,14 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cu base_functions.append( int4_weight_only(group_size=32, layout=Int4CPULayout()) ) + if int4_zp_int: + base_functions.append( + int4_weight_only( + group_size=32, + layout=Int4CPULayout(), + zero_point_domain=ZeroPointDomain.INT, + ) + ) else: base_functions.append(int4_weight_only(group_size=32)) @@ -71,7 +81,9 @@ def test_tensor_core_layout_transpose(self): self.assertEqual(aqt_shape, shape) @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - @common_utils.parametrize("apply_quant", get_quantization_functions(True, True)) + @common_utils.parametrize( + "apply_quant", get_quantization_functions(True, True, "cuda", True) + ) def test_weights_only(self, apply_quant): linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") ql = apply_quant(linear) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index a3fef29fea..102e76cb1a 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -57,7 +57,13 @@ def check_idempotent(self, fn, *args, **kwargs): # Legacy tinygemm ops -def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): +def _get_groupwise_affine_qparams( + w, + n_bit=4, + groupsize=128, + dtype=torch.bfloat16, + zero_point_domain=ZeroPointDomain.FLOAT, +): if groupsize > w.shape[-1]: groupsize = w.shape[-1] assert groupsize > 1 @@ -70,21 +76,25 @@ def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat1 max_val = to_quant.amax(dim=1, keepdim=True) min_val = to_quant.amin(dim=1, keepdim=True) max_int = 2**n_bit - 1 + quant_min = 0 + quant_max = max_int scales = (max_val - min_val).clamp(min=1e-6) / max_int - zeros = min_val + scales * (2 ** (n_bit - 1)) - return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to( - dtype=dtype - ).reshape(w.shape[0], -1) + if zero_point_domain == ZeroPointDomain.FLOAT: + zeros = min_val + scales * (2 ** (n_bit - 1)) + zeros = zeros.to(dtype=dtype).reshape(w.shape[0], -1) + else: + zeros = quant_min - torch.round(min_val / scales) + zeros = torch.clamp(zeros, quant_min, quant_max) + zeros = zeros.to(dtype=dtype).reshape(w.shape[0], -1) + scales = scales.to(dtype=dtype).reshape(w.shape[0], -1) + return scales, zeros def _groupwise_affine_quantize_tensor_from_qparams( - w, - scales, - zeros, - n_bit=4, - groupsize=128, + w, scales, zeros, n_bit=4, groupsize=128, zero_point_domain=ZeroPointDomain.FLOAT ): assert groupsize > 1 + assert n_bit == 4 # needed for GPTQ single column quantize if groupsize > w.shape[-1] and scales.shape[-1] == 1: groupsize = w.shape[-1] @@ -97,17 +107,28 @@ def _groupwise_affine_quantize_tensor_from_qparams( scales = scales.reshape(-1, 1) zeros = zeros.reshape(-1, 1) - min_val = zeros - scales * (2 ** (n_bit - 1)) max_int = 2**n_bit - 1 min_int = 0 - w_int4x8 = ( - to_quant.sub(min_val) - .div(scales) - .round() - .clamp_(min_int, max_int) - .to(torch.int32) - .reshape_as(w) - ) + if zero_point_domain == ZeroPointDomain.FLOAT: + min_val = zeros - scales * (2 ** (n_bit - 1)) + w_int4x8 = ( + to_quant.sub(min_val) + .div(scales) + .round() + .clamp_(min_int, max_int) + .to(torch.int32) + .reshape_as(w) + ) + else: + w_int4x8 = ( + to_quant.div(scales) + .round() + .add(zeros) + .clamp_(min_int, max_int) + .to(torch.int32) + .reshape_as(w) + ) + if TORCH_VERSION_AT_LEAST_2_5: if not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) @@ -121,6 +142,7 @@ def _groupwise_affine_dequantize_tensor_from_qparams( zeros, n_bit=4, groupsize=128, + zero_point_domain=ZeroPointDomain.FLOAT, ): assert groupsize > 1 # needed for GPTQ single column dequantize @@ -133,12 +155,15 @@ def _groupwise_affine_dequantize_tensor_from_qparams( scales = scales.reshape(-1, 1) zeros = zeros.reshape(-1, 1) - w_dq = ( - w_int4x8_grouped.sub(2 ** (n_bit - 1)) - .mul(scales) - .add(zeros) - .reshape_as(w_int4x8) - ) + if zero_point_domain == ZeroPointDomain.FLOAT: + w_dq = ( + w_int4x8_grouped.sub(2 ** (n_bit - 1)) + .mul(scales) + .add(zeros) + .reshape_as(w_int4x8) + ) + else: + w_dq = w_int4x8_grouped.sub(zeros).mul(scales).reshape_as(w_int4x8) return w_dq @@ -650,10 +675,8 @@ def test_not_preserve_zero_not_supported(self): def test_get_groupwise_affine_qparams(self): input = torch.randn(10, 256) n_bit = 4 - scale_ref, zero_point_ref = _get_groupwise_affine_qparams( - input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16 - ) + zero_point_domains = [ZeroPointDomain.FLOAT, ZeroPointDomain.INT] mapping_type = MappingType.ASYMMETRIC dtype = torch.int8 block_size = (1, 128) @@ -662,19 +685,27 @@ def test_get_groupwise_affine_qparams(self): eps = 1e-6 scale_dtype = torch.bfloat16 zero_point_dtype = torch.bfloat16 - scale, zero_point = choose_qparams_affine( - input, - mapping_type, - block_size, - dtype, - quant_min, - quant_max, - eps, - scale_dtype=scale_dtype, - zero_point_dtype=zero_point_dtype, - preserve_zero=False, - zero_point_domain=ZeroPointDomain.FLOAT, - ) + for zero_point_domain in zero_point_domains: + scale_ref, zero_point_ref = _get_groupwise_affine_qparams( + input, + n_bit=n_bit, + groupsize=128, + dtype=torch.bfloat16, + zero_point_domain=zero_point_domain, + ) + scale, zero_point = choose_qparams_affine( + input, + mapping_type, + block_size, + dtype, + quant_min, + quant_max, + eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=zero_point_domain == ZeroPointDomain.INT, + zero_point_domain=zero_point_domain, + ) self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zero_point_ref)) @@ -686,14 +717,15 @@ def test_groupwise_affine_quantize_tensor_from_qparams(self): n_bit = 4 groupsize = 128 - w_int4x8 = groupwise_affine_quantize_tensor_from_qparams( - input, scales, zeros, n_bit, groupsize - ) - w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams( - input, scales, zeros, n_bit, groupsize - ) + for zero_point_domain in [ZeroPointDomain.FLOAT, ZeroPointDomain.INT]: + w_int4x8 = groupwise_affine_quantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize, zero_point_domain + ) + w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize, zero_point_domain + ) - self.assertTrue(torch.equal(w_int4x8, w_int4x8_ref)) + self.assertTrue(torch.equal(w_int4x8, w_int4x8_ref)) def test_groupwise_affine_dequantize_tensor_from_qparams(self): input = torch.randint(0, 15, (10, 256), dtype=torch.int32) @@ -702,20 +734,27 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): n_bit = 4 groupsize = 128 - if TORCH_VERSION_AT_LEAST_2_5: - input_tmp = input - if not (is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): - input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( - input_tmp, scales, zeros, n_bit, groupsize + for zero_point_domain in [ZeroPointDomain.FLOAT, ZeroPointDomain.INT]: + if zero_point_domain == ZeroPointDomain.INT: + zeros = torch.randint(0, 15, (10, 2), dtype=torch.int32) + if TORCH_VERSION_AT_LEAST_2_5: + input_tmp = input + if not ( + is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6 + ): + input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( + input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain + ) + else: + if zero_point_domain == ZeroPointDomain.INT: + continue + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize + ) + w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams( + input, scales, zeros, n_bit, groupsize, zero_point_domain ) - else: - w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( - input, scales, zeros, n_bit, groupsize - ) - w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams( - input, scales, zeros, n_bit, groupsize - ) self.assertTrue(torch.equal(w_bf16, w_bf16_ref)) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 80f4dd689c..360831c76d 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -202,6 +202,17 @@ We also have a unified quantized tensor subclass that implements how to get a qu #### Layouts We extended the `layout` concept to represent different packing formats for a tensor. `AffineQuantizedTensor` supports `plain` and `tensor_core_tiled` layout. `plain` layout is used for `int8_weight_only` and `int8_dynamic_activation_int8_weight` and also as a default layout. `tensor_core_tiled` layout is used for `int4_weight_only` quantization and is packing the weights in a format that is compatible with tinygemm [int4mm](https://github.com/pytorch/pytorch/blob/39357ba06f48cda7d293a4995aa5eba2a46598b5/aten/src/ATen/native/native_functions.yaml#L4138) kernels. +### Zero Point Domains +```ZeroPointDomain``` is used to control the data types of zero points. ```ZeroPointDomain.None``` means zero_point is None, ```ZeroPointDomain.FLOAT``` means zero_point is in the floating point domain and ```ZeroPointDomain.INT``` means integer domain. For detailed implementation of different zero point data types, refer to [the reference implementation](../../test/quantization/test_quant_primitives.py). +The following support matrix illustrates the relationship between layouts and zero point domains, which may be updated with backend changes: + +|Layout|None(Symmetric)|Float|Int| +|------|---------------|-----|---| +|TensorCoreTiledLayout| Yes | Yes(Default) | No| +|Int4CPULayout | Yes | Yes(Default) | No | +|MarlinSparseLayout | No | No | Yes(Default) | + + ### Full Affine Quantization Flow Example Let's use int4 weight only quantization that's targeting tinygemm int4 weight only quantized matmul as an example: @@ -239,6 +250,8 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune') group_size = 32 # only works for torch 2.4+ quantize_(m, int4_weight_only(group_size=group_size)) +## If different zero_point_domain needed +# quantize_(m, int4_weight_only(group_size=group_size), zero_point_domain=ZeroPointDomain.FLOAT) # temporary workaround for tensor subclass + torch.compile # NOTE: this is only need for torch version < 2.5+ diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 03fb8812b1..5161f1657e 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -28,6 +28,7 @@ from torchao.dtypes import ( AffineQuantizedTensor, Float8Layout, + Int4CPULayout, MarlinQQQLayout, MarlinSparseLayout, PlainLayout, @@ -115,6 +116,19 @@ "Int8DynActInt4WeightGPTQQuantizer", ] +# update according to the support matrix +LAYOUT_TO_ZERO_POINT_DOMAIN = { + TensorCoreTiledLayout: [ZeroPointDomain.FLOAT], + MarlinSparseLayout: [ZeroPointDomain.INT], + Int4CPULayout: [ZeroPointDomain.FLOAT], +} + +LAYOUT_TO_PRESERVE_ZEROS = { + TensorCoreTiledLayout: False, + MarlinSparseLayout: True, + Int4CPULayout: False, +} + ###### # TO BE DEPRECATED START @@ -662,7 +676,10 @@ def gemlite_uintx_weight_only( def int4_weight_only( - group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False + group_size=128, + layout=TensorCoreTiledLayout(inner_k_tiles=8), + use_hqq=False, + zero_point_domain=None, ): """ Applies uint4 weight-only asymmetric per-group quantization to linear layers, using @@ -682,6 +699,7 @@ def int4_weight_only( size is more fine grained, choices are [256, 128, 64, 32] `layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)` `use_hqq`: whether to use hqq or default quantization mode, default is False + `zero_point_domain`: data type of zeros points, choices are [None(then the value is determined by the layout), ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE] """ def apply_int4_weight_only_quant(weight): @@ -697,17 +715,26 @@ def apply_int4_weight_only_quant(weight): quant_min = 0 quant_max = 15 eps = 1e-6 - preserve_zero = False + preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)] zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT + + nonlocal zero_point_domain + assert ( + type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys() + ), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}" + if zero_point_domain is None: + # the first value is the default one + zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0] + else: + assert ( + zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)] + ), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}" # Sparse Marlin only supports symmetric quantization. # NOTE: If we start having lots of layouts that require different configurations, # we should consider moving this logic somewhere else. if isinstance(layout, MarlinSparseLayout): mapping_type = MappingType.SYMMETRIC - preserve_zero = True - zero_point_domain = ZeroPointDomain.INT assert ( group_size == 128 or group_size == weight.shape[-1] ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}" diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index e1cf98b549..74c136ad00 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -373,11 +373,7 @@ def unpack_tinygemm_scales_and_zeros(scales_and_zeros): def groupwise_affine_quantize_tensor_from_qparams( - w, - scales, - zeros, - n_bit=4, - groupsize=128, + w, scales, zeros, n_bit=4, groupsize=128, zero_point_domain=ZeroPointDomain.FLOAT ): assert groupsize > 1 # needed for GPTQ single column quantize @@ -400,7 +396,7 @@ def groupwise_affine_quantize_tensor_from_qparams( output_dtype, quant_min, quant_max, - zero_point_domain=ZeroPointDomain.FLOAT, + zero_point_domain=zero_point_domain, ) if TORCH_VERSION_AT_LEAST_2_5 and w.shape[-1] > 1: if not (is_device(int_data.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): @@ -414,6 +410,7 @@ def groupwise_affine_dequantize_tensor_from_qparams( zeros, n_bit=4, groupsize=128, + zero_point_domain=ZeroPointDomain.FLOAT, ): assert groupsize > 1 assert w_int4x8.dim() == 2 @@ -452,7 +449,7 @@ def groupwise_affine_dequantize_tensor_from_qparams( input_dtype, quant_min, quant_max, - zero_point_domain=ZeroPointDomain.FLOAT, + zero_point_domain=zero_point_domain, output_dtype=scales.dtype, )