From eab345c2268a7506355d506ebfc27b5d28e5e7d0 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 20 Dec 2024 14:37:53 -0800 Subject: [PATCH] Add gemlite kernel option to autoquant (#1449) * Add gemlite kernel option to autoquant Summary: att Test Plan: python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-gemlite-int4 --output_json_path result.json Reviewers: Subscribers: Tasks: Tags: * updates --- torchao/_models/llama/generate.py | 27 +++++++++++--- torchao/dtypes/uintx/gemlite_layout.py | 15 ++++---- torchao/quantization/__init__.py | 2 ++ torchao/quantization/autoquant.py | 49 ++++++++++++++++++++++++++ torchao/quantization/quant_api.py | 9 +++-- 5 files changed, 85 insertions(+), 17 deletions(-) diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 231133c2c..6e2e4f713 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -343,7 +343,7 @@ def ffn_or_attn_only(mod, fqn): from torchao.prototype.spinquant import apply_spinquant apply_spinquant(model) - if "gemlite" in quantization: + if quantization.startswith("gemlite"): import os, pwd import gemlite from gemlite.core import GemLiteLinearTriton, set_autotune @@ -677,21 +677,38 @@ def ffn_or_attn_only(mod, fqn): qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST, example_input=inputs, ) - if "autoquant-fp" == quantization: + elif "autoquant-fp" == quantization: model = autoquant( model, manual=True, qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, example_input=inputs, ) - if "autoquant-sparse" == quantization: + elif "autoquant-sparse" == quantization: model = autoquant( model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST, example_input=inputs, ) - if "autoquant-all" == quantization: + elif "autoquant-gemlite-int4" == quantization: + import os, pwd + from gemlite.core import GemLiteLinearTriton + GemLiteLinearTriton.load_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") + model = autoquant( + model, + manual=True, + qtensor_class_list=torchao.quantization.GEMLITE_INT4_AUTOQUANT_CLASS_LIST, + example_input=inputs, + ) + elif "autoquant-all" == quantization: + try: + import os, pwd + from gemlite.core import GemLiteLinearTriton + GemLiteLinearTriton.load_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json") + except: + pass + model = autoquant( model, manual=True, @@ -986,7 +1003,7 @@ def callback(x): type=str, help=( "Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, " - + "autoquant-int4, autoquant-float8, uintx--, uintx---hqq, sparse-marlin, spinquant, " + + "autoquant-int4, autoquant-gemlite-int4, autoquant-float8, autoquant-sparse, autoquant-all, uintx--, uintx---hqq, sparse-marlin, spinquant, " + "embed-int8wo, marlin_qqq, gemlite---" ), ) diff --git a/torchao/dtypes/uintx/gemlite_layout.py b/torchao/dtypes/uintx/gemlite_layout.py index 1adee2a55..9233ce6ee 100644 --- a/torchao/dtypes/uintx/gemlite_layout.py +++ b/torchao/dtypes/uintx/gemlite_layout.py @@ -46,15 +46,14 @@ def get_gemlite_quant_kwargs(bit_width, group_size): return kwargs -def apply_gemlite_quant( +def get_gemlite_aqt_kwargs( weight, group_size=64, bit_width=4, - packing_bitwidth=8, + packing_bitwidth=32, contiguous=None, use_hqq=True, ): - from torchao.dtypes.affine_quantized_tensor import to_affine_quantized_intx from torchao.dtypes.uintx.gemlite_layout import GemlitePackedLayout assert bit_width in [ @@ -86,17 +85,15 @@ def apply_gemlite_quant( ) return weight - quant_kwargs = get_gemlite_quant_kwargs(bit_width, group_size) - - layout = GemlitePackedLayout( + aqt_kwargs = get_gemlite_quant_kwargs(bit_width, group_size) + aqt_kwargs["_layout"] = GemlitePackedLayout( group_size=group_size, bit_width=bit_width, packing_bitwidth=packing_bitwidth, contiguous=contiguous, ) - return to_affine_quantized_intx( - weight, **quant_kwargs, _layout=layout, use_hqq=use_hqq - ) + aqt_kwargs["use_hqq"] = use_hqq + return aqt_kwargs @dataclass(frozen=True) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index a202dfd04..d0d29cf4b 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -15,6 +15,7 @@ DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST, DEFAULT_INT4_AUTOQUANT_CLASS_LIST, DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST, + GEMLITE_INT4_AUTOQUANT_CLASS_LIST, OTHER_AUTOQUANT_CLASS_LIST, autoquant, ) @@ -94,6 +95,7 @@ "autoquant", "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", + "GEMLITE_INT4_AUTOQUANT_CLASS_LIST", "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", "DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 2a0c80c86..b13f1d16a 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -48,6 +48,7 @@ "autoquant", "DEFAULT_AUTOQUANT_CLASS_LIST", "DEFAULT_INT4_AUTOQUANT_CLASS_LIST", + "GEMLITE_INT4_AUTOQUANT_CLASS_LIST", "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST", "DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST", "OTHER_AUTOQUANT_CLASS_LIST", @@ -700,6 +701,43 @@ class AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight( layout: Layout = MarlinSparseLayout() +class AQGemliteInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin): + group_size: int = 32 + + @classmethod + def from_float(cls, weight): + from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs + + bit_width = 4 + packing_bitwidth = 32 + contiguous = None + use_hqq = True + aqt_kwargs = get_gemlite_aqt_kwargs( + weight, cls.group_size, bit_width, packing_bitwidth, contiguous, use_hqq + ) + return super( + AQGemliteInt4G32WeightOnlyQuantizedLinearWeight, cls + ).from_hp_to_intx(weight, **aqt_kwargs) + + +class AQGemliteInt4G64WeightOnlyQuantizedLinearWeight( + AQGemliteInt4G32WeightOnlyQuantizedLinearWeight +): + group_size: int = 64 + + +class AQGemliteInt4G128WeightOnlyQuantizedLinearWeight( + AQGemliteInt4G32WeightOnlyQuantizedLinearWeight +): + group_size: int = 128 + + +class AQGemliteInt4G256WeightOnlyQuantizedLinearWeight( + AQGemliteInt4G32WeightOnlyQuantizedLinearWeight +): + group_size: int = 256 + + class AQDefaultLinearWeight(torch.Tensor, AQMixin): """ A class to be used in concert with AutoQuantizableLinearWeight to provide a @@ -977,6 +1015,12 @@ def get_weight_block_size(x): AQInt4G64WeightOnlyQuantizedLinearWeight, ] +GEMLITE_INT4_AUTOQUANT_CLASS_LIST = [ + AQDefaultLinearWeight, + AQInt8DynamicallyQuantizedLinearWeight, + AQGemliteInt4G64WeightOnlyQuantizedLinearWeight, +] + DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST = [ AQFloat32LinearWeight, AQBFloat16LinearWeight, @@ -1002,6 +1046,11 @@ def get_weight_block_size(x): + DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST ) +# add gemlite options +ALL_AUTOQUANT_CLASS_LIST += [ + AQGemliteInt4G64WeightOnlyQuantizedLinearWeight, +] + if is_sm_at_least_89(): ALL_AUTOQUANT_CLASS_LIST += [ AQFloat8WeightOnlyQuantizedLinearWeight, diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 5161f1657..af950cb79 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -666,11 +666,14 @@ def gemlite_uintx_weight_only( `contiguous`: if set, the weight will be packed as specified. Leaving it as None lets gemlite determine the best choice. """ - from torchao.dtypes.uintx.gemlite_layout import apply_gemlite_quant + from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs use_hqq = True if bit_width == 4 else False - apply_fn = lambda weight: apply_gemlite_quant( - weight, group_size, bit_width, packing_bitwidth, contiguous, use_hqq + apply_fn = lambda weight: to_affine_quantized_intx( + weight, + **get_gemlite_aqt_kwargs( + weight, group_size, bit_width, packing_bitwidth, contiguous, use_hqq + ), ) return _get_linear_subclass_inserter(apply_fn)