diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index f3a1074ba5..8d0af8b369 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -25,7 +25,9 @@ ) from torchao.kernel.intmm import int_scaled_matmul from .utils import find_multiple -from typing import Tuple, Optional, Callable +from typing import Tuple, Optional, Callable, Dict, Any +from collections import defaultdict +import functools __all__ = [ @@ -627,6 +629,63 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8): int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) return int_data, scales_and_zeros, False, groupsize, inner_k_tiles +def to_aqt( + input_float, + mapping_type, + block_size, + target_dtype, + quant_min = None, + quant_max = None, + eps = None, + scale_dtype = None, + zero_point_dtype = None, + preserve_zero = True, + zero_point_domain = ZeroPointDomain.INT, +): + return AffineQuantizedTensor.from_float( + input_float, + mapping_type, + block_size, + target_dtype, + quant_min=quant_min, + quant_max=quant_max, + eps=eps, + scale_dtype=scale_dtype, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain + ) + +# TODO: merge with nf4 implements decorator +# aten op to their __torch_dispatch__ implemnetations for the tensor subclass +_ATEN_OPS_TABLE: Dict[Callable, Dict[Any, Any]] = defaultdict(dict) + +def implements_aten_ops(cls, aten_ops): + """Use this decorator to implement a function for an aten op in __torch_dispatch__""" + + def decorator(func): + for op in aten_ops: + _ATEN_OPS_TABLE[cls][op] = func + return func + + return decorator + +_TORCH_FUNCTIONS_TABLE: Dict[Callable, Dict[Any, Any]] = defaultdict(dict) + +def implements_torch_function(cls, torch_function): + def decorator(func): + functools.update_wrapper(func, torch_function) + _TORCH_FUNCTIONS_TABLE[cls][torch_function] = func + return func + + return decorator + +def implements_aqt_aten_ops(aten_ops): + return implements_aten_ops(AffineQuantizedTensor, aten_ops) + +def implements_aqt_torch_function(torch_function): + return implements_torch_function(AffineQuantizedTensor, torch_function) + class AffineQuantizedTensor(torch.Tensor): """ @@ -772,101 +831,8 @@ def from_float( def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = {} if kwargs is None else kwargs - if func is torch.nn.functional.linear: - input_tensor, weight_qtensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - is_cuda = weight_qtensor.is_cuda - is_cpu = weight_qtensor.device == torch.device("cpu") - if isinstance(weight_qtensor, AffineQuantizedTensor): - weight_is_int8 = _aqt_is_int8(weight_qtensor) - weight_is_uint4 = _aqt_is_uint4(weight_qtensor) - - if isinstance(input_tensor, AffineQuantizedTensor): - # if input tensor is quantized, either dispatch to the int8 mm kernel - # or just dequantize the input tensor - input_is_int8 = _aqt_is_int8_reduced_range(input_tensor) - input_tensor_dtype_is_expected = input_tensor.dtype in [ - torch.float, - torch.bfloat16 - ] - if ( - is_cuda and - input_is_int8 and - input_tensor_dtype_is_expected - ): - # - # 1. do the matrix form of dot(X_i, W_j) - # - # - # 2. rescale the output - # - # in cases with large matrices, y_dot_int32 can grow sufficiently - # large that y_dot_int32 * a float16 scale is greater than the maximum - # value of a float 16, (which results in a value of inf even if multiplying - # by the other scale would bring it within the expected range) - - x_vals_int8 = input_tensor.int_data - x_scales = input_tensor.scale - w_vals_int8_t = weight_qtensor.int_data.contiguous().t() - w_scales = weight_qtensor.scale - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) - y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1)) - - y = (y_dot_scaled * w_scales).reshape( - *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] - ) - - # can downcast only at the very end - output_dtype = input_tensor.dtype - y = y.to(output_dtype) - if bias is not None: - y += bias - return y - else: - input_tensor = input_tensor.dequantize() - - # weight only quantization - # TODO: enable cpu and mps path as well - # TODO: make sure weight dimension matches the expectation of the int4mm kernel - # TODO: move this to TinygemmAffineQuantizedTensor - if ( - is_cuda and - weight_is_uint4 and - weight_qtensor.dtype == torch.bfloat16 and - len(weight_qtensor.shape) == 2 and - weight_qtensor.block_size[0] == 1 and - weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT - ): - # groupwise int4 quantization - # TODO: currently doing packing on the fly, we'll need to figure out - # the API to do packing before hand - # TODO: expose the arg - innerKTiles = 8 - packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles) - scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point) - groupsize = weight_qtensor.block_size[-1] - return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros) - elif ( - is_cpu and - weight_is_int8 and - len(weight_qtensor.shape) == 2 and - len(weight_qtensor.block_size) == 2 and - weight_qtensor.block_size[0] == 1 and - weight_qtensor.block_size[1] == weight_qtensor.shape[1] - ): - # TODO: enable mps path as well - # per channel int8 weight only quantizated mm - return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale) - else: - weight_tensor = weight_qtensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - else: - if isinstance(input_tensor, AffineQuantizedTensor): - input_tensor = input_tensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + if func in _TORCH_FUNCTIONS_TABLE[cls]: + return _TORCH_FUNCTIONS_TABLE[cls][func](*args, **kwargs) with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) @@ -927,62 +893,170 @@ def __torch_dispatch__(cls, func, types, args, kwargs): # 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized # kernels in CPU as well, see the note above # 2 - we're given non-floats - quantizing long to int8 is crazy - if ( - func in [aten.mm.default, aten.addmm.default] - and args[0].is_floating_point() - and args[0].device == torch.device("cpu") - ): - if func == aten.addmm.default: - assert args[1].shape[-1] == args[2].shape[0], ( - f"need mat1 shape: {args[1].shape} final" - f"dim to match mat2 shape: {args[2].shape} first dim " - ) - input_tensor, weight_qtensor, bias = ( - args[1], - args[2], - args[0], + + if func in _ATEN_OPS_TABLE[cls]: + return _ATEN_OPS_TABLE[cls][func](func, *args, **kwargs) + + raise NotImplementedError( + f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported" + ) + +@implements_aqt_torch_function(torch.nn.functional.linear) +def functional_linear(*args, **kwargs): + input_tensor, weight_qtensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + is_cuda = weight_qtensor.is_cuda + is_cpu = weight_qtensor.device == torch.device("cpu") + if isinstance(weight_qtensor, AffineQuantizedTensor): + weight_is_int8 = _aqt_is_int8(weight_qtensor) + weight_is_uint4 = _aqt_is_uint4(weight_qtensor) + + if isinstance(input_tensor, AffineQuantizedTensor): + # if input tensor is quantized, either dispatch to the int8 mm kernel + # or just dequantize the input tensor + input_is_int8 = _aqt_is_int8_reduced_range(input_tensor) + input_tensor_dtype_is_expected = input_tensor.dtype in [ + torch.float, + torch.bfloat16 + ] + if ( + is_cuda and + input_is_int8 and + input_tensor_dtype_is_expected + ): + # + # 1. do the matrix form of dot(X_i, W_j) + # + # + # 2. rescale the output + # + # in cases with large matrices, y_dot_int32 can grow sufficiently + # large that y_dot_int32 * a float16 scale is greater than the maximum + # value of a float 16, (which results in a value of inf even if multiplying + # by the other scale would bring it within the expected range) + + x_vals_int8 = input_tensor.int_data + x_scales = input_tensor.scale + w_vals_int8_t = weight_qtensor.int_data.contiguous().t() + w_scales = weight_qtensor.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1)) + + y = (y_dot_scaled * w_scales).reshape( + *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] ) + + # can downcast only at the very end + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y else: - assert args[0].shape[-1] == args[1].shape[0], ( - f"need mat1 shape: {args[0].shape} final dim" - f"to match mat2 shape: {args[1].shape} first dim" - ) - input_tensor, weight_qtensor, bias = ( - args[0], - args[1], - None if len(args) == 2 else args[2], - ) + input_tensor = input_tensor.dequantize() + + # weight only quantization + # TODO: enable cpu and mps path as well + # TODO: make sure weight dimension matches the expectation of the int4mm kernel + # TODO: move this to TinygemmAffineQuantizedTensor + if ( + is_cuda and + weight_is_uint4 and + weight_qtensor.dtype == torch.bfloat16 and + len(weight_qtensor.shape) == 2 and + weight_qtensor.block_size[0] == 1 and + weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT + ): + # groupwise int4 quantization + # TODO: currently doing packing on the fly, we'll need to figure out + # the API to do packing before hand + # TODO: expose the arg + innerKTiles = 8 + packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles) + scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point) + groupsize = weight_qtensor.block_size[-1] + return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros) + elif ( + is_cpu and + weight_is_int8 and + len(weight_qtensor.shape) == 2 and + len(weight_qtensor.block_size) == 2 and + weight_qtensor.block_size[0] == 1 and + weight_qtensor.block_size[1] == weight_qtensor.shape[1] + ): + # TODO: enable mps path as well + # per channel int8 weight only quantizated mm + return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale) + else: weight_tensor = weight_qtensor.dequantize() - return func(input_tensor, weight_tensor, bias) + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + else: + if isinstance(input_tensor, AffineQuantizedTensor): + input_tensor = input_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +@implements_aqt_aten_ops([aten.mm.default, aten.addmm.default]) +def aten_mm(func, *args, **kwargs): + if not args[0].is_floating_point(): + raise NotImplementedError(f"{func} is not implemented for non floating point input") + + if func == aten.addmm.default: + assert args[1].shape[-1] == args[2].shape[0], ( + f"need mat1 shape: {args[1].shape} final" + f"dim to match mat2 shape: {args[2].shape} first dim " + ) + input_tensor, weight_qtensor, bias = ( + args[1], + args[2], + args[0], + ) + else: + assert args[0].shape[-1] == args[1].shape[0], ( + f"need mat1 shape: {args[0].shape} final dim" + f"to match mat2 shape: {args[1].shape} first dim" + ) + input_tensor, weight_qtensor, bias = ( + args[0], + args[1], + None if len(args) == 2 else args[2], + ) + weight_tensor = weight_qtensor.dequantize() + return func(input_tensor, weight_tensor, bias) - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) +@implements_aqt_aten_ops([aten.detach.default]) +def detach(func, *args, **kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) - if func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - if func is aten.t.default: - # TODO: need to implement this - # args[0].transposed = not args[0].transposed - # new = args[0]._change_shape(args[0].shape[::-1]) - # return return_and_correct_aliasing(func, args, kwargs, new) - raise Exception("transpose not implemented yet") +@implements_aqt_aten_ops([aten.clone.default]) +def clone(func, *args, **kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) - if func is aten._to_copy.default: - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), - ) - raise NotImplementedError( - f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported" - ) +@implements_aqt_aten_ops([aten._to_copy.default]) +def _to_copy(func, *args, **kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + +@implements_aqt_aten_ops([aten.t.default]) +def t(func, *args, **kwargs): + # TODO: need to implement this + # args[0].transposed = not args[0].transposed + # new = args[0]._change_shape(args[0].shape[::-1]) + # return return_and_correct_aliasing(func, args, kwargs, new) + raise Exception("transpose not implemented yet") class LinearActQuantizedTensor(torch.Tensor):