Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use int_scaled_matmul with int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC) #1402

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
36 changes: 27 additions & 9 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
quantize_affine,
dequantize_affine,
MappingType,
ZeroPointDomain,
)
from torchao.quantization.utils import (
dequantize_per_channel,
Expand Down Expand Up @@ -105,6 +106,10 @@

COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

ACT_MAPPING_TYPES = [MappingType.ASYMMETRIC, MappingType.SYMMETRIC]

WEIGHT_ZERO_POINT_DOMAINS = [ZeroPointDomain.NONE, ZeroPointDomain.INT]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()

def _int8wo_api(mod):
Expand All @@ -125,9 +130,16 @@ def _int8wo_groupwise_api(mod):
group_size = 32
quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False)

def _int8da_int8w_api(mod):
def _int8da_int8w_api(mod, act_mapping_type=MappingType.SYMMETRIC, weight_zero_point_domain=ZeroPointDomain.INT):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
quantize_(
mod,
int8_dynamic_activation_int8_weight(
act_mapping_type=act_mapping_type,
weight_zp_domain=weight_zero_point_domain
),
set_inductor_config=False
)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
Expand Down Expand Up @@ -871,24 +883,30 @@ def _test_lin_weight_subclass_api_impl(
api(mod)

test = mod(x)

self.assertGreater(
SQNR(ref_f, test),
min_sqnr, f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}"
min_sqnr, f"API failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}"
)

mod_qc = torch.compile(mod, mode="max-autotune")
test_comp = mod_qc(x)
self.assertGreater(
SQNR(ref_f, test_comp), min_sqnr,
f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}"
f"API failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}"
)


@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
_int8da_int8w_api, device, 35, test_dtype=dtype
@parameterized.expand(
list(itertools.product(COMMON_DEVICES, COMMON_DTYPES, ACT_MAPPING_TYPES, WEIGHT_ZERO_POINT_DOMAINS))
)
def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping, weight_zero_point_domain):
from functools import partial
api = partial(
_int8da_int8w_api,
act_mapping_type=act_mapping,
weight_zero_point_domain=weight_zero_point_domain
)
self._test_lin_weight_subclass_api_impl(api, device, 35, test_dtype=dtype)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(is_fbcode(), "broken in fbcode")
Expand Down
26 changes: 26 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,32 @@ def test_fake_quantize_affine_cachemask(self):
torch.testing.assert_close(dequantized, fake_quantized)
torch.testing.assert_close(expected_mask, mask)

# ZeroPointDomain.NONE should work
def test_none_zero_point_domain(self):
input = torch.randn(10, 256)
n_bit = 8
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (1, 128)
quant_min = None
quant_max = None
eps = 1e-6
scale_dtype = torch.float32
zero_point_dtype = torch.int64
_, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=True,
zero_point_domain=ZeroPointDomain.NONE,
)
self.assertTrue(zero_point is None)

