Skip to content

Commit

Permalink
W4A8 based on CUTLASS
Browse files Browse the repository at this point in the history
CUTLASS-based s8s4_linear_cutlass() operator is introduced, performing
linear transformation over quantized 8-bit input and quantized 4-bit
weight tensors, with corresponding floating point scale tensors
attached.

A benchmark script, for comparing performance of MM based on this
linear operator with MM over 16-bit floating point tensors is supplied
in benchmarks/benchmarks/benchmark_s8s4_cutlass.py.

The Llama generator script torchao/_models/llama/generate.py is
changed, to add "int8adq-int4w-symm" quantization as an option, that
will in turn activate s8s4_linear_cutlass() operator.  With this type
of quantization activated, i.e. if generate.py script run as follows:

python generate.py --compile --precision=torch.float16 -q int8adq-int4w-symm

the generator achieves around 133 tok/sec on A100, vs. around 93
tok/sec without quantization, i.e. when generate.py script run as
follows:

python generate.py --compile --precision=torch.float16
  • Loading branch information
alexsamardzic committed Dec 6, 2024
1 parent 8a805d0 commit 10179d2
Show file tree
Hide file tree
Showing 16 changed files with 1,021 additions and 8 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ jobs:
python -m pip install --upgrade pip
pip install ${{ matrix.torch-spec }}
pip install -r dev-requirements.txt
git clone https://github.com/NVIDIA/cutlass.git
sed -i '/cutlass_include_dir = os.path.join/a\ \n # FIXME: remove this\n use_cutlass = True\n cutlass_include_dir = "cutlass/include"' setup.py
pip install .
export CONDA=$(dirname $(dirname $(which conda)))
export LD_LIBRARY_PATH=$CONDA/lib/:$LD_LIBRARY_PATH
Expand Down
53 changes: 53 additions & 0 deletions benchmarks/benchmark_s8s4_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import pandas as pd
from torchao.utils import benchmark_torch_function_in_microseconds
from torchao.ops import s8s4_linear_cutlass
from tqdm import tqdm


def get_problem(m, n, k):
groupsize = k

dev = torch.device("cuda")
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
B_ref = torch.randn((k, n), dtype=torch.half, device=dev)

A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev)
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
B = torch.randint(-128, 127, size=(n, k // 2), dtype=torch.int8, device=dev)
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
C = None

return A_ref, B_ref, A, A_scale, B, B_scale, C


def benchmark(m: int, k: int, n: int):
A_ref, B_ref, A, A_scale, B, B_scale, C = get_problem(m, n, k)

fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref)
s8s4_linear_cutlass_time = benchmark_torch_function_in_microseconds(
s8s4_linear_cutlass, A, A_scale, B, B_scale, C
)

return {
"m": m,
"k": k,
"n": n,
"fp16_latency (ms)": fp16_time,
"s8s4_linear_cutlass latency (ms)": s8s4_linear_cutlass_time,
"speedup (d/s)": fp16_time / s8s4_linear_cutlass_time,
}


if __name__ == "__main__":
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)

results = []
for m in tqdm([1 << i for i in range(10)]):
for n, k in zip(n_vals, k_vals):
results.append(benchmark(m, k, n))

df = pd.DataFrame(results)
df.to_csv("s8s4_linear_cutlass_time_results.csv", index=False)
print(df.to_markdown(index=False))
18 changes: 18 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ def get_extensions():
extension = CUDAExtension if use_cuda else CppExtension

if not IS_WINDOWS:
use_cutlass = False
try:
import cutlass
import cutlass_library
from packaging.version import parse
except:
pass
else:
if parse(cutlass.__version__) >= parse("3.6.0"):
use_cutlass = True
cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
cutlass_include_dir = os.path.join(cutlass_library_dir, "source", "include")

extra_link_args = []
extra_compile_args = {
"cxx": [
Expand All @@ -76,6 +89,11 @@ def get_extensions():
"-t=0",
]
}
if use_cutlass:
extra_compile_args["nvcc"].extend([
"-DTORCHAO_USE_CUTLASS",
"-I" + cutlass_include_dir,
])

if debug_mode:
extra_compile_args["cxx"].append("-g")
Expand Down
10 changes: 9 additions & 1 deletion test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
run_tests,
)

from torchao.dtypes import Int4CPULayout, SemiSparseLayout
from torchao.dtypes import Int4CPULayout, Int4PackedLayout, SemiSparseLayout
from torchao.quantization import (
float8_weight_only,
int4_weight_only,
Expand Down Expand Up @@ -38,6 +38,14 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cu
)
else:
base_functions.append(int4_weight_only(group_size=32))
base_functions.append(
int8_dynamic_activation_int4_weight(
group_size=None,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=Int4PackedLayout(),
)
)

