Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use int_scaled_matmul with int8_dynamic_activation_int8_weight(act_mapping_type=MappingType.ASYMMETRIC) #1402

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
36 changes: 27 additions & 9 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
quantize_affine,
dequantize_affine,
MappingType,
ZeroPointDomain,
)
from torchao.quantization.utils import (
dequantize_per_channel,
Expand Down Expand Up @@ -105,6 +106,10 @@

COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]

ACT_MAPPING_TYPES = [MappingType.ASYMMETRIC, MappingType.SYMMETRIC]

WEIGHT_ZERO_POINT_DOMAINS = [ZeroPointDomain.NONE, ZeroPointDomain.INT]

COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()

def _int8wo_api(mod):
Expand All @@ -125,9 +130,16 @@ def _int8wo_groupwise_api(mod):
group_size = 32
quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False)

def _int8da_int8w_api(mod):
def _int8da_int8w_api(mod, act_mapping_type=MappingType.SYMMETRIC, weight_zero_point_domain=ZeroPointDomain.INT):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int8_dynamic_activation_int8_weight(), set_inductor_config=False)
quantize_(
mod,
int8_dynamic_activation_int8_weight(
act_mapping_type=act_mapping_type,
weight_zp_domain=weight_zero_point_domain
),
set_inductor_config=False
)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
Expand Down Expand Up @@ -871,24 +883,30 @@ def _test_lin_weight_subclass_api_impl(
api(mod)

test = mod(x)

self.assertGreater(
SQNR(ref_f, test),
min_sqnr, f"{api.__name__} failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}"
min_sqnr, f"API failed, no compile dtype={test_dtype}, (m, k, n)={test_shape}"
)

mod_qc = torch.compile(mod, mode="max-autotune")
test_comp = mod_qc(x)
self.assertGreater(
SQNR(ref_f, test_comp), min_sqnr,
f"{api.__name__} failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}"
f"API failed when compiled with dtype={test_dtype}, (m, k, n)={test_shape}"
)


@parameterized.expand(COMMON_DEVICE_DTYPE)
def test_int8_dynamic_quant_subclass_api(self, device, dtype):
self._test_lin_weight_subclass_api_impl(
_int8da_int8w_api, device, 35, test_dtype=dtype
@parameterized.expand(
list(itertools.product(COMMON_DEVICES, COMMON_DTYPES, ACT_MAPPING_TYPES, WEIGHT_ZERO_POINT_DOMAINS))
)
def test_int8_dynamic_quant_subclass_api(self, device, dtype, act_mapping, weight_zero_point_domain):
from functools import partial
api = partial(
_int8da_int8w_api,
act_mapping_type=act_mapping,
weight_zero_point_domain=weight_zero_point_domain
)
self._test_lin_weight_subclass_api_impl(api, device, 35, test_dtype=dtype)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(is_fbcode(), "broken in fbcode")
Expand Down
15 changes: 7 additions & 8 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)


Expand Down Expand Up @@ -74,7 +75,7 @@ def test_block_size_calc_success(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
example_inputs = [
torch.randn(10, 2048),
Expand All @@ -93,7 +94,7 @@ def test_block_size_calc_success(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
for example_input in example_inputs:
obs(example_input)
Expand All @@ -108,7 +109,7 @@ def test_block_size_row_errors(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
example_inputs = [
torch.randn(10, 2048),
Expand All @@ -127,7 +128,7 @@ def test_block_size_row_errors(self):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
example_inputs = [
torch.randn(10, 2048),
Expand Down Expand Up @@ -155,7 +156,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
if observe_weight:
weight_observer = AffineQuantizedMinMaxObserver(
Expand All @@ -165,7 +166,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
)
else:
weight_observer = None
Expand Down Expand Up @@ -199,7 +200,6 @@ def test_linear_observer_tensor(self, observe_weight: bool):
input_scale.item(),
max_val / max_fp8,
)
self.assertIsNotNone(input_zero_point)

if observe_weight:
weight_observer = linear.weight.weight_observer
Expand All @@ -210,7 +210,6 @@ def test_linear_observer_tensor(self, observe_weight: bool):
atol=5e-5,
rtol=0.0,
)
self.assertIsNotNone(weight_zero_point)
else:
self.assertIsNone(linear.weight.weight_observer)

Expand Down
26 changes: 26 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,32 @@ def test_fake_quantize_affine_cachemask(self):
torch.testing.assert_close(dequantized, fake_quantized)
torch.testing.assert_close(expected_mask, mask)

# ZeroPointDomain.NONE should work
def test_none_zero_point_domain(self):
input = torch.randn(10, 256)
n_bit = 8
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (1, 128)
quant_min = None
quant_max = None
eps = 1e-6
scale_dtype = torch.float32
zero_point_dtype = torch.int64
_, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=True,
zero_point_domain=ZeroPointDomain.NONE,
)
self.assertTrue(zero_point is None)

if __name__ == "__main__":
unittest.main()
6 changes: 3 additions & 3 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def from_hp_to_intx(
zero_point_domain,
)
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
if zero_point_domain is None:
if zero_point_domain == ZeroPointDomain.NONE or zero_point_domain is None:
zero_point = None
data = quantize_affine(
input_float,
Expand Down Expand Up @@ -349,7 +349,7 @@ def from_hp_to_floatx(
scale_dtype=scale_dtype,
zero_point_dtype=None,
preserve_zero=True,
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
_layout=_layout,
use_hqq=False,
)
Expand All @@ -376,7 +376,7 @@ def from_hp_to_floatx_static(
target_dtype=target_dtype,
quant_min=math.ceil(torch.finfo(target_dtype).min),
quant_max=math.ceil(torch.finfo(target_dtype).max),
zero_point_domain=None,
zero_point_domain=ZeroPointDomain.NONE,
_layout=_layout,
)
else:
Expand Down
15 changes: 12 additions & 3 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
PlainAQTTensorImpl,
_linear_fp_act_int8_weight_check,
_linear_fp_act_int8_weight_impl,
_linear_int8_act_int8_weight_check,
_linear_int8_act_int8_weight_impl,
_linear_sym_int8_act_sym_int8_weight_check,
_linear_sym_int8_act_sym_int8_weight_impl,
_linear_asym_int8_act_sym_int8_weight_check,
_linear_asym_int8_act_sym_int8_weight_impl
)
from torchao.dtypes.uintx.semi_sparse_layout import (
_linear_int8_act_int8_weight_semi_structured_sparse_check,
Expand Down Expand Up @@ -110,7 +112,14 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias):
# so that these can be shared by F.linear, aten.mm, aten.addmm dispatches
def _register_aqt_quantized_linear_dispatches():
for dispatch_condition, impl in [
(_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl),
(
_linear_sym_int8_act_sym_int8_weight_check,
_linear_sym_int8_act_sym_int8_weight_impl
),
(
_linear_asym_int8_act_sym_int8_weight_check,
_linear_asym_int8_act_sym_int8_weight_impl
),
(
_linear_int8_act_int8_weight_semi_structured_sparse_check,
_linear_int8_act_int8_weight_semi_structured_sparse_impl,
Expand Down
72 changes: 62 additions & 10 deletions torchao/dtypes/uintx/plain_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __new__(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
kwargs = {}
Expand All @@ -55,7 +55,7 @@ def __init__(
self,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
self.int_data = int_data
Expand All @@ -64,7 +64,10 @@ def __init__(
self._layout = _layout

def __tensor_flatten__(self):
return ["int_data", "scale", "zero_point"], [self._layout]
if self.zero_point is not None:
return ["int_data", "scale", "zero_point"], [self._layout]
else:
return ["int_data", "scale"], [self._layout]

@classmethod
def __tensor_unflatten__(
Expand All @@ -73,7 +76,7 @@ def __tensor_unflatten__(
int_data, scale, zero_point = (
tensor_data_dict["int_data"],
tensor_data_dict["scale"],
tensor_data_dict["zero_point"],
tensor_data_dict.get("zero_point", None),
)
(_layout,) = tensor_attributes
return cls(int_data, scale, zero_point, _layout)
Expand All @@ -83,15 +86,15 @@ def to(self, *args, **kwargs):
return self.__class__(
self.int_data.to(kwargs["device"]),
self.scale.to(kwargs["device"]),
self.zero_point.to(kwargs["device"]),
self.zero_point.to(kwargs["device"]) if self.zero_point is not None else None,
self._layout,
)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.int_data),
fn(self.scale),
fn(self.zero_point),
fn(self.zero_point) if self.zero_point is not None else None,
self._layout,
)

Expand Down Expand Up @@ -134,7 +137,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
return PlainAQTTensorImpl(
aten.slice.Tensor(self.int_data, dim, start, end, step),
self.scale.view(-1),
self.zero_point.view(-1),
self.zero_point.view(-1) if self.zero_point is not None else None,
self._layout,
)
else:
Expand All @@ -148,7 +151,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):

__torch_function__ = torch._C._disabled_torch_function_impl

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return self.int_data, self.scale, self.zero_point

def get_layout(self) -> Layout:
Expand Down Expand Up @@ -220,7 +223,7 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias):
return y


def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias):
def _linear_sym_int8_act_sym_int8_weight_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
and _aqt_is_int8_reduced_range(input_tensor)
Expand All @@ -231,7 +234,7 @@ def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias):
)


def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
def _linear_sym_int8_act_sym_int8_weight_impl(input_tensor, weight_tensor, bias):
#
# 1. do the matrix form of dot(X_i, W_j)
#
Expand Down Expand Up @@ -266,3 +269,52 @@ def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
if bias is not None:
y += bias
return y


def _linear_asym_int8_act_sym_int8_weight_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
and _aqt_is_int8(input_tensor)
and weight_tensor.zero_point_domain == ZeroPointDomain.NONE
and isinstance(weight_tensor, AffineQuantizedTensor)
and input_tensor.dtype == weight_tensor.dtype
and isinstance(input_tensor._layout, PlainLayout)
and isinstance(weight_tensor._layout, PlainLayout)
)


def _linear_asym_int8_act_sym_int8_weight_impl(input_tensor, weight_tensor, bias):
#
# 1. do the matrix form of dot(X_i, W_j)
#
#
# 2. rescale the output and apply compensation for zero point of A
#
x_vals_int8 = input_tensor.tensor_impl.int_data
x_zps = input_tensor.tensor_impl.zero_point.reshape(-1, 1)
x_scales = input_tensor.tensor_impl.scale.reshape(-1, 1)
w_vals_int8_t = weight_tensor.tensor_impl.int_data.contiguous().t()
w_scales = weight_tensor.tensor_impl.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
x_scales_dtype = x_scales.dtype
# Cast fp16 scale to float to avoid overflow in int_scaled_matmul
intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype
y_dot_scaled = int_scaled_matmul(
tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)
)
y_dot_scaled = y_dot_scaled.to(x_scales_dtype) * w_scales

# Compute compensation
w_col_sum = w_vals_int8_t.to(torch.float).sum(dim=0)
a_compensation = ((x_scales * w_scales) * x_zps.to(intermediate_dtype)) * w_col_sum

y = (y_dot_scaled - a_compensation).reshape(
*x_vals_int8.shape[:-1], y_dot_scaled.shape[-1]
)

# can downcast only at the very end
output_dtype = input_tensor.dtype
y = y.to(output_dtype)
if bias is not None:
y += bias
return y
Loading