Skip to content

Commit

Permalink
gemlite integration in torchao (#1034)
Browse files Browse the repository at this point in the history
* gemlite integration in torchao

Summary:

This PR adds support for gemlite kernels in torchao using a subclass
integration with the gemlite_uintx_weight_only constructor. This works
for int4 grouped and ungrouped assymmetric oeight only quantization and
int8 symmetric ungrouped quantization for fp16 models. TP support
through DTensor is included in thsi PR

in the process of integrating gemlite into AQT i also made some fixes to
a few quant primitives that are being used which previously were not.

Test Plan:

test_integration.py -k "test_gemlite_layout"
test_affine_quantized_tensor_parallel.py -k "test_tp_gemlite"

see benchmarks.sh for gemlite benchmarks as well.

Reviewers:

Subscribers:

Tasks:

Tags:

new gemlite integration using pip install

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

tests ran

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

fixing gemlite to do int4 matmul instead of fp16 fp16

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

running tests

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

more testing

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

AQT integration wip

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Wip

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

testing on gemlite a100_int8_tuning branch

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

gemlite subclass testing bitpacking 8 bits

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

bug fixing stuff

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

hicham fixes

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

new benchmarks

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

testing gemlite 8 bit

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

WIP

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

tp support

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

wip

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

final

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* fixing regressions

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles authored Dec 16, 2024
1 parent 200589b commit 603d908
Show file tree
Hide file tree
Showing 12 changed files with 577 additions and 13 deletions.
28 changes: 28 additions & 0 deletions test/dtypes/test_affine_quantized_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
from torchao.quantization.quant_api import quantize_
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6

try:
import gemlite # noqa: F401

has_gemlite = True
except ModuleNotFoundError:
has_gemlite = False


class TestAffineQuantizedTensorParallel(DTensorTestBase):
"""Basic test case for tensor subclasses"""
Expand Down Expand Up @@ -139,8 +146,29 @@ def test_tp(self, dtype):
return self._test_tp(dtype)


class TestGemliteLayoutTensorParallel(TestAffineQuantizedTensorParallel):
COMMON_DTYPES = [torch.float16]

@common_utils.parametrize("dtype", COMMON_DTYPES)
@with_comms
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not has_gemlite, "gemlite not available")
def test_tp_gemlite(self, dtype):
from torchao.quantization import gemlite_uintx_weight_only

for packing_bitwidth in [32, 8]:
for bit_width in [4, 8]:
for group_size in [64, 32, None] if bit_width == 4 else [None]:
api = lambda: gemlite_uintx_weight_only(
group_size, bit_width, packing_bitwidth
)
self.QUANT_METHOD_FN = staticmethod(api)
return self._test_tp(dtype)


common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel)
common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel)
common_utils.instantiate_parametrized_tests(TestGemliteLayoutTensorParallel)

# Run only on H100
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
Expand Down
34 changes: 34 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@
)
from torchao.dtypes.utils import is_device

try:
import gemlite
has_gemlite = True
except ModuleNotFoundError:
has_gemlite = False

logger = logging.getLogger("INFO")

torch.manual_seed(0)
Expand Down Expand Up @@ -870,6 +876,10 @@ def _test_lin_weight_subclass_api_impl(
ref_f = mod(x)
api(mod)

# test get_plain()
if hasattr(mod[0].weight, "tensor_impl"):
mod[0].weight.tensor_impl.get_plain()

test = mod(x)
self.assertGreater(
SQNR(ref_f, test),
Expand Down Expand Up @@ -930,6 +940,30 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):
test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "gemlite tests needs torch 2.5 or greater")
@unittest.skipIf(not has_gemlite, "gemlite not available")
def test_gemlite_layout(self, device, dtype):
if dtype!= torch.float16:
self.skipTest(f"gemlite only works for fp16 dtype")
from torchao.quantization import gemlite_uintx_weight_only
if device == "cpu":
self.skipTest(f"gemlite is for cuda, not {device}")
for packing_bitwidth in [32, 8]:
for bit_width in [4,8]:
for group_size in [64, 32, None] if bit_width ==4 else [None]:
api = lambda mod: quantize_(mod, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth))
for test_shape in [[1, 1024, 512],[16, 256, 1024], [128, 256, 1024]]:
print(packing_bitwidth, bit_width, group_size, test_shape, dtype)
self._test_lin_weight_subclass_api_impl(
api,
device,
15,
test_shape=test_shape,
test_dtype=dtype,
)


@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
Expand Down
15 changes: 15 additions & 0 deletions torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,21 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int4wo-64 --write_result benchmark_results.txt --prefill_size 8000
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization sparse-marlin --write_result benchmark_results.txt --prefill_size 8000 --precision float16 --sparsity semi-structured

# gemlite benchmarks
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-None --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-None --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-8-None --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-8-None --write_result benchmark_results.txt

python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-64 --write_result benchmark_results.txt --batch_size 32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-64 --write_result benchmark_results.txt --batch_size 32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-4-None --write_result benchmark_results.txt --batch_size 32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-4-None --write_result benchmark_results.txt --batch_size 32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-8-8-None --write_result benchmark_results.txt --batch_size 32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --quantization gemlite-32-8-None --write_result benchmark_results.txt --batch_size 32