if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def from_hp_to_intx(
zero_point_domain,
)
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
if zero_point_domain is None:
if zero_point_domain == ZeroPointDomain.NONE or zero_point_domain is None:
zero_point = None
data = quantize_affine(
input_float,
Expand Down
15 changes: 12 additions & 3 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
PlainAQTTensorImpl,
_linear_fp_act_int8_weight_check,
_linear_fp_act_int8_weight_impl,
_linear_int8_act_int8_weight_check,
_linear_int8_act_int8_weight_impl,
_linear_sym_int8_act_sym_int8_weight_check,
_linear_sym_int8_act_sym_int8_weight_impl,
_linear_asym_int8_act_sym_int8_weight_check,
_linear_asym_int8_act_sym_int8_weight_impl
)
from torchao.dtypes.uintx.semi_sparse_layout import (
_linear_int8_act_int8_weight_semi_structured_sparse_check,
Expand Down Expand Up @@ -110,7 +112,14 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias):
# so that these can be shared by F.linear, aten.mm, aten.addmm dispatches
def _register_aqt_quantized_linear_dispatches():
for dispatch_condition, impl in [
(_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl),
(
_linear_sym_int8_act_sym_int8_weight_check,
_linear_sym_int8_act_sym_int8_weight_impl
),
(
_linear_asym_int8_act_sym_int8_weight_check,
_linear_asym_int8_act_sym_int8_weight_impl
),
(
_linear_int8_act_int8_weight_semi_structured_sparse_check,
_linear_int8_act_int8_weight_semi_structured_sparse_impl,
Expand Down
72 changes: 62 additions & 10 deletions torchao/dtypes/uintx/plain_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __new__(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
kwargs = {}
Expand All @@ -55,7 +55,7 @@ def __init__(
self,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
self.int_data = int_data
Expand All @@ -64,7 +64,10 @@ def __init__(
self._layout = _layout

def __tensor_flatten__(self):
return ["int_data", "scale", "zero_point"], [self._layout]
if self.zero_point is not None:
return ["int_data", "scale", "zero_point"], [self._layout]
else:
return ["int_data", "scale"], [self._layout]

@classmethod
def __tensor_unflatten__(
Expand All @@ -73,7 +76,7 @@ def __tensor_unflatten__(
int_data, scale, zero_point = (
tensor_data_dict["int_data"],
tensor_data_dict["scale"],
tensor_data_dict["zero_point"],
tensor_data_dict.get("zero_point", None),
)
(_layout,) = tensor_attributes
return cls(int_data, scale, zero_point, _layout)
Expand All @@ -83,15 +86,15 @@ def to(self, *args, **kwargs):
return self.__class__(
self.int_data.to(kwargs["device"]),
self.scale.to(kwargs["device"]),
self.zero_point.to(kwargs["device"]),
self.zero_point.to(kwargs["device"]) if self.zero_point is not None else None,
self._layout,
)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.int_data),
fn(self.scale),
fn(self.zero_point),
fn(self.zero_point) if self.zero_point is not None else None,
self._layout,
)

Expand Down Expand Up @@ -134,7 +137,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
return PlainAQTTensorImpl(
aten.slice.Tensor(self.int_data, dim, start, end, step),
self.scale.view(-1),
self.zero_point.view(-1),
self.zero_point.view(-1) if self.zero_point is not None else None,
self._layout,
)
else:
Expand All @@ -148,7 +151,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):

__torch_function__ = torch._C._disabled_torch_function_impl

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return self.int_data, self.scale, self.zero_point

def get_layout(self) -> Layout:
Expand Down Expand Up @@ -220,7 +223,7 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias):
return y


def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias):
def _linear_sym_int8_act_sym_int8_weight_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
and _aqt_is_int8_reduced_range(input_tensor)
Expand All @@ -231,7 +234,7 @@ def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias):
)


def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
def _linear_sym_int8_act_sym_int8_weight_impl(input_tensor, weight_tensor, bias):
#
# 1. do the matrix form of dot(X_i, W_j)
#
Expand Down Expand Up @@ -266,3 +269,52 @@ def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
if bias is not None:
y += bias
return y


def _linear_asym_int8_act_sym_int8_weight_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
and _aqt_is_int8(input_tensor)
and weight_tensor.zero_point_domain == ZeroPointDomain.NONE
and isinstance(weight_tensor, AffineQuantizedTensor)
and input_tensor.dtype == weight_tensor.dtype
and isinstance(input_tensor._layout, PlainLayout)
and isinstance(weight_tensor._layout, PlainLayout)
)


