Skip to content

Commit

Permalink
Expose zero_point_domain as arguments (#1401)
Browse files Browse the repository at this point in the history
* export zero_point_domain as arguments

* assert for combination of TensorCoreTiledLayout and integer zero points

* change the default zero_point_domain to None

* maintain layout and zero_point_domain in a dict

* nit

* fix key errors

* nit

* add zero_point_domian arguments in documents

* update documemts

* Apply automatic Ruff fixes

---------

Co-authored-by: Ruff Auto-fixes <[email protected]>
  • Loading branch information
airMeng and ruff authored Dec 17, 2024
1 parent 9472a17 commit ace7219
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 76 deletions.
18 changes: 15 additions & 3 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@
int8_dynamic_activation_int8_weight,
int8_weight_only,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
is_sm_at_least_89,
)


def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"):
def get_quantization_functions(
do_sparse: bool, do_int4: bool, device: str = "cuda", int4_zp_int: bool = False
):
base_functions = [
int8_weight_only(),
int8_dynamic_activation_int4_weight(),
Expand All @@ -36,6 +38,14 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cu
base_functions.append(
int4_weight_only(group_size=32, layout=Int4CPULayout())
)
if int4_zp_int:
base_functions.append(
int4_weight_only(
group_size=32,
layout=Int4CPULayout(),
zero_point_domain=ZeroPointDomain.INT,
)
)
else:
base_functions.append(int4_weight_only(group_size=32))

Expand Down Expand Up @@ -71,7 +81,9 @@ def test_tensor_core_layout_transpose(self):
self.assertEqual(aqt_shape, shape)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
@common_utils.parametrize(
"apply_quant", get_quantization_functions(True, True, "cuda", True)
)
def test_weights_only(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
Expand Down
161 changes: 100 additions & 61 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ def check_idempotent(self, fn, *args, **kwargs):


# Legacy tinygemm ops
def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):
def _get_groupwise_affine_qparams(
w,
n_bit=4,
groupsize=128,
dtype=torch.bfloat16,
zero_point_domain=ZeroPointDomain.FLOAT,
):
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
Expand All @@ -70,21 +76,25 @@ def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat1
max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
quant_min = 0
quant_max = max_int
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to(
dtype=dtype
).reshape(w.shape[0], -1)
if zero_point_domain == ZeroPointDomain.FLOAT:
zeros = min_val + scales * (2 ** (n_bit - 1))
zeros = zeros.to(dtype=dtype).reshape(w.shape[0], -1)
else:
zeros = quant_min - torch.round(min_val / scales)
zeros = torch.clamp(zeros, quant_min, quant_max)
zeros = zeros.to(dtype=dtype).reshape(w.shape[0], -1)
scales = scales.to(dtype=dtype).reshape(w.shape[0], -1)
return scales, zeros


def _groupwise_affine_quantize_tensor_from_qparams(
w,
scales,
zeros,
n_bit=4,
groupsize=128,
w, scales, zeros, n_bit=4, groupsize=128, zero_point_domain=ZeroPointDomain.FLOAT
):
assert groupsize > 1
assert n_bit == 4
# needed for GPTQ single column quantize
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
groupsize = w.shape[-1]
Expand All @@ -97,17 +107,28 @@ def _groupwise_affine_quantize_tensor_from_qparams(

scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
min_val = zeros - scales * (2 ** (n_bit - 1))
max_int = 2**n_bit - 1
min_int = 0
w_int4x8 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)
if zero_point_domain == ZeroPointDomain.FLOAT:
min_val = zeros - scales * (2 ** (n_bit - 1))
w_int4x8 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)
else:
w_int4x8 = (
to_quant.div(scales)
.round()
.add(zeros)
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)

if TORCH_VERSION_AT_LEAST_2_5:
if not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
Expand All @@ -121,6 +142,7 @@ def _groupwise_affine_dequantize_tensor_from_qparams(
zeros,
n_bit=4,
groupsize=128,
zero_point_domain=ZeroPointDomain.FLOAT,
):
assert groupsize > 1
# needed for GPTQ single column dequantize
Expand All @@ -133,12 +155,15 @@ def _groupwise_affine_dequantize_tensor_from_qparams(
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)