# 2:4 sparse model
export MODEL_REPO=nm-testing/SparseLlama-3-8B-pruned_50.2of4
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --write_result benchmark_results.txt
Expand Down
42 changes: 39 additions & 3 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def decode_n_tokens(
)
next_token, next_prob = next_token.clone(), next_prob.clone()
input_pos += 1
new_tokens.append(next_token)
# in some instances not having this causes weird issues with the stored tokens when you run the next decode_one_token step
new_tokens.append(next_token.clone())
callback(new_tokens[-1])
new_probs.append(next_prob)
cur_token = next_token
Expand Down Expand Up @@ -368,6 +369,7 @@ def ffn_or_attn_only(mod, fqn):
int8_weight_only,
quantize_,
uintx_weight_only,
gemlite_uintx_weight_only,
)

from torchao.quantization.granularity import PerRow, PerTensor
Expand All @@ -377,6 +379,39 @@ def ffn_or_attn_only(mod, fqn):
from torchao.prototype.spinquant import apply_spinquant

apply_spinquant(model)
if "gemlite" in quantization:
import os, pwd
import gemlite
from gemlite.core import GemLiteLinearTriton, set_autotune
_quant_args = quantization.split("-")
bit_width = int(_quant_args[-2])
group_size = None if _quant_args[-1] == 'None' else int(_quant_args[-1])
try:
packing_bitwidth = int(_quant_args[-3])
except:
# if only 2 inputs found, use default value
packing_bitwidth = 32

quantize_(model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth))

# try to load gemlite kernel config
try:
GemLiteLinearTriton.load_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
print(f"loaded gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
except:
print(f"unable to load gemlite kernel cache /tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")

print("running gemlite warmup")
generate(
model,
encode_tokens(tokenizer, prompt, bos=True, device=device),
max_new_tokens,
batch_size,
interactive=False,
temperature=temperature,
top_k=top_k,
)
GemLiteLinearTriton.cache_config(f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json")
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
if "int8dq" in quantization:
Expand Down Expand Up @@ -959,7 +994,7 @@ def callback(x):

parser = argparse.ArgumentParser(description="Your CLI description.")
parser.add_argument(
"--prefill_size", type=int, default=0, help="Whether to run in ttft mode"
"--prefill_size", type=int, default=None, help="Whether to run in ttft mode"
)
parser.add_argument(
"--prompt", type=str, default="Hello, my name is", help="Input prompt."
Expand Down Expand Up @@ -993,7 +1028,7 @@ def callback(x):
help=(
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, "
+ "autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, "
+ "embed-int8wo, marlin_qqq"
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>"
),
)
parser.add_argument(
Expand Down Expand Up @@ -1053,6 +1088,7 @@ def callback(x):
)

args = parser.parse_args()
print(args)
main(
args.prefill_size,
args.prompt,
Expand Down
5 changes: 4 additions & 1 deletion torchao/_models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,10 @@ def setup_caches(self, max_batch_size, max_seq_length, training: bool=False, kv_
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
dtype = self.output.weight.dtype
dtype = None
# module swaps can cause issues without this
if hasattr(self.output, "weight"):
dtype = self.output.weight.dtype
# For quantized layers, dtype is encoded in scales
if hasattr(self.output, "scales"):
dtype = self.output.scales.dtype
Expand Down
15 changes: 13 additions & 2 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@ def from_hp_to_intx(
else input_float.dtype
)
device = input_float.device
from torchao.dtypes.uintx import TensorCoreTiledLayout

data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq(
input_float,
nbits=nbits,
Expand All @@ -233,7 +235,15 @@ def from_hp_to_intx(
compute_dtype=compute_dtype,
device=device,
verbose=False,
raw_output=False,
raw_output=not isinstance(
_layout, (TensorCoreTiledLayout, PlainLayout)
),
# raw_output=False is basically the 'convert to TensorCoreTiledLayout zero_point version' option (add scale*midpoint)
# note in choose_qparams_affine, preserve_zero = False does this same thing while also controlling whether
# zero is preserved.
# TODO uncouple preserve_zero and conversion of zero_point to TensorCoreTiledLayout version
# TODO move the conversion of zero_point out of quant_primitives and into TensorCoreTiledLayout.from_plain
# TODO change PlainLayout to use raw_output.
)
data = data.to(target_dtype)
else:
Expand All @@ -251,7 +261,8 @@ def from_hp_to_intx(
zero_point_domain,
)
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
if zero_point_domain is None:
# TODO should probably consolidate ZeroPointDomain.NONE and None
if zero_point_domain is None or zero_point_domain == ZeroPointDomain.NONE:
zero_point = None
data = quantize_affine(
input_float,
Expand Down
8 changes: 8 additions & 0 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
_linear_int8_act_int8_weight_block_sparse_check,
_linear_int8_act_int8_weight_block_sparse_impl,
)
from torchao.dtypes.uintx.gemlite_layout import (
_linear_fp_act_int4_weight_gemlite_check,
_linear_fp_act_int4_weight_gemlite_impl,
)
from torchao.dtypes.uintx.marlin_qqq_tensor import (
_linear_int8_act_int4_weight_marlin_qqq_check,
_linear_int8_act_int4_weight_marlin_qqq_impl,
Expand Down Expand Up @@ -135,6 +139,10 @@ def _register_aqt_quantized_linear_dispatches():
_linear_int8_act_int4_weight_marlin_qqq_check,
_linear_int8_act_int4_weight_marlin_qqq_impl,
),
(
_linear_fp_act_int4_weight_gemlite_check,
_linear_fp_act_int4_weight_gemlite_impl,
),
]:
register_aqt_quantized_linear_dispatch(dispatch_condition, impl)

Expand Down
Loading

0 comments on commit 603d908

Please sign in to comment.