def _linear_asym_int8_act_sym_int8_weight_impl(input_tensor, weight_tensor, bias):
#
# 1. do the matrix form of dot(X_i, W_j)
#
#
# 2. rescale the output and apply compensation for zero point of A
#
x_vals_int8 = input_tensor.tensor_impl.int_data
x_zps = input_tensor.tensor_impl.zero_point.reshape(-1, 1)
x_scales = input_tensor.tensor_impl.scale.reshape(-1, 1)
w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t()
w_scales = weight_tensor.tensor_impl.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
x_scales_dtype = x_scales.dtype
# Cast fp16 scale to float to avoid overflow in int_scaled_matmul
intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype
y_dot_scaled = int_scaled_matmul(
tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)
)
y_dot_scaled = y_dot_scaled.to(x_scales_dtype) * w_scales

# Compute compensation
w_col_sum = w_vals_int8_t.to(torch.float).sum(dim=0)
a_compensation = ((x_scales * w_scales) * x_zps.to(intermediate_dtype)) * w_col_sum

y = (y_dot_scaled - a_compensation).reshape(
*x_vals_int8.shape[:-1], y_dot_scaled.shape[-1]
)

# can downcast only at the very end
output_dtype = input_tensor.dtype
y = y.to(output_dtype)
if bias is not None:
y += bias
return y
5 changes: 4 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,9 @@ def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:


def int8_dynamic_activation_int8_weight(
layout=PlainLayout(), act_mapping_type=MappingType.SYMMETRIC
layout=PlainLayout(),
act_mapping_type=MappingType.SYMMETRIC,
weight_zp_domain=ZeroPointDomain.INT
):
"""
Applies int8 dynamic symmetric per-token activation and int8 per-channel weight
Expand Down Expand Up @@ -781,6 +783,7 @@ def get_weight_block_size(x):
eps=eps,
zero_point_dtype=zero_point_dtype,
_layout=layout,
zero_point_domain=weight_zp_domain
)
weight = to_linear_activation_quantized(weight, input_quant_func)
return weight
Expand Down
11 changes: 8 additions & 3 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,14 +888,19 @@ def _choose_qparams_affine(
"preserve_zero == False is not supported for symmetric quantization"
)
if (
zero_point_domain is not None
zero_point_domain != ZeroPointDomain.NONE.name
and zero_point_domain != None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we can probably remove support for None since it's the same as ZeroPointDomain.NONE.name

Copy link
Contributor Author

@sanchitintel sanchitintel Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for reviewing!

Some other places in the codebase are also using both ZeroPointDomain.NONE.name and None separately:

), f"dequantiztion with no zero point domain is only supported with FP8 types, got {input.dtype}"

I unified these two cases as one in the latest commit, but I'm not sure if changes in __tensor_unflatten__ & __tensor_flatten__ methods of some classes may be required at some other places in the codebase to ensure that they can deal with a None zero-point when TorchDynamo would be used . I'll run CUDA-only UTs at my end tomorrow morning to verify.

EDIT: Haven't gotten access to an Nvidia GPU until now

and zero_point_domain != ZeroPointDomain.INT.name
):
raise ValueError(
"zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization"
"Except a None value for zero_point_domain, Only ZeroPointDomain.NONE and ZeroPointDomain.INT"
" are supported for symmetric quantization."
)
if zero_point_domain == ZeroPointDomain.NONE.name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the fix, looks like this is not tested before. can you add a test for the new code path?

also this op is becoming too complicated..we want to split

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a test for the new code path?

This case is being tested in a UT I added in test/quantization/test_quant_primitives.py

also this op is becoming too complicated..we want to split

Please advise if you're referring to splitting _choose_qparams_affine.
If so, I could split it up into smaller methods. Thanks!

Copy link
Contributor

@jerryzh168 jerryzh168 Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I meant splitting choose_qparams_affine/quantize_affine/dequantize, not to smaller methods, but to different variations and reduce the complexity of the most common path (and remove these if/else checking), this includes removing preserve_zero, zero_point_domain args and just have different variations of choose_qparams_affine/quantize_affine/dequantize. this should be done separately though since it will be a large change

zero_point = None
else:
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
scale = torch.clamp(scale, min=eps)
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
else:
assert mapping_type == MappingType.ASYMMETRIC.name
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
Expand Down