w_dq = (
w_int4x8_grouped.sub(2 ** (n_bit - 1))
.mul(scales)
.add(zeros)
.reshape_as(w_int4x8)
)
if zero_point_domain == ZeroPointDomain.FLOAT:
w_dq = (
w_int4x8_grouped.sub(2 ** (n_bit - 1))
.mul(scales)
.add(zeros)
.reshape_as(w_int4x8)
)
else:
w_dq = w_int4x8_grouped.sub(zeros).mul(scales).reshape_as(w_int4x8)
return w_dq


Expand Down Expand Up @@ -650,10 +675,8 @@ def test_not_preserve_zero_not_supported(self):
def test_get_groupwise_affine_qparams(self):
input = torch.randn(10, 256)
n_bit = 4
scale_ref, zero_point_ref = _get_groupwise_affine_qparams(
input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16
)

zero_point_domains = [ZeroPointDomain.FLOAT, ZeroPointDomain.INT]
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (1, 128)
Expand All @@ -662,19 +685,27 @@ def test_get_groupwise_affine_qparams(self):
eps = 1e-6
scale_dtype = torch.bfloat16
zero_point_dtype = torch.bfloat16
scale, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=False,
zero_point_domain=ZeroPointDomain.FLOAT,
)
for zero_point_domain in zero_point_domains:
scale_ref, zero_point_ref = _get_groupwise_affine_qparams(
input,
n_bit=n_bit,
groupsize=128,
dtype=torch.bfloat16,
zero_point_domain=zero_point_domain,
)
scale, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=zero_point_domain == ZeroPointDomain.INT,
zero_point_domain=zero_point_domain,
)

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zero_point_ref))
Expand All @@ -686,14 +717,15 @@ def test_groupwise_affine_quantize_tensor_from_qparams(self):
n_bit = 4
groupsize = 128

w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize
)
w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize
)
for zero_point_domain in [ZeroPointDomain.FLOAT, ZeroPointDomain.INT]:
w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize, zero_point_domain
)
w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize, zero_point_domain
)

self.assertTrue(torch.equal(w_int4x8, w_int4x8_ref))
self.assertTrue(torch.equal(w_int4x8, w_int4x8_ref))

def test_groupwise_affine_dequantize_tensor_from_qparams(self):
input = torch.randint(0, 15, (10, 256), dtype=torch.int32)
Expand All @@ -702,20 +734,27 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
n_bit = 4
groupsize = 128

if TORCH_VERSION_AT_LEAST_2_5:
input_tmp = input
if not (is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6):
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
input_tmp, scales, zeros, n_bit, groupsize
for zero_point_domain in [ZeroPointDomain.FLOAT, ZeroPointDomain.INT]:
if zero_point_domain == ZeroPointDomain.INT:
zeros = torch.randint(0, 15, (10, 2), dtype=torch.int32)
if TORCH_VERSION_AT_LEAST_2_5:
input_tmp = input
if not (
is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6
):
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain
)
else:
if zero_point_domain == ZeroPointDomain.INT:
continue
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize
)
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize, zero_point_domain
)
else:
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize
)
w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(
input, scales, zeros, n_bit, groupsize
)

self.assertTrue(torch.equal(w_bf16, w_bf16_ref))