if do_sparse:
base_functions.append(
Expand Down
79 changes: 79 additions & 0 deletions test/test_s8s4_linear_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import itertools

import torch

import torchao
from torchao.quantization.utils import group_quantize_tensor_symmetric
from torchao.utils import compute_max_diff

import pytest


S8S4_LINEAR_CUTLASS_DTYPE = [torch.float16, torch.bfloat16]
S8S4_LINEAR_CUTLASS_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
S8S4_LINEAR_CUTLASS_SIZE_MNK = [
(2, 512, 128),
(3, 2048, 2048),
(4, 3584, 640),
(13, 8704, 8576),
(26, 18944, 1664),
(67, 6656, 1408),
]
S8S4_LINEAR_CUTLASS_USE_BIAS = [False, True]
S8S4_LINEAR_CUTLASS_TEST_PARAMS = list(
itertools.product(
S8S4_LINEAR_CUTLASS_DTYPE,
S8S4_LINEAR_CUTLASS_BATCH_SIZE,
S8S4_LINEAR_CUTLASS_SIZE_MNK,
S8S4_LINEAR_CUTLASS_USE_BIAS,
)
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"dtype, batch_size, size_mnk, use_bias", S8S4_LINEAR_CUTLASS_TEST_PARAMS
)
def test_s8s4_linear_cutlass(dtype, batch_size, size_mnk, use_bias):
size_m, size_n, size_k = size_mnk

input = torch.randn((batch_size, size_m, size_k), dtype=dtype, device="cuda")
weight = torch.rand((size_n, size_k), dtype=dtype, device="cuda")
bias = torch.rand((size_n,), dtype=dtype, device="cuda") if use_bias else None

input_2d = input.view(-1, input.shape[-1])
input_2d_s8, input_2d_scales, input_2d_zeros = group_quantize_tensor_symmetric(
input_2d, 8, size_k, dtype
)
assert torch.all(input_2d_zeros == 0)
input_s8 = input_2d_s8.reshape(input.shape)
input_scales = input_2d_scales.reshape(input.shape[:-1])

weight_s8, weight_scales, weight_zeros = group_quantize_tensor_symmetric(
weight, 4, size_n, dtype
)
assert torch.all(weight_zeros == 0)
weight_s4 = ((weight_s8[:, 1::2] & 0xF) << 4) | (weight_s8[:, 0::2] & 0xF)

# If torch.nn.functional.linear(input, weight, bias) used as
# reference, the error would be too big. The calculation below is
# approximately what s8s4_linear_cutlass kernel is doing (except
# that matrrix multiplication is over integers there)).
size_m_2d = input_2d.shape[0]
output_ref = (
(input_2d_s8.to(dtype) @ weight_s8.to(dtype).T)
* input_2d_scales.view(size_m_2d, 1)
* weight_scales.view(1, size_n)
)
if bias is not None:
output_ref += bias
output_ref = output_ref.reshape(input.shape[:-1] + (size_n,))

fn_inputs = (input_s8, input_scales, weight_s4, weight_scales, bias)
try:
output = torchao.ops.s8s4_linear_cutlass(*fn_inputs)
except NotImplementedError as e:
pytest.xfail("torchao.ops.s8s4_linear_cutlass() op not implemented")

max_diff = compute_max_diff(output, output_ref)
assert max_diff < 5e-3
14 changes: 13 additions & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torchao
import torch._dynamo.config
import torch._inductor.config
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import get_model_size_in_bytes
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
Expand Down Expand Up @@ -293,6 +294,17 @@ def ffn_or_attn_only(mod, fqn):
group_size=int(quantization.split("-")[1])
assert group_size in [32,64,128,256], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
quantize_(model, int4_weight_only(group_size=group_size))
elif "int8adq-int4w-symm" in quantization:
from torchao.dtypes import Int4PackedLayout
quantize_(
model,
int8_dynamic_activation_int4_weight(
group_size=None,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=Int4PackedLayout(),
)
)
if "marlin" in quantization:
if "qqq" in quantization:
from torchao.dtypes import MarlinQQQLayout
Expand Down Expand Up @@ -699,7 +711,7 @@ def callback(x):
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str,
help=(
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
'Which quantization techniques to apply: int8dq, int8adq-int4w-symm, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, '
+'embed-int8wo, marlin_qqq'
)
Expand Down
Loading

0 comments on commit 10179d2

Please sign in to comment.