From 603d908dd5aae06f2a433ed9897944b4795e4d79 Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Tue, 17 Dec 2024 02:39:45 +0800 Subject: [PATCH] gemlite integration in torchao (#1034) * gemlite integration in torchao Summary: This PR adds support for gemlite kernels in torchao using a subclass integration with the gemlite_uintx_weight_only constructor. This works for int4 grouped and ungrouped assymmetric oeight only quantization and int8 symmetric ungrouped quantization for fp16 models. TP support through DTensor is included in thsi PR in the process of integrating gemlite into AQT i also made some fixes to a few quant primitives that are being used which previously were not. Test Plan: test_integration.py -k "test_gemlite_layout" test_affine_quantized_tensor_parallel.py -k "test_tp_gemlite" see benchmarks.sh for gemlite benchmarks as well. Reviewers: Subscribers: Tasks: Tags: new gemlite integration using pip install Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: tests ran Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: fixing gemlite to do int4 matmul instead of fp16 fp16 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: running tests Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: more testing Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: AQT integration wip Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Wip Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: testing on gemlite a100_int8_tuning branch Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: gemlite subclass testing bitpacking 8 bits Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: bug fixing stuff Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: hicham fixes Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: new benchmarks Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: testing gemlite 8 bit Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: WIP Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: tp support Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: wip Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: final Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fixing regressions Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- .../test_affine_quantized_tensor_parallel.py | 28 ++ test/integration/test_integration.py | 34 ++ torchao/_models/llama/benchmarks.sh | 15 + torchao/_models/llama/generate.py | 42 +- torchao/_models/llama/model.py | 5 +- torchao/dtypes/affine_quantized_tensor.py | 15 +- torchao/dtypes/affine_quantized_tensor_ops.py | 8 + torchao/dtypes/uintx/gemlite_layout.py | 372 ++++++++++++++++++ torchao/quantization/README.md | 5 +- torchao/quantization/__init__.py | 4 +- torchao/quantization/quant_api.py | 28 ++ torchao/quantization/quant_primitives.py | 34 +- 12 files changed, 577 insertions(+), 13 deletions(-) create mode 100644 torchao/dtypes/uintx/gemlite_layout.py diff --git a/test/dtypes/test_affine_quantized_tensor_parallel.py b/test/dtypes/test_affine_quantized_tensor_parallel.py index da20b930d3..3abb736f92 100644 --- a/test/dtypes/test_affine_quantized_tensor_parallel.py +++ b/test/dtypes/test_affine_quantized_tensor_parallel.py @@ -19,6 +19,13 @@ from torchao.quantization.quant_api import quantize_ from torchao.utils import TORCH_VERSION_AT_LEAST_2_6 +try: + import gemlite # noqa: F401 + + has_gemlite = True +except ModuleNotFoundError: + has_gemlite = False + class TestAffineQuantizedTensorParallel(DTensorTestBase): """Basic test case for tensor subclasses""" @@ -139,8 +146,29 @@ def test_tp(self, dtype): return self._test_tp(dtype) +class TestGemliteLayoutTensorParallel(TestAffineQuantizedTensorParallel): + COMMON_DTYPES = [torch.float16] + + @common_utils.parametrize("dtype", COMMON_DTYPES) + @with_comms + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + @unittest.skipIf(not has_gemlite, "gemlite not available") + def test_tp_gemlite(self, dtype): + from torchao.quantization import gemlite_uintx_weight_only + + for packing_bitwidth in [32, 8]: + for bit_width in [4, 8]: + for group_size in [64, 32, None] if bit_width == 4 else [None]: + api = lambda: gemlite_uintx_weight_only( + group_size, bit_width, packing_bitwidth + ) + self.QUANT_METHOD_FN = staticmethod(api) + return self._test_tp(dtype) + + common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel) common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel) +common_utils.instantiate_parametrized_tests(TestGemliteLayoutTensorParallel) # Run only on H100 if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0): diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 6aae8b2e31..faabf48ab5 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -96,6 +96,12 @@ ) from torchao.dtypes.utils import is_device +try: + import gemlite + has_gemlite = True +except ModuleNotFoundError: + has_gemlite = False + logger = logging.getLogger("INFO") torch.manual_seed(0) @@ -870,6 +876,10 @@ def _test_lin_weight_subclass_api_impl( ref_f = mod(x) api(mod) + # test get_plain() + if hasattr(mod[0].weight, "tensor_impl"): + mod[0].weight.tensor_impl.get_plain() + test = mod(x) self.assertGreater( SQNR(ref_f, test), @@ -930,6 +940,30 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype): test_dtype=dtype ) + @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "gemlite tests needs torch 2.5 or greater") + @unittest.skipIf(not has_gemlite, "gemlite not available") + def test_gemlite_layout(self, device, dtype): + if dtype!= torch.float16: + self.skipTest(f"gemlite only works for fp16 dtype") + from torchao.quantization import gemlite_uintx_weight_only + if device == "cpu": + self.skipTest(f"gemlite is for cuda, not {device}") + for packing_bitwidth in [32, 8]: + for bit_width in [4,8]: + for group_size in [64, 32, None] if bit_width ==4 else [None]: + api = lambda mod: quantize_(mod, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth)) + for test_shape in [[1, 1024, 512],[16, 256, 1024], [128, 256, 1024]]: + print(packing_bitwidth, bit_width, group_size, test_shape, dtype) + self._test_lin_weight_subclass_api_impl( + api, + device, + 15, + test_shape=test_shape, + test_dtype=dtype, + ) + + @parameterized.expand(COMMON_DEVICE_DTYPE) @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.") # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now") diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index c8cd4bf39c..2f50e47dcd 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -91,6 +91,21 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int4wo-64 --write_result benchmark_results.txt --prefill_size 8000 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization sparse-marlin --write_result benchmark_results.txt --prefill_size 8000 --precision float16 --sparsity semi-structured +# gemlite benchmarks +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-64 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-64 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-None --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-None --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-8-None --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-8-None --write_result benchmark_results.txt + +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-64 --write_result benchmark_results.txt --batch_size 32 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-64 --write_result benchmark_results.txt --batch_size 32 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-None --write_result benchmark_results.txt --batch_size 32 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-None --write_result benchmark_results.txt --batch_size 32 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-8-None --write_result benchmark_results.txt --batch_size 32 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-8-None --write_result benchmark_results.txt --batch_size 32 + # 2:4 sparse model export MODEL_REPO=nm-testing/SparseLlama-3-8B-pruned_50.2of4 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --write_result benchmark_results.txt diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 8ec6acccc9..9b0208375e 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -171,7 +171,8 @@ def decode_n_tokens( ) next_token, next_prob = next_token.clone(), next_prob.clone() input_pos += 1 - new_tokens.append(next_token) + # in some instances not having this causes weird issues with the stored tokens when you run the next decode_one_token step + new_tokens.append(next_token.clone()) callback(new_tokens[-1]) new_probs.append(next_prob) cur_token = next_token @@ -368,6 +369,7 @@ def ffn_or_attn_only(mod, fqn): int8_weight_only, quantize_, uintx_weight_only, + gemlite_uintx_weight_only, ) from torchao.quantization.granularity import PerRow, PerTensor @@ -377,6 +379,39 @@ def ffn_or_attn_only(mod, fqn): from torchao.prototype.spinquant import apply_spinquant apply_spinquant(model) + if "gemlite" in quantization: + import os, pwd + import gemlite + from gemlite.core import GemLiteLinearTriton, set_autotune + _quant_args = quantization.split("-") + bit_width = int(_quant_args[-2]) + group_size = None if _quant_args[-1] == 'None' else int(_quant_args[-1]) + try: + packing_bitwidth = int(_quant_args[-3]) + except: + # if only 2 inputs found, use default value + packing_bitwidth = 32 + + quantize_(model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth)) + + # try to load gemlite kernel config + try: + GemLiteLinearTriton.load_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") + print(f"loaded gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") + except: + print(f"unable to load gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") + + print("running gemlite warmup") + generate( + model, + encode_tokens(tokenizer, prompt, bos=True, device=device), + max_new_tokens, + batch_size, + interactive=False, + temperature=temperature, + top_k=top_k, + ) + GemLiteLinearTriton.cache_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") if "int8wo" in quantization: quantize_(model, int8_weight_only()) if "int8dq" in quantization: @@ -959,7 +994,7 @@ def callback(x): parser = argparse.ArgumentParser(description="Your CLI description.") parser.add_argument( - "--prefill_size", type=int, default=0, help="Whether to run in ttft mode" + "--prefill_size", type=int, default=None, help="Whether to run in ttft mode" ) parser.add_argument( "--prompt", type=str, default="Hello, my name is", help="Input prompt." @@ -993,7 +1028,7 @@ def callback(x): help=( "Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, " + "autoquant-int4, autoquant-float8, uintx--, uintx---hqq, sparse-marlin, spinquant, " - + "embed-int8wo, marlin_qqq" + + "embed-int8wo, marlin_qqq, gemlite---" ), ) parser.add_argument( @@ -1053,6 +1088,7 @@ def callback(x): ) args = parser.parse_args() + print(args) main( args.prefill_size, args.prompt, diff --git a/torchao/_models/llama/model.py b/torchao/_models/llama/model.py index 74cad30cbd..87993f6867 100644 --- a/torchao/_models/llama/model.py +++ b/torchao/_models/llama/model.py @@ -170,7 +170,10 @@ def setup_caches(self, max_batch_size, max_seq_length, training: bool=False, kv_ max_seq_length = find_multiple(max_seq_length, 8) self.max_seq_length = max_seq_length self.max_batch_size = max_batch_size - dtype = self.output.weight.dtype + dtype = None + # module swaps can cause issues without this + if hasattr(self.output, "weight"): + dtype = self.output.weight.dtype # For quantized layers, dtype is encoded in scales if hasattr(self.output, "scales"): dtype = self.output.scales.dtype diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 7aca25ecc5..ba06d877f3 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -225,6 +225,8 @@ def from_hp_to_intx( else input_float.dtype ) device = input_float.device + from torchao.dtypes.uintx import TensorCoreTiledLayout + data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq( input_float, nbits=nbits, @@ -233,7 +235,15 @@ def from_hp_to_intx( compute_dtype=compute_dtype, device=device, verbose=False, - raw_output=False, + raw_output=not isinstance( + _layout, (TensorCoreTiledLayout, PlainLayout) + ), + # raw_output=False is basically the 'convert to TensorCoreTiledLayout zero_point version' option (add scale*midpoint) + # note in choose_qparams_affine, preserve_zero = False does this same thing while also controlling whether + # zero is preserved. + # TODO uncouple preserve_zero and conversion of zero_point to TensorCoreTiledLayout version + # TODO move the conversion of zero_point out of quant_primitives and into TensorCoreTiledLayout.from_plain + # TODO change PlainLayout to use raw_output. ) data = data.to(target_dtype) else: @@ -251,7 +261,8 @@ def from_hp_to_intx( zero_point_domain, ) # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None - if zero_point_domain is None: + # TODO should probably consolidate ZeroPointDomain.NONE and None + if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE: zero_point = None data = quantize_affine( input_float, diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 8938e7472c..a1667e8fbb 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -20,6 +20,10 @@ _linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl, ) +from torchao.dtypes.uintx.gemlite_layout import ( + _linear_fp_act_int4_weight_gemlite_check, + _linear_fp_act_int4_weight_gemlite_impl, +) from torchao.dtypes.uintx.marlin_qqq_tensor import ( _linear_int8_act_int4_weight_marlin_qqq_check, _linear_int8_act_int4_weight_marlin_qqq_impl, @@ -135,6 +139,10 @@ def _register_aqt_quantized_linear_dispatches(): _linear_int8_act_int4_weight_marlin_qqq_check, _linear_int8_act_int4_weight_marlin_qqq_impl, ), + ( + _linear_fp_act_int4_weight_gemlite_check, + _linear_fp_act_int4_weight_gemlite_impl, + ), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py new file mode 100644 index 0000000000..969816727e --- /dev/null +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -0,0 +1,372 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple + +import torch +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.uintx.tensor_core_tiled_layout import TensorCoreTiledAQTTensorImpl +from torchao.dtypes.utils import Layout, is_device +from torchao.quantization.quant_primitives import quantize_affine +from torchao.utils import fill_defaults + +aten = torch.ops.aten + + +def get_gemlite_quant_kwargs(bit_width, group_size): + from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain + + kwargs = {} + if bit_width != 8: + kwargs["mapping_type"] = MappingType.ASYMMETRIC + kwargs["block_size"] = (1, group_size) + kwargs["target_dtype"] = torch.uint8 + kwargs["eps"] = 1e-6 + kwargs["quant_min"] = 0 + kwargs["quant_max"] = (2**bit_width) - 1 + kwargs["eps"] = 1e-6 + kwargs["zero_point_dtype"] = torch.float16 + kwargs["zero_point_domain"] = ZeroPointDomain.FLOAT + elif bit_width == 8: + kwargs["mapping_type"] = MappingType.SYMMETRIC + kwargs["block_size"] = (1, group_size) + kwargs["target_dtype"] = torch.int8 + kwargs["quant_min"] = -128 + kwargs["quant_max"] = 127 + kwargs["eps"] = 1e-5 + kwargs["zero_point_dtype"] = None + kwargs["zero_point_domain"] = ZeroPointDomain.NONE + return kwargs + + +def apply_gemlite_quant( + weight, + group_size=64, + bit_width=4, + packing_bitwidth=8, + contiguous=None, + use_hqq=True, +): + from torchao.dtypes.affine_quantized_tensor import to_affine_quantized_intx + from torchao.dtypes.uintx.gemlite_layout import GemlitePackedLayout + + assert bit_width in [ + 4, + 8, + ], f"gemlite only works with bit_width 4,8 but got {bit_width}" + assert packing_bitwidth in [ + 8, + 16, + 32, + ], f"gemlite needs packing_bitwidth in [8, 16, 32] but got {packing_bitwidth}" + assert ( + weight.dtype == torch.float16 + ), f"gemlite only works with dtype torch.float16 but got {weight.dtype}" + assert group_size in [32, 64, 128, 256, 512, 1024, None] + assert ( + group_size is None or bit_width != 8 + ), "gemlite only works with group_size=None for bit_width=8" + + out_features, in_features = weight.shape + group_size = in_features if group_size is None else group_size + + quant_kwargs = get_gemlite_quant_kwargs(bit_width, group_size) + + layout = GemlitePackedLayout( + group_size=group_size, + bit_width=bit_width, + packing_bitwidth=packing_bitwidth, + contiguous=contiguous, + ) + return to_affine_quantized_intx( + weight, **quant_kwargs, _layout=layout, use_hqq=use_hqq + ) + + +@dataclass(frozen=True) +class GemlitePackedLayout(Layout): + group_size: Optional[int] = 64 + bit_width: int = 4 + packing_bitwidth: int = 8 + contiguous: bool = None + + +@register_layout(GemlitePackedLayout) +class GemliteAQTTensorImpl(TensorCoreTiledAQTTensorImpl): + def __new__( + cls, + packed_weight: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + gemlite_kwargs: Dict, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + gemlite_kwargs: Dict, + _layout: Layout, + ): + self.packed_weight = packed_weight + self.scale = scale + self.zero_point = zero_point + self.gemlite_kwargs = gemlite_kwargs + self._layout = _layout + torch._dynamo.config.inline_inbuilt_nn_modules = False + + def __tensor_flatten__(self): + return ["packed_weight", "scale", "zero_point"], [ + self._layout, + self.gemlite_kwargs, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scale, zero_point = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scale"], + tensor_data_dict["zero_point"], + ) + _layout, gemlite_kwargs = tensor_attributes + return cls(packed_weight, scale, zero_point, gemlite_kwargs, _layout) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + from gemlite.core import DType, GemLiteLinearTriton, set_autotune + + assert isinstance( + _layout, GemlitePackedLayout + ), f"GemliteAQTTensorImpl only works with GemliteLinearTriton but got {_layout}" + group_size, bit_width = _layout.group_size, _layout.bit_width + + torch._dynamo.config.inline_inbuilt_nn_modules = False + set_autotune( + {"GEMV_REVSPLITK": True, "GEMV": True, "GEMM_SPLITK": True, "GEMM": True}, + exhaustive=False, + use_cuda_graph=False, + ) + + out_features, in_features = int_data.shape + input_dtype, output_dtype = DType.FP16, DType.FP16 + gemlite_linear = GemLiteLinearTriton( + bit_width, + group_size=group_size, + in_features=in_features, + out_features=out_features, + input_dtype=input_dtype, + output_dtype=output_dtype, + ) + gemlite_linear.pack( + int_data, + scale, + zero_point, + bias=None, + fma_mode=False, + packing_bitwidth=_layout.packing_bitwidth, + contiguous=_layout.contiguous, + ) + + gemlite_kwargs = { + "out_features": out_features, + "scale_activations": gemlite_linear.scale_activations, + "meta_args": gemlite_linear.get_meta_args(), + } + + packed_weight, scale, zero_point = gemlite_linear.get_tensor_args() + + return cls(packed_weight, scale, zero_point, gemlite_kwargs, _layout) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs["device"] + if not is_device("cuda", device): + raise ValueError( + f"GemliteAQTTensorImpl is only available for cuda device, can't convert to {device}" + ) + return self.__class__( + self.packed_weight.to(device), + self.scale.to(device), + self.zero_point.to(device), + self.gemlite_kwargs, + self._layout, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_weight), + fn(self.scale), + fn(self.zero_point), + self.gemlite_kwargs, + self._layout, + ) + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dq = ( + _linear_fp_act_int4_weight_gemlite_impl( + torch.eye( + self.scale.shape[0] * self._layout.group_size, + device=self.device, + dtype=self.scale.dtype, + ), + self, + ) + .t() + .contiguous() + ) + + quant_kwargs = get_gemlite_quant_kwargs( + self._layout.bit_width, self._layout.group_size + ) + quant_kwargs["output_dtype"] = quant_kwargs.pop("target_dtype") + for key in ["mapping_type", "eps", "zero_point_dtype"]: + del quant_kwargs[key] + + int_data = quantize_affine( + dq, + scale=self.scale, + zero_point=self.zero_point, + **quant_kwargs, + ) + + return int_data, self.scale, self.zero_point + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + # we don't handle transpose operations and just ignore them. In practice the only + # reason a transpsoe should occur is because the functional linear + # op can decompose into e.g. transpose + addmm so since we want + # to use the gemlite matmul kernel, which expects teh weight to be passed in as is, + # we ignore the transpose + if func is aten.detach.default or func is aten.t.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + if func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + int_data, scale, zero_point = self.get_plain() + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + int_data = self._layout.post_process(int_data) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return return_and_correct_aliasing(func, args, kwargs, sliced) + elif dim == 1: + int_data, scale, zero_point = self.get_plain() + assert step == 1, "Only step == 1 is supported in slicing right now" + data_len = int_data.shape[dim] + # scale and zero_point are transposed compared to int_data + param_dim = 1 - dim + scale_len = scale.shape[param_dim] + ratio = data_len / scale_len + start_scale = int(start / ratio) + end_scale = int(end / ratio) + + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + scale = aten.slice.Tensor( + scale, param_dim, start_scale, end_scale, step + ) + if zero_point is not None and zero_point.numel() > 0: + zero_point = aten.slice.Tensor( + zero_point, param_dim, start_scale, end_scale, step + ) + else: + zero_point = None + # import fbvscode; fbvscode.set_trace() + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return sliced + else: + raise NotImplementedError( + f"GemliteAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) + + raise NotImplementedError( + f"GemliteAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_layout(self) -> Layout: + return self._layout + + +# logic taken from gemlite's core.py +def _matmul_type_fn(batch_size: int, bit_width: int) -> str: + if batch_size > 64: + return "GEMM" + elif batch_size > 1: + return "GEMM_SPLITK" + elif bit_width < 8: + return "GEMV_REVSPLITK" + else: + return "GEMV_SPLITK" + + +def _linear_fp_act_int4_weight_gemlite_impl(input_tensor, weight_tensor, bias=None): + if hasattr(weight_tensor, "tensor_impl"): + weight_impl = weight_tensor.tensor_impl + else: + weight_impl = weight_tensor + + from gemlite.core import GemLiteLinearTriton + + batch_size = input_tensor.view(-1, input_tensor.shape[-1]).shape[0] + matmul_type = _matmul_type_fn(batch_size, weight_impl._layout.bit_width) + + return GemLiteLinearTriton.forward_functional( + x=input_tensor, + bias=bias, + matmul_type=matmul_type, + **weight_impl.gemlite_kwargs, + tensor_args=( + weight_impl.packed_weight, + weight_impl.scale, + weight_impl.zero_point, + ), + ) + + +def _linear_fp_act_int4_weight_gemlite_check(input_tensor, weight_tensor, bias): + return ( + # input is native fp16 tensor + not is_traceable_wrapper_subclass(input_tensor) + # and input_tensor.dtype == torch.float16 + # weight is gemlite layout + and isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, GemlitePackedLayout) + ) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 3fc2cb5ef0..80f4dd689c 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -335,6 +335,9 @@ Marlin QQQ is an optimized GPU kernel that supports W4A8 mixed precision GEMM. F | | w4a8 | 197.45 | 653.50 | 4.79 | 3.31 | | | w4a8-g128 | 187.62 | 640.32 | 4.82 | 3.41 | +### Gemlite Triton +Int4 and Int8 quantization using the [Gemlite Triton](https://github.com/mobiusml/gemlite) kernels. You can try it out with the `quantize_` api as above alongside the constructor `gemlite_uintx_weight_only`. An example can be found in `torchao/_models/llama/generate.py`. + ### UINTx Quantization We're trying to develop kernels for low bit quantization for intx quantization formats. While the current performance is not ideal, we're hoping to continue to iterate on these kernels to improve their performance. @@ -358,7 +361,7 @@ We have kernels that do 8-bit dynamic quantization of activations and uintx grou | | int8_dynamic_activation_intx_weight-4-256-false | 16.03 | 65.81 | NA | 4.11 | | | int8_dynamic_activation_intx_weight-3-256-false | 18.94 | 59.97 | NA | 3.17 | -You try can out these apis with the `quantize_` api as above alongside the constructor `int8_dynamic_activation_intx_weight`. An example can be found in `torchao/_models/llama/generate.py`. +You can try out these apis with the `quantize_` api as above alongside the constructor `int8_dynamic_activation_intx_weight`. An example can be found in `torchao/_models/llama/generate.py`. ### Automatic Inductor Configuration The `quantize_` and `autoquant` apis now automatically use our recommended inductor configuration setings. You can mimic the same configuration settings for your own experiments by using the `torchao.quantization.utils.recommended_inductor_config_setter` to replicate our recommended configuration settings. Alternatively if you wish to disable these recommended settings, you can use the key word argument `set_inductor_config` and set it to false in the `quantize_` or `autoquant` apis to prevent assignment of those configuration settings. You can also overwrite these configuration settings after they are assigned if you so desire, as long as they are overwritten before passing any inputs to the torch.compiled model. This means that previous flows which referenced a variety of inductor configurations that needed to be set are now outdated, though continuing to manually set those same inductor configurations is unlikely to cause any issues. diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 90a2fb7207..a202dfd040 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -48,6 +48,7 @@ float8_static_activation_float8_weight, float8_weight_only, fpx_weight_only, + gemlite_uintx_weight_only, int4_weight_only, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_semi_sparse_weight, @@ -110,8 +111,9 @@ "float8_static_activation_float8_weight", "uintx_weight_only", "fpx_weight_only", - # smooth quant - subject to change + "gemlite_uintx_weight_only", "swap_conv2d_1x1_to_linear", + # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", "SmoothFakeDynamicallyQuantizedLinear", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 6930a2e15f..03fb8812b1 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -108,6 +108,7 @@ "float8_weight_only", "uintx_weight_only", "fpx_weight_only", + "gemlite_uintx_weight_only", "float8_dynamic_activation_float8_weight", "float8_static_activation_float8_weight", "Int8DynActInt4WeightQuantizer", @@ -633,6 +634,33 @@ def int8_dynamic_activation_int4_weight( ) +def gemlite_uintx_weight_only( + group_size: Optional[int] = 64, + bit_width: int = 4, + packing_bitwidth: int = 32, + contiguous: Optional[bool] = None, +): + """ + applies weight only 4 or 8 bit integer quantization and utilizes the gemlite triton kernel and its associated weight packing format. + This only works for fp16 models. 8 bit quantization is symmetric, 4 bit quantization is asymmetric. + + Args: + `group_size`: parameter for quantization, controls the granularity of quantization, smaller + size is more fine grained + `bit_width`: bit width of the quantized weight. + `packing_bitwidth`: bit width of the packed weight, should be 8 or 32. Can have performance impacts depending on hardware. + `contiguous`: if set, the weight will be packed as specified. Leaving it as None lets gemlite determine the best choice. + """ + + from torchao.dtypes.uintx.gemlite_layout import apply_gemlite_quant + + use_hqq = True if bit_width == 4 else False + apply_fn = lambda weight: apply_gemlite_quant( + weight, group_size, bit_width, packing_bitwidth, contiguous, use_hqq + ) + return _get_linear_subclass_inserter(apply_fn) + + def int4_weight_only( group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False ): diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 37aa609b9b..fddd21c43e 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -373,8 +373,13 @@ def _quantize_affine_no_dtype_cast( for i in reduction_dims: shape_after_reduction[i] = 1 scale = scale.view(shape_after_reduction) - if zero_point is not None: + + if zero_point is not None and zero_point.numel() > 0: zero_point = zero_point.view(shape_after_reduction) + else: + # in some cases zero_point being a non-value shows as a tensor + # with numel=0 which we handle by unifying the two + zero_point = None if zero_point_domain == ZeroPointDomain.INT.name: quant = torch.clamp( @@ -889,10 +894,12 @@ def _choose_qparams_affine( ) if ( zero_point_domain is not None - and zero_point_domain != ZeroPointDomain.INT.name + and zero_point_domain == ZeroPointDomain.FLOAT.name ): + # TODO INT should not be a valid ZeroPointDomain for symmetric quantization since + # symmetric quant doesn't have a zero_point raise ValueError( - "zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization" + "zero_point_domain should be ZeroPointDomain.INT or ZeroPointDomain.NONE for symmetric quantization" ) scale = torch.clamp(scale, min=eps) zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) @@ -911,6 +918,9 @@ def _choose_qparams_affine( zero_point_domain == ZeroPointDomain.FLOAT.name ), "if not preserve_zero, zero_point must be in FLOAT domain" mid_point = (quant_max + quant_min + 1) / 2 + # this is not preserving zero_point, this is converting to TensorCoreTiledFormat + # TODO move the conversion of zero_point out of quant_primitives + # and into TensorCoreTiledLayout.from_plain zero_point = min_val_neg + scale * mid_point if zero_point is not None: @@ -1185,17 +1195,31 @@ def choose_qparams_and_quantize_affine_hqq( verbose=verbose, ) else: + zero = zero.to(compute_dtype) + scale = scale.to(compute_dtype) W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1]) # Store meta-data (we invert the scale for dequantization) scale = 1.0 / scale - # Convert to affienquantized format + # Convert to TensorCoreTiled format + # TODO move the conversion of zero_point out of quant_primitives + # and into TensorCoreTiledLayout.from_plain and rename this + # helper function correctly. if raw_output is False: W_q, scale, zero = _convert_to_affinequantized_format( W_q, scale, zero, nbits, shape ) - + else: + # this path was not used before, the way hqq sets up scale/zero is transposed + # compared to the rest of our utils so we need to reshape them acccordingly. + W_q = W_q.reshape(shape) + if axis == 1: + scale = scale.reshape(shape[0], -1) + zero = zero.reshape(shape[0], -1) + else: + scale = scale.reshape(-1, shape[-1]) + zero = zero.reshape(-1, shape[-1]) # Make sure all the weights are in the right compute_dtype/device W_q = W_q.to(dtype=torch.uint8, device=device) scale = scale.to(dtype=compute_dtype, device=device)