Expand Down
13 changes: 13 additions & 0 deletions torchao/quantization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,17 @@ We also have a unified quantized tensor subclass that implements how to get a qu
#### Layouts
We extended the `layout` concept to represent different packing formats for a tensor. `AffineQuantizedTensor` supports `plain` and `tensor_core_tiled` layout. `plain` layout is used for `int8_weight_only` and `int8_dynamic_activation_int8_weight` and also as a default layout. `tensor_core_tiled` layout is used for `int4_weight_only` quantization and is packing the weights in a format that is compatible with tinygemm [int4mm](https://github.com/pytorch/pytorch/blob/39357ba06f48cda7d293a4995aa5eba2a46598b5/aten/src/ATen/native/native_functions.yaml#L4138) kernels.

### Zero Point Domains
```ZeroPointDomain``` is used to control the data types of zero points. ```ZeroPointDomain.None``` means zero_point is None, ```ZeroPointDomain.FLOAT``` means zero_point is in the floating point domain and ```ZeroPointDomain.INT``` means integer domain. For detailed implementation of different zero point data types, refer to [the reference implementation](../../test/quantization/test_quant_primitives.py).
The following support matrix illustrates the relationship between layouts and zero point domains, which may be updated with backend changes:

|Layout|None(Symmetric)|Float|Int|
|------|---------------|-----|---|
|TensorCoreTiledLayout| Yes | Yes(Default) | No|
|Int4CPULayout | Yes | Yes(Default) | No |
|MarlinSparseLayout | No | No | Yes(Default) |


### Full Affine Quantization Flow Example
Let's use int4 weight only quantization that's targeting tinygemm int4 weight only quantized matmul
as an example:
Expand Down Expand Up @@ -239,6 +250,8 @@ m_bf16 = torch.compile(m_bf16, mode='max-autotune')
group_size = 32
# only works for torch 2.4+
quantize_(m, int4_weight_only(group_size=group_size))
## If different zero_point_domain needed
# quantize_(m, int4_weight_only(group_size=group_size), zero_point_domain=ZeroPointDomain.FLOAT)

# temporary workaround for tensor subclass + torch.compile
# NOTE: this is only need for torch version < 2.5+
Expand Down
37 changes: 32 additions & 5 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torchao.dtypes import (
AffineQuantizedTensor,
Float8Layout,
Int4CPULayout,
MarlinQQQLayout,
MarlinSparseLayout,
PlainLayout,
Expand Down Expand Up @@ -115,6 +116,19 @@
"Int8DynActInt4WeightGPTQQuantizer",
]

# update according to the support matrix
LAYOUT_TO_ZERO_POINT_DOMAIN = {
TensorCoreTiledLayout: [ZeroPointDomain.FLOAT],
MarlinSparseLayout: [ZeroPointDomain.INT],
Int4CPULayout: [ZeroPointDomain.FLOAT],
}

LAYOUT_TO_PRESERVE_ZEROS = {
TensorCoreTiledLayout: False,
MarlinSparseLayout: True,
Int4CPULayout: False,
}


######
# TO BE DEPRECATED START
Expand Down Expand Up @@ -662,7 +676,10 @@ def gemlite_uintx_weight_only(


def int4_weight_only(
group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False
group_size=128,
layout=TensorCoreTiledLayout(inner_k_tiles=8),
use_hqq=False,
zero_point_domain=None,
):
"""
Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
Expand All @@ -682,6 +699,7 @@ def int4_weight_only(
size is more fine grained, choices are [256, 128, 64, 32]
`layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`
`use_hqq`: whether to use hqq or default quantization mode, default is False
`zero_point_domain`: data type of zeros points, choices are [None(then the value is determined by the layout), ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE]
"""

def apply_int4_weight_only_quant(weight):
Expand All @@ -697,17 +715,26 @@ def apply_int4_weight_only_quant(weight):
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = False
preserve_zero = LAYOUT_TO_PRESERVE_ZEROS[type(layout)]
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT

nonlocal zero_point_domain
assert (
type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys()
), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}"
if zero_point_domain is None:
# the first value is the default one
zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0]
else:
assert (
zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)]
), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}"

# Sparse Marlin only supports symmetric quantization.
# NOTE: If we start having lots of layouts that require different configurations,
# we should consider moving this logic somewhere else.
if isinstance(layout, MarlinSparseLayout):
mapping_type = MappingType.SYMMETRIC
preserve_zero = True
zero_point_domain = ZeroPointDomain.INT
assert (
group_size == 128 or group_size == weight.shape[-1]
), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}"
Expand Down
Loading

0 comments on commit ace7219

Please sign in to comment.