Skip to content

Commit

Permalink
Move MarlinQQQTensor out of AQT (#1385)
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva authored Dec 7, 2024
1 parent 8a805d0 commit 23db9bf
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 60 deletions.
19 changes: 19 additions & 0 deletions torchao/dtypes/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# README

## File Structure of the `dtypes` Folder

The `dtypes` folder contains several important files and subfolders that are organized as follows:

- **affine_quantized_tensor.py**: This is the main file, from which the subfolders `uintx` and `floatx` inherit. It contains the base tensor subclass `AffineQuantizedTensor` and code for layout and tensorImpl registration.

- **affine_quantized_tensor_ops.py**: This file defines all the overriden aten ops and different dispatch kernels related to affine quantized tensors.

- **utils.py**: A utility file that provides helper functions and common utilities used across different files in the `dtypes` folder.

- **nf4tensor.py**: This file is specific to the NF4 tensor implementation, and layouts.

### Subfolders

- **uintx**: A subfolder that contains layouts and tensor subclasses inheriting from `affine_quantized_tensor.py`. It is specialized for handling unsigned integer quantized tensors.

- **floatx**: Similar to `uintx`, this subfolder contains layouts and tensor subclasses that inherit from `affine_quantized_tensor.py`, but it is focused on floating-point quantized tensors.
4 changes: 2 additions & 2 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from . import affine_quantized_tensor_ops
from .affine_quantized_tensor import (
AffineQuantizedTensor,
MarlinQQQTensor,
to_affine_quantized_floatx,
to_affine_quantized_floatx_static,
# experimental, will be merged into floatx in the future
to_affine_quantized_fpx,
to_affine_quantized_intx,
to_affine_quantized_intx_static,
to_marlinqqq_quantized_intx,
)
from .floatx import (
Float8Layout,
Expand All @@ -18,10 +16,12 @@
BlockSparseLayout,
Int4CPULayout,
MarlinQQQLayout,
MarlinQQQTensor,
MarlinSparseLayout,
SemiSparseLayout,
TensorCoreTiledLayout,
UintxLayout,
to_marlinqqq_quantized_intx,
)
from .utils import (
Layout,
Expand Down
56 changes: 0 additions & 56 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@
choose_qparams_affine,
choose_qparams_affine_floatx,
choose_qparams_and_quantize_affine_hqq,
choose_qparams_and_quantize_affine_qqq,
dequantize_affine,
dequantize_affine_floatx,
dequantize_affine_qqq,
quantize_affine,
quantize_affine_floatx,
)
Expand All @@ -33,14 +31,12 @@

__all__ = [
"AffineQuantizedTensor",
"MarlinQQQTensor",
"register_layout",
"to_affine_quantized_intx",
"to_affine_quantized_floatx",
"to_affine_quantized_intx_static",
"to_affine_quantized_floatx_static",
"to_affine_quantized_fpx",
"to_marlinqqq_quantized_intx",
]


Expand Down Expand Up @@ -459,57 +455,6 @@ def _apply_fn_to_data(self, fn):
# 2 - we're given non-floats - quantizing long to int8 is crazy


class MarlinQQQTensor(AffineQuantizedTensor):
"""
MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class.
To see what happens during choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization,
please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py
and check the two quant primitive ops: choose_qparams_and_quantize_affine_qqq and dequantize_affine_qqq
"""

def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
if output_dtype is None:
output_dtype = self.dtype

int_data, s_group, s_channel = self.tensor_impl.get_plain()
nbits = int(math.log2(self.quant_max - self.quant_min + 1))
group_size = max(self.block_size)
return dequantize_affine_qqq(
int_data, s_group, s_channel, nbits, group_size, output_dtype
)

@classmethod
def from_hp_to_intx(
cls,
input_float: torch.Tensor,
block_size: Tuple[int, ...],
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
_layout: Optional[Layout] = None,
):
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)
nbits = int(math.log2(quant_max - quant_min + 1))
group_size = max(block_size)
data, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq(
input_float, nbits, group_size
)
data = _layout.post_process(data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(data, s_group, s_channel, _layout)
return cls(
tensor_impl,
block_size,
original_shape,
quant_min,
quant_max,
zero_point_domain,
dtype=input_float.dtype,
)


######################################################
# Layout and TensorImpl Subclass Registration #
######################################################
Expand All @@ -522,7 +467,6 @@ def from_hp_to_intx(
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
# experimental will be merged in to floatx
to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx
to_marlinqqq_quantized_intx = MarlinQQQTensor.from_hp_to_intx

if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True`
Expand Down
2 changes: 1 addition & 1 deletion torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
_linear_int8_act_int8_weight_block_sparse_check,
_linear_int8_act_int8_weight_block_sparse_impl,
)
from torchao.dtypes.uintx.marlin_qqq_layout import (
from torchao.dtypes.uintx.marlin_qqq_tensor import (
_linear_int8_act_int4_weight_marlin_qqq_check,
_linear_int8_act_int4_weight_marlin_qqq_impl,
)
Expand Down
6 changes: 5 additions & 1 deletion torchao/dtypes/uintx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .block_sparse_layout import (
BlockSparseLayout,
)
from .marlin_qqq_layout import (
from .marlin_qqq_tensor import (
MarlinQQQLayout,
MarlinQQQTensor,
to_marlinqqq_quantized_intx,
)
from .marlin_sparse_layout import (
MarlinSparseLayout,
Expand All @@ -26,4 +28,6 @@
"TensorCoreTiledLayout",
"Int4CPULayout",
"MarlinQQQLayout",
"MarlinQQQTensor",
"to_marlinqqq_quantized_intx",
]
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import math
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
from torch.utils._python_dispatch import (
Expand All @@ -8,18 +10,75 @@

from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
get_tensor_impl_constructor,
register_layout,
)
from torchao.dtypes.uintx.plain_layout import (
_aqt_is_int8_reduced_range,
)
from torchao.dtypes.utils import AQTTensorImpl, Layout
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
choose_qparams_and_quantize_affine_qqq,
dequantize_affine_qqq,
)

logger = logging.getLogger(__name__)

aten = torch.ops.aten


class MarlinQQQTensor(AffineQuantizedTensor):
"""
MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class.
To see what happens during choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization,
please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py
and check the two quant primitive ops: choose_qparams_and_quantize_affine_qqq and dequantize_affine_qqq
"""

def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
if output_dtype is None:
output_dtype = self.dtype

int_data, s_group, s_channel = self.tensor_impl.get_plain()
nbits = int(math.log2(self.quant_max - self.quant_min + 1))
group_size = max(self.block_size)
return dequantize_affine_qqq(
int_data, s_group, s_channel, nbits, group_size, output_dtype
)

@classmethod
def from_hp_to_intx(
cls,
input_float: torch.Tensor,
block_size: Tuple[int, ...],
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
_layout: Optional[Layout] = None,
):
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)
nbits = int(math.log2(quant_max - quant_min + 1))
group_size = max(block_size)
data, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq(
input_float, nbits, group_size
)
data = _layout.post_process(data)
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(data, s_group, s_channel, _layout)
return cls(
tensor_impl,
block_size,
original_shape,
quant_min,
quant_max,
zero_point_domain,
dtype=input_float.dtype,
)


@dataclass(frozen=True)
class MarlinQQQLayout(Layout):
pass
Expand Down Expand Up @@ -279,3 +338,6 @@ def _linear_int8_act_int4_weight_marlin_qqq_impl(input_tensor, weight_tensor, bi
if bias is not None:
out += bias.to(out.dtype)
return out


to_marlinqqq_quantized_intx = MarlinQQQTensor.from_hp_to_intx

0 comments on commit 23db9bf

Please sign in to comment.