From 23db9bf72a544773ec5fcc9f8b2bc36bc4dcc17d Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Fri, 6 Dec 2024 16:48:06 -0800 Subject: [PATCH] Move MarlinQQQTensor out of AQT (#1385) --- torchao/dtypes/README.md | 19 ++++++ torchao/dtypes/__init__.py | 4 +- torchao/dtypes/affine_quantized_tensor.py | 56 ----------------- torchao/dtypes/affine_quantized_tensor_ops.py | 2 +- torchao/dtypes/uintx/__init__.py | 6 +- ...lin_qqq_layout.py => marlin_qqq_tensor.py} | 62 +++++++++++++++++++ 6 files changed, 89 insertions(+), 60 deletions(-) create mode 100644 torchao/dtypes/README.md rename torchao/dtypes/uintx/{marlin_qqq_layout.py => marlin_qqq_tensor.py} (79%) diff --git a/torchao/dtypes/README.md b/torchao/dtypes/README.md new file mode 100644 index 0000000000..c1124c648f --- /dev/null +++ b/torchao/dtypes/README.md @@ -0,0 +1,19 @@ +# README + +## File Structure of the `dtypes` Folder + +The `dtypes` folder contains several important files and subfolders that are organized as follows: + +- **affine_quantized_tensor.py**: This is the main file, from which the subfolders `uintx` and `floatx` inherit. It contains the base tensor subclass `AffineQuantizedTensor` and code for layout and tensorImpl registration. + +- **affine_quantized_tensor_ops.py**: This file defines all the overriden aten ops and different dispatch kernels related to affine quantized tensors. + +- **utils.py**: A utility file that provides helper functions and common utilities used across different files in the `dtypes` folder. + +- **nf4tensor.py**: This file is specific to the NF4 tensor implementation, and layouts. + +### Subfolders + +- **uintx**: A subfolder that contains layouts and tensor subclasses inheriting from `affine_quantized_tensor.py`. It is specialized for handling unsigned integer quantized tensors. + +- **floatx**: Similar to `uintx`, this subfolder contains layouts and tensor subclasses that inherit from `affine_quantized_tensor.py`, but it is focused on floating-point quantized tensors. diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 00305db348..c7d98cb56e 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,14 +1,12 @@ from . import affine_quantized_tensor_ops from .affine_quantized_tensor import ( AffineQuantizedTensor, - MarlinQQQTensor, to_affine_quantized_floatx, to_affine_quantized_floatx_static, # experimental, will be merged into floatx in the future to_affine_quantized_fpx, to_affine_quantized_intx, to_affine_quantized_intx_static, - to_marlinqqq_quantized_intx, ) from .floatx import ( Float8Layout, @@ -18,10 +16,12 @@ BlockSparseLayout, Int4CPULayout, MarlinQQQLayout, + MarlinQQQTensor, MarlinSparseLayout, SemiSparseLayout, TensorCoreTiledLayout, UintxLayout, + to_marlinqqq_quantized_intx, ) from .utils import ( Layout, diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 93d2766d1e..7aca25ecc5 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -16,10 +16,8 @@ choose_qparams_affine, choose_qparams_affine_floatx, choose_qparams_and_quantize_affine_hqq, - choose_qparams_and_quantize_affine_qqq, dequantize_affine, dequantize_affine_floatx, - dequantize_affine_qqq, quantize_affine, quantize_affine_floatx, ) @@ -33,14 +31,12 @@ __all__ = [ "AffineQuantizedTensor", - "MarlinQQQTensor", "register_layout", "to_affine_quantized_intx", "to_affine_quantized_floatx", "to_affine_quantized_intx_static", "to_affine_quantized_floatx_static", "to_affine_quantized_fpx", - "to_marlinqqq_quantized_intx", ] @@ -459,57 +455,6 @@ def _apply_fn_to_data(self, fn): # 2 - we're given non-floats - quantizing long to int8 is crazy -class MarlinQQQTensor(AffineQuantizedTensor): - """ - MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. - - To see what happens during choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization, - please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py - and check the two quant primitive ops: choose_qparams_and_quantize_affine_qqq and dequantize_affine_qqq - """ - - def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: - if output_dtype is None: - output_dtype = self.dtype - - int_data, s_group, s_channel = self.tensor_impl.get_plain() - nbits = int(math.log2(self.quant_max - self.quant_min + 1)) - group_size = max(self.block_size) - return dequantize_affine_qqq( - int_data, s_group, s_channel, nbits, group_size, output_dtype - ) - - @classmethod - def from_hp_to_intx( - cls, - input_float: torch.Tensor, - block_size: Tuple[int, ...], - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, - _layout: Optional[Layout] = None, - ): - original_shape = input_float.shape - input_float = _layout.pre_process(input_float) - nbits = int(math.log2(quant_max - quant_min + 1)) - group_size = max(block_size) - data, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq( - input_float, nbits, group_size - ) - data = _layout.post_process(data) - tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) - tensor_impl = tensor_impl_ctr(data, s_group, s_channel, _layout) - return cls( - tensor_impl, - block_size, - original_shape, - quant_min, - quant_max, - zero_point_domain, - dtype=input_float.dtype, - ) - - ###################################################### # Layout and TensorImpl Subclass Registration # ###################################################### @@ -522,7 +467,6 @@ def from_hp_to_intx( to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static # experimental will be merged in to floatx to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx -to_marlinqqq_quantized_intx = MarlinQQQTensor.from_hp_to_intx if TORCH_VERSION_AT_LEAST_2_5: # Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True` diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index bd7ff7d333..8938e7472c 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -20,7 +20,7 @@ _linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl, ) -from torchao.dtypes.uintx.marlin_qqq_layout import ( +from torchao.dtypes.uintx.marlin_qqq_tensor import ( _linear_int8_act_int4_weight_marlin_qqq_check, _linear_int8_act_int4_weight_marlin_qqq_impl, ) diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 8fba2bb678..4b1f3d39c8 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,8 +1,10 @@ from .block_sparse_layout import ( BlockSparseLayout, ) -from .marlin_qqq_layout import ( +from .marlin_qqq_tensor import ( MarlinQQQLayout, + MarlinQQQTensor, + to_marlinqqq_quantized_intx, ) from .marlin_sparse_layout import ( MarlinSparseLayout, @@ -26,4 +28,6 @@ "TensorCoreTiledLayout", "Int4CPULayout", "MarlinQQQLayout", + "MarlinQQQTensor", + "to_marlinqqq_quantized_intx", ] diff --git a/torchao/dtypes/uintx/marlin_qqq_layout.py b/torchao/dtypes/uintx/marlin_qqq_tensor.py similarity index 79% rename from torchao/dtypes/uintx/marlin_qqq_layout.py rename to torchao/dtypes/uintx/marlin_qqq_tensor.py index c3b2a78394..b75d959b41 100644 --- a/torchao/dtypes/uintx/marlin_qqq_layout.py +++ b/torchao/dtypes/uintx/marlin_qqq_tensor.py @@ -1,5 +1,7 @@ import logging +import math from dataclasses import dataclass +from typing import Optional, Tuple import torch from torch.utils._python_dispatch import ( @@ -8,18 +10,75 @@ from torchao.dtypes.affine_quantized_tensor import ( AffineQuantizedTensor, + get_tensor_impl_constructor, register_layout, ) from torchao.dtypes.uintx.plain_layout import ( _aqt_is_int8_reduced_range, ) from torchao.dtypes.utils import AQTTensorImpl, Layout +from torchao.quantization.quant_primitives import ( + ZeroPointDomain, + choose_qparams_and_quantize_affine_qqq, + dequantize_affine_qqq, +) logger = logging.getLogger(__name__) aten = torch.ops.aten +class MarlinQQQTensor(AffineQuantizedTensor): + """ + MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. + + To see what happens during choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization, + please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py + and check the two quant primitive ops: choose_qparams_and_quantize_affine_qqq and dequantize_affine_qqq + """ + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + + int_data, s_group, s_channel = self.tensor_impl.get_plain() + nbits = int(math.log2(self.quant_max - self.quant_min + 1)) + group_size = max(self.block_size) + return dequantize_affine_qqq( + int_data, s_group, s_channel, nbits, group_size, output_dtype + ) + + @classmethod + def from_hp_to_intx( + cls, + input_float: torch.Tensor, + block_size: Tuple[int, ...], + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + _layout: Optional[Layout] = None, + ): + original_shape = input_float.shape + input_float = _layout.pre_process(input_float) + nbits = int(math.log2(quant_max - quant_min + 1)) + group_size = max(block_size) + data, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq( + input_float, nbits, group_size + ) + data = _layout.post_process(data) + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(data, s_group, s_channel, _layout) + return cls( + tensor_impl, + block_size, + original_shape, + quant_min, + quant_max, + zero_point_domain, + dtype=input_float.dtype, + ) + + @dataclass(frozen=True) class MarlinQQQLayout(Layout): pass @@ -279,3 +338,6 @@ def _linear_int8_act_int4_weight_marlin_qqq_impl(input_tensor, weight_tensor, bi if bias is not None: out += bias.to(out.dtype) return out + + +to_marlinqqq_quantized_intx = MarlinQQQTensor.from_hp_to_intx