Skip to content

Commit

Permalink
Add gemlite kernel option to autoquant (#1449)
Browse files Browse the repository at this point in the history
* Add gemlite kernel option to autoquant

Summary:
att

Test Plan:
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant-gemlite-int4 --output_json_path result.json

Reviewers:

Subscribers:

Tasks:

Tags:

* updates
  • Loading branch information
jerryzh168 authored Dec 20, 2024
1 parent 3bac905 commit eab345c
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 17 deletions.
27 changes: 22 additions & 5 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def ffn_or_attn_only(mod, fqn):
from torchao.prototype.spinquant import apply_spinquant

apply_spinquant(model)
if "gemlite" in quantization:
if quantization.startswith("gemlite"):
import os, pwd
import gemlite
from gemlite.core import GemLiteLinearTriton, set_autotune
Expand Down Expand Up @@ -677,21 +677,38 @@ def ffn_or_attn_only(mod, fqn):
qtensor_class_list=torchao.quantization.OTHER_AUTOQUANT_CLASS_LIST,
example_input=inputs,
)
if "autoquant-fp" == quantization:
elif "autoquant-fp" == quantization:
model = autoquant(
model,
manual=True,
qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
example_input=inputs,
)
if "autoquant-sparse" == quantization:
elif "autoquant-sparse" == quantization:
model = autoquant(
model,
manual=True,
qtensor_class_list = torchao.quantization.DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST,
example_input=inputs,
)
if "autoquant-all" == quantization:
elif "autoquant-gemlite-int4" == quantization:
import os, pwd
from gemlite.core import GemLiteLinearTriton
GemLiteLinearTriton.load_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
model = autoquant(
model,
manual=True,
qtensor_class_list=torchao.quantization.GEMLITE_INT4_AUTOQUANT_CLASS_LIST,
example_input=inputs,
)
elif "autoquant-all" == quantization:
try:
import os, pwd
from gemlite.core import GemLiteLinearTriton
GemLiteLinearTriton.load_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
except:
pass

model = autoquant(
model,
manual=True,
Expand Down Expand Up @@ -986,7 +1003,7 @@ def callback(x):
type=str,
help=(
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, "
+ "autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, "
+ "autoquant-int4, autoquant-gemlite-int4, autoquant-float8, autoquant-sparse, autoquant-all, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, "
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>"
),
)
Expand Down
15 changes: 6 additions & 9 deletions torchao/dtypes/uintx/gemlite_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,14 @@ def get_gemlite_quant_kwargs(bit_width, group_size):
return kwargs


def apply_gemlite_quant(
def get_gemlite_aqt_kwargs(
weight,
group_size=64,
bit_width=4,
packing_bitwidth=8,
packing_bitwidth=32,
contiguous=None,
use_hqq=True,
):
from torchao.dtypes.affine_quantized_tensor import to_affine_quantized_intx
from torchao.dtypes.uintx.gemlite_layout import GemlitePackedLayout

assert bit_width in [
Expand Down Expand Up @@ -86,17 +85,15 @@ def apply_gemlite_quant(
)
return weight

quant_kwargs = get_gemlite_quant_kwargs(bit_width, group_size)

layout = GemlitePackedLayout(
aqt_kwargs = get_gemlite_quant_kwargs(bit_width, group_size)
aqt_kwargs["_layout"] = GemlitePackedLayout(
group_size=group_size,
bit_width=bit_width,
packing_bitwidth=packing_bitwidth,
contiguous=contiguous,
)
return to_affine_quantized_intx(
weight, **quant_kwargs, _layout=layout, use_hqq=use_hqq
)
aqt_kwargs["use_hqq"] = use_hqq
return aqt_kwargs


@dataclass(frozen=True)
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST,
GEMLITE_INT4_AUTOQUANT_CLASS_LIST,
OTHER_AUTOQUANT_CLASS_LIST,
autoquant,
)
Expand Down Expand Up @@ -94,6 +95,7 @@
"autoquant",
"DEFAULT_AUTOQUANT_CLASS_LIST",
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
"GEMLITE_INT4_AUTOQUANT_CLASS_LIST",
"DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST",
"DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST",
"OTHER_AUTOQUANT_CLASS_LIST",
Expand Down
49 changes: 49 additions & 0 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"autoquant",
"DEFAULT_AUTOQUANT_CLASS_LIST",
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST",
"GEMLITE_INT4_AUTOQUANT_CLASS_LIST",
"DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST",
"DEFAULT_SPARSE_AUTOQUANT_CLASS_LIST",
"OTHER_AUTOQUANT_CLASS_LIST",
Expand Down Expand Up @@ -700,6 +701,43 @@ class AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight(
layout: Layout = MarlinSparseLayout()


class AQGemliteInt4G32WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
group_size: int = 32

@classmethod
def from_float(cls, weight):
from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs

bit_width = 4
packing_bitwidth = 32
contiguous = None
use_hqq = True
aqt_kwargs = get_gemlite_aqt_kwargs(
weight, cls.group_size, bit_width, packing_bitwidth, contiguous, use_hqq
)
return super(
AQGemliteInt4G32WeightOnlyQuantizedLinearWeight, cls
).from_hp_to_intx(weight, **aqt_kwargs)


class AQGemliteInt4G64WeightOnlyQuantizedLinearWeight(
AQGemliteInt4G32WeightOnlyQuantizedLinearWeight
):
group_size: int = 64


class AQGemliteInt4G128WeightOnlyQuantizedLinearWeight(
AQGemliteInt4G32WeightOnlyQuantizedLinearWeight
):
group_size: int = 128


class AQGemliteInt4G256WeightOnlyQuantizedLinearWeight(
AQGemliteInt4G32WeightOnlyQuantizedLinearWeight
):
group_size: int = 256


class AQDefaultLinearWeight(torch.Tensor, AQMixin):
"""
A class to be used in concert with AutoQuantizableLinearWeight to provide a
Expand Down Expand Up @@ -977,6 +1015,12 @@ def get_weight_block_size(x):
AQInt4G64WeightOnlyQuantizedLinearWeight,
]

GEMLITE_INT4_AUTOQUANT_CLASS_LIST = [
AQDefaultLinearWeight,
AQInt8DynamicallyQuantizedLinearWeight,
AQGemliteInt4G64WeightOnlyQuantizedLinearWeight,
]

DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST = [
AQFloat32LinearWeight,
AQBFloat16LinearWeight,
Expand All @@ -1002,6 +1046,11 @@ def get_weight_block_size(x):
+ DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST
)

# add gemlite options
ALL_AUTOQUANT_CLASS_LIST += [
AQGemliteInt4G64WeightOnlyQuantizedLinearWeight,
]

if is_sm_at_least_89():
ALL_AUTOQUANT_CLASS_LIST += [
AQFloat8WeightOnlyQuantizedLinearWeight,
Expand Down
9 changes: 6 additions & 3 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,11 +666,14 @@ def gemlite_uintx_weight_only(
`contiguous`: if set, the weight will be packed as specified. Leaving it as None lets gemlite determine the best choice.
"""

from torchao.dtypes.uintx.gemlite_layout import apply_gemlite_quant
from torchao.dtypes.uintx.gemlite_layout import get_gemlite_aqt_kwargs

use_hqq = True if bit_width == 4 else False
apply_fn = lambda weight: apply_gemlite_quant(
weight, group_size, bit_width, packing_bitwidth, contiguous, use_hqq
apply_fn = lambda weight: to_affine_quantized_intx(
weight,
**get_gemlite_aqt_kwargs(
weight, group_size, bit_width, packing_bitwidth, contiguous, use_hqq
),
)
return _get_linear_subclass_inserter(apply_fn)

Expand Down

0 comments on commit eab345c

Please sign in to comment.