Skip to content

Commit

Permalink
Add to function and decorator for AffineQuantizedTensor (pytorch#251)
Browse files Browse the repository at this point in the history
Summary:
att
Next: we can move AffineQuantizedTensor to dtypes and make nf4tensor to use the same implements decorator

Test Plan:
python test/quantization/test_quant_api.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored May 17, 2024
1 parent 9b25ecc commit 5741aa2
Showing 1 changed file with 218 additions and 144 deletions.
362 changes: 218 additions & 144 deletions torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
)
from torchao.kernel.intmm import int_scaled_matmul
from .utils import find_multiple
from typing import Tuple, Optional, Callable
from typing import Tuple, Optional, Callable, Dict, Any
from collections import defaultdict
import functools


__all__ = [
Expand Down Expand Up @@ -627,6 +629,63 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8):
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles)
return int_data, scales_and_zeros, False, groupsize, inner_k_tiles

def to_aqt(
input_float,
mapping_type,
block_size,
target_dtype,
quant_min = None,
quant_max = None,
eps = None,
scale_dtype = None,
zero_point_dtype = None,
preserve_zero = True,
zero_point_domain = ZeroPointDomain.INT,
):
return AffineQuantizedTensor.from_float(
input_float,
mapping_type,
block_size,
target_dtype,
quant_min=quant_min,
quant_max=quant_max,
eps=eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=zero_point_domain
)

# TODO: merge with nf4 implements decorator
# aten op to their __torch_dispatch__ implemnetations for the tensor subclass
_ATEN_OPS_TABLE: Dict[Callable, Dict[Any, Any]] = defaultdict(dict)

def implements_aten_ops(cls, aten_ops):
"""Use this decorator to implement a function for an aten op in __torch_dispatch__"""

def decorator(func):
for op in aten_ops:
_ATEN_OPS_TABLE[cls][op] = func
return func

return decorator

_TORCH_FUNCTIONS_TABLE: Dict[Callable, Dict[Any, Any]] = defaultdict(dict)

def implements_torch_function(cls, torch_function):
def decorator(func):
functools.update_wrapper(func, torch_function)
_TORCH_FUNCTIONS_TABLE[cls][torch_function] = func
return func

return decorator

def implements_aqt_aten_ops(aten_ops):
return implements_aten_ops(AffineQuantizedTensor, aten_ops)

def implements_aqt_torch_function(torch_function):
return implements_torch_function(AffineQuantizedTensor, torch_function)


class AffineQuantizedTensor(torch.Tensor):
"""
Expand Down Expand Up @@ -772,101 +831,8 @@ def from_float(
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs

if func is torch.nn.functional.linear:
input_tensor, weight_qtensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
is_cuda = weight_qtensor.is_cuda
is_cpu = weight_qtensor.device == torch.device("cpu")
if isinstance(weight_qtensor, AffineQuantizedTensor):
weight_is_int8 = _aqt_is_int8(weight_qtensor)
weight_is_uint4 = _aqt_is_uint4(weight_qtensor)

if isinstance(input_tensor, AffineQuantizedTensor):
# if input tensor is quantized, either dispatch to the int8 mm kernel
# or just dequantize the input tensor
input_is_int8 = _aqt_is_int8_reduced_range(input_tensor)
input_tensor_dtype_is_expected = input_tensor.dtype in [
torch.float,
torch.bfloat16
]
if (
is_cuda and
input_is_int8 and
input_tensor_dtype_is_expected
):
#
# 1. do the matrix form of dot(X_i, W_j)
#
#
# 2. rescale the output
#
# in cases with large matrices, y_dot_int32 can grow sufficiently
# large that y_dot_int32 * a float16 scale is greater than the maximum
# value of a float 16, (which results in a value of inf even if multiplying
# by the other scale would bring it within the expected range)

x_vals_int8 = input_tensor.int_data
x_scales = input_tensor.scale
w_vals_int8_t = weight_qtensor.int_data.contiguous().t()
w_scales = weight_qtensor.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1))

y = (y_dot_scaled * w_scales).reshape(
*x_vals_int8.shape[:-1], y_dot_scaled.shape[-1]
)

# can downcast only at the very end
output_dtype = input_tensor.dtype
y = y.to(output_dtype)
if bias is not None:
y += bias
return y
else:
input_tensor = input_tensor.dequantize()

# weight only quantization
# TODO: enable cpu and mps path as well
# TODO: make sure weight dimension matches the expectation of the int4mm kernel
# TODO: move this to TinygemmAffineQuantizedTensor
if (
is_cuda and
weight_is_uint4 and
weight_qtensor.dtype == torch.bfloat16 and
len(weight_qtensor.shape) == 2 and
weight_qtensor.block_size[0] == 1 and
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT
):
# groupwise int4 quantization
# TODO: currently doing packing on the fly, we'll need to figure out
# the API to do packing before hand
# TODO: expose the arg
innerKTiles = 8
packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles)
scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point)
groupsize = weight_qtensor.block_size[-1]
return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros)
elif (
is_cpu and
weight_is_int8 and
len(weight_qtensor.shape) == 2 and
len(weight_qtensor.block_size) == 2 and
weight_qtensor.block_size[0] == 1 and
weight_qtensor.block_size[1] == weight_qtensor.shape[1]
):
# TODO: enable mps path as well
# per channel int8 weight only quantizated mm
return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale)
else:
weight_tensor = weight_qtensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
else:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
if func in _TORCH_FUNCTIONS_TABLE[cls]:
return _TORCH_FUNCTIONS_TABLE[cls][func](*args, **kwargs)

with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
Expand Down Expand Up @@ -927,62 +893,170 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
# 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized
# kernels in CPU as well, see the note above
# 2 - we're given non-floats - quantizing long to int8 is crazy
if (
func in [aten.mm.default, aten.addmm.default]
and args[0].is_floating_point()
and args[0].device == torch.device("cpu")
):
if func == aten.addmm.default:
assert args[1].shape[-1] == args[2].shape[0], (
f"need mat1 shape: {args[1].shape} final"
f"dim to match mat2 shape: {args[2].shape} first dim "
)
input_tensor, weight_qtensor, bias = (
args[1],
args[2],
args[0],

if func in _ATEN_OPS_TABLE[cls]:
return _ATEN_OPS_TABLE[cls][func](func, *args, **kwargs)

raise NotImplementedError(
f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported"
)

@implements_aqt_torch_function(torch.nn.functional.linear)
def functional_linear(*args, **kwargs):
input_tensor, weight_qtensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
is_cuda = weight_qtensor.is_cuda
is_cpu = weight_qtensor.device == torch.device("cpu")
if isinstance(weight_qtensor, AffineQuantizedTensor):
weight_is_int8 = _aqt_is_int8(weight_qtensor)
weight_is_uint4 = _aqt_is_uint4(weight_qtensor)

if isinstance(input_tensor, AffineQuantizedTensor):
# if input tensor is quantized, either dispatch to the int8 mm kernel
# or just dequantize the input tensor
input_is_int8 = _aqt_is_int8_reduced_range(input_tensor)
input_tensor_dtype_is_expected = input_tensor.dtype in [
torch.float,
torch.bfloat16
]
if (
is_cuda and
input_is_int8 and
input_tensor_dtype_is_expected
):
#
# 1. do the matrix form of dot(X_i, W_j)
#
#
# 2. rescale the output
#
# in cases with large matrices, y_dot_int32 can grow sufficiently
# large that y_dot_int32 * a float16 scale is greater than the maximum
# value of a float 16, (which results in a value of inf even if multiplying
# by the other scale would bring it within the expected range)

x_vals_int8 = input_tensor.int_data
x_scales = input_tensor.scale
w_vals_int8_t = weight_qtensor.int_data.contiguous().t()
w_scales = weight_qtensor.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1))

y = (y_dot_scaled * w_scales).reshape(
*x_vals_int8.shape[:-1], y_dot_scaled.shape[-1]
)

# can downcast only at the very end
output_dtype = input_tensor.dtype
y = y.to(output_dtype)
if bias is not None:
y += bias
return y
else:
assert args[0].shape[-1] == args[1].shape[0], (
f"need mat1 shape: {args[0].shape} final dim"
f"to match mat2 shape: {args[1].shape} first dim"
)
input_tensor, weight_qtensor, bias = (
args[0],
args[1],
None if len(args) == 2 else args[2],
)
input_tensor = input_tensor.dequantize()

# weight only quantization
# TODO: enable cpu and mps path as well
# TODO: make sure weight dimension matches the expectation of the int4mm kernel
# TODO: move this to TinygemmAffineQuantizedTensor
if (
is_cuda and
weight_is_uint4 and
weight_qtensor.dtype == torch.bfloat16 and
len(weight_qtensor.shape) == 2 and
weight_qtensor.block_size[0] == 1 and
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT
):
# groupwise int4 quantization
# TODO: currently doing packing on the fly, we'll need to figure out
# the API to do packing before hand
# TODO: expose the arg
innerKTiles = 8
packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles)
scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point)
groupsize = weight_qtensor.block_size[-1]
return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros)
elif (
is_cpu and
weight_is_int8 and
len(weight_qtensor.shape) == 2 and
len(weight_qtensor.block_size) == 2 and
weight_qtensor.block_size[0] == 1 and
weight_qtensor.block_size[1] == weight_qtensor.shape[1]
):
# TODO: enable mps path as well
# per channel int8 weight only quantizated mm
return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale)
else:
weight_tensor = weight_qtensor.dequantize()
return func(input_tensor, weight_tensor, bias)
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
else:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)


@implements_aqt_aten_ops([aten.mm.default, aten.addmm.default])
def aten_mm(func, *args, **kwargs):
if not args[0].is_floating_point():
raise NotImplementedError(f"{func} is not implemented for non floating point input")

if func == aten.addmm.default:
assert args[1].shape[-1] == args[2].shape[0], (
f"need mat1 shape: {args[1].shape} final"
f"dim to match mat2 shape: {args[2].shape} first dim "
)
input_tensor, weight_qtensor, bias = (
args[1],
args[2],
args[0],
)
else:
assert args[0].shape[-1] == args[1].shape[0], (
f"need mat1 shape: {args[0].shape} final dim"
f"to match mat2 shape: {args[1].shape} first dim"
)
input_tensor, weight_qtensor, bias = (
args[0],
args[1],
None if len(args) == 2 else args[2],
)
weight_tensor = weight_qtensor.dequantize()
return func(input_tensor, weight_tensor, bias)

if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)
@implements_aqt_aten_ops([aten.detach.default])
def detach(func, *args, **kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

if func is aten.clone.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

if func is aten.t.default:
# TODO: need to implement this
# args[0].transposed = not args[0].transposed
# new = args[0]._change_shape(args[0].shape[::-1])
# return return_and_correct_aliasing(func, args, kwargs, new)
raise Exception("transpose not implemented yet")
@implements_aqt_aten_ops([aten.clone.default])
def clone(func, *args, **kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

if func is aten._to_copy.default:
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
)

raise NotImplementedError(
f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported"
)
@implements_aqt_aten_ops([aten._to_copy.default])
def _to_copy(func, *args, **kwargs):
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
)

@implements_aqt_aten_ops([aten.t.default])
def t(func, *args, **kwargs):
# TODO: need to implement this
# args[0].transposed = not args[0].transposed
# new = args[0]._change_shape(args[0].shape[::-1])
# return return_and_correct_aliasing(func, args, kwargs, new)
raise Exception("transpose not implemented yet")


class LinearActQuantizedTensor(torch.Tensor):
Expand Down

0 comments on commit 5741aa2

Please sign in to comment.