Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

int8 dynamic prefill weight only decode #1436

Merged
merged 63 commits into from
Dec 30, 2024
Merged
Changes from 1 commit
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
f390fd9
Add sparsity flag to benchmark
jcaip Oct 18, 2024
67937a9
update
jcaip Oct 18, 2024
6b62266
update
jcaip Oct 18, 2024
aa4c9df
fp8 testing
jcaip Oct 18, 2024
6b1ede1
fp8 testing
jcaip Oct 18, 2024
3c07c40
wip
jcaip Oct 22, 2024
a6c7de9
update benchmark script
jcaip Oct 22, 2024
3660766
update
jcaip Oct 22, 2024
ddf2e10
wip
jcaip Oct 22, 2024
ad4d3b0
udpate
jcaip Oct 22, 2024
653587e
update
jcaip Oct 22, 2024
c757357
wip
jcaip Oct 22, 2024
f1b0841
wip
jcaip Oct 22, 2024
afeaff5
test
jcaip Oct 22, 2024
c294765
wip
jcaip Oct 22, 2024
803e9b3
update
jcaip Oct 22, 2024
eb18850
fix
jcaip Oct 22, 2024
2642212
wip
jcaip Oct 22, 2024
4eccdb9
move out of aqt
jcaip Oct 22, 2024
13e6fd6
wip
jcaip Oct 22, 2024
608d70c
moved float8+24 to it's own file
jcaip Oct 22, 2024
b1f1796
Merge branch 'main' into jcaip/sparse-benchmarking-updates
jcaip Oct 22, 2024
30a4fac
update
jcaip Oct 23, 2024
6091592
wip
jcaip Oct 23, 2024
17f9121
remove float8 for now
jcaip Oct 23, 2024
75d0a0b
wip
jcaip Oct 23, 2024
b2fba99
fix
jcaip Oct 28, 2024
ba5665d
fix
jcaip Oct 28, 2024
4fdfa7b
time prefill by default
jcaip Dec 2, 2024
111babc
update
jcaip Dec 3, 2024
35f1fc7
merge
jcaip Dec 3, 2024
23f981d
fix merge conflicts
jcaip Dec 3, 2024
74c52ff
update
jcaip Dec 3, 2024
eed072d
update benchmarks
jcaip Dec 3, 2024
67cbcbb
fix ruff check
jcaip Dec 3, 2024
0e579ae
fix ruff v2
jcaip Dec 3, 2024
443db19
undo change
jcaip Dec 3, 2024
054717e
add padding
jcaip Dec 3, 2024
2e5b72a
update import
jcaip Dec 3, 2024
2b81dd6
final commit
jcaip Dec 3, 2024
de2d447
fix script
jcaip Dec 3, 2024
c0fa0da
wip
jcaip Dec 6, 2024
584c013
update
jcaip Dec 6, 2024
38d60c7
update
jcaip Dec 25, 2024
97cca7a
update
jcaip Dec 25, 2024
525053b
merge main
jcaip Dec 25, 2024
4da1b31
fix merge confligt
jcaip Dec 25, 2024
2517406
demo
jcaip Dec 25, 2024
5b8a28c
update
jcaip Dec 30, 2024
e25b30c
update generate
jcaip Dec 30, 2024
a58e0fd
moved summarization to standalone script
jcaip Dec 30, 2024
ea5cb0c
update
jcaip Dec 30, 2024
17a191a
update weight only decode flag
jcaip Dec 30, 2024
8899435
remove prompt.txt
jcaip Dec 30, 2024
a3056ff
cleanup
jcaip Dec 30, 2024
67a1a35
remove moby.txt
jcaip Dec 30, 2024
1554a8c
update
jcaip Dec 30, 2024
5161364
update
jcaip Dec 30, 2024
562191f
update
jcaip Dec 30, 2024
bf18806
update benchmars
jcaip Dec 30, 2024
89f03d8
rename arg
jcaip Dec 30, 2024
ce58e1e
update demo script
jcaip Dec 30, 2024
b144a53
formatting
jcaip Dec 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wip
jcaip committed Dec 6, 2024
commit c0fa0dafe7e035333ed20f15f12682fe874ea899
161 changes: 161 additions & 0 deletions torchao/_models/llama/benchmark_results.txt

Large diffs are not rendered by default.

161 changes: 82 additions & 79 deletions torchao/_models/llama/benchmarks.sh

Large diffs are not rendered by default.

11 changes: 10 additions & 1 deletion torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@ def run_evaluation(
device = "cuda",
precision = torch.bfloat16,
quantization: Optional[str] = None,
sparsity:Optional[str] = None,
compile=False,
max_length=None,
calibration_tasks: Optional[List[str]] = None,
@@ -44,7 +45,7 @@ def run_evaluation(
"""Runs the evaluation of a model using LM Eval."""
print(
f"\nEvaluating model {checkpoint_path} on tasks: {tasks}, limit: {limit}, device: {device}, precision: {precision}, "
+f"quantization: {quantization}, compile: {compile}, max_length: {max_length}, calibration_tasks: {calibration_tasks}, "
+f"quantization: {quantization}, sparsity: {sparsity}, compile: {compile}, max_length: {max_length}, calibration_tasks: {calibration_tasks}, "
+f"calibration_seq_length: {calibration_seq_length}, pad_calibration_inputs: {pad_calibration_inputs}\n"
)
torchao.quantization.utils.recommended_inductor_config_setter()
@@ -232,6 +233,13 @@ def run_evaluation(
"float8wo, float8dq, float8saq"
),
)
parser.add_argument(
"--sparsity",
type=str,
help=(
"Which sparsity techniques to apply: semi-structured"
),
)
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')
parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq')
@@ -247,6 +255,7 @@ def run_evaluation(
args.device,
args.precision,
args.quantization,
args.sparstiy,
args.compile,
args.max_length,
args.calibration_tasks,
39 changes: 21 additions & 18 deletions torchao/_models/llama/evals.sh
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder

export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround # auto-round w/o quant_lm_head
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head
# export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround # auto-round w/o quant_lm_head
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head

export MODEL_REPO=meta-llama/Meta-Llama-3-8B
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu # auto-round w/o quant_lm_head
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head
# export MODEL_REPO=meta-llama/Meta-Llama-3-8B
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu # auto-round w/o quant_lm_head
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head

export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu # auto-round w/o quant_lm_head
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8wo
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8dq-tensor
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8dq-row
# export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cpu # auto-round w/o quant_lm_head
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization autoround-cuda-1 # auto-round w/ quant_lm_head
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8wo
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8dq-tensor
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8dq-row

# Testing on additional tasks
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --tasks 'winogrande' 'arc_challenge'
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --tasks 'mmlu' 'truthfulqa_mc2'
# # Testing on additional tasks
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --tasks 'winogrande' 'arc_challenge'
# python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int8wo --tasks 'mmlu' 'truthfulqa_mc2'

export MODEL_REPO=nm-testing/SparseLlama-3-8B-pruned_50.2of4
python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization float8dq-tensor --sparsity semi-structured
50 changes: 32 additions & 18 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,8 @@
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(False)

class HostEvent:
def __init__(self):
@@ -142,7 +144,6 @@ def generate(
# format model input
prompt, input_pos = prepare_inputs_for_model(prompt)
prompt = prompt.repeat(batch_size, 1) # expand prompt based on batchsize

# full prompt+output will be stored in seq
seq = torch.empty(batch_size, max_seq_length, dtype=prompt.dtype, device=device)
seq[:, :T] = prompt
@@ -221,7 +222,7 @@ def main(

if prefill_size is not None and prefill_size > 0:
# create prompt of prefill size
prompt = "prompt " * (int(prefill_size)-3)
prompt = "prompt " * (int(prefill_size)-2)

torchao.quantization.utils.recommended_inductor_config_setter()

@@ -285,6 +286,8 @@ def ffn_or_attn_only(mod, fqn):
quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only)
else:
quantize_(model, int8_dynamic_activation_int8_weight())
if "optim" in quantization:
quantize_(model, int8_dynamic_activation_int8_weight(optimize_prefill=True))
if "int4wo" in quantization:
if "hqq" in quantization:
use_hqq=True
@@ -293,6 +296,7 @@ def ffn_or_attn_only(mod, fqn):
group_size=int(quantization.split("-")[1])
assert group_size in [32,64,128,256], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
quantize_(model, int4_weight_only(group_size=group_size))

if "marlin" in quantization:
if "qqq" in quantization:
from torchao.dtypes import MarlinQQQLayout
@@ -305,6 +309,10 @@ def ffn_or_attn_only(mod, fqn):
layout=MarlinQQQLayout(),
),
)
if "optim" in quantization:
from torchao.quantization.quant_api import int8_dynamic_prefill_int4_weight_only_decode
quantize_(model, int8_dynamic_prefill_int4_weight_only_decode(), filter_fn=ffn_or_attn_only)

elif "semi" in sparsity:
from torchao.dtypes import MarlinSparseLayout
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()), filter_fn=ffn_or_attn_only)
@@ -386,7 +394,12 @@ def ffn_or_attn_only(mod, fqn):
granularity = PerRow()
else:
granularity = PerTensor()
quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity))
if sparsity and "semi" in sparsity:
from torchao.experimental.sparse import SemiSparseFloat8Layout
quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity, layout=SemiSparseFloat8Layout()))
quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity), filter_fn=not_ffn_only)
else:
quantize_(model, float8_dynamic_activation_float8_weight(granularity=granularity))
elif "autoquant_v2" in quantization:
from torchao.prototype.quantization.autoquant_v2 import autoquant_v2
from torchao._models._eval import InputRecorder
@@ -482,11 +495,11 @@ def ffn_or_attn_only(mod, fqn):
unwrap_tensor_subclass(model)

# standalone sparsity
elif sparsity:
from torchao.sparsity import semi_sparse_weight, sparsify_
if "semi" in sparsity:
#TODO there is a bug here, need to fix
sparsify_(model.to(device), semi_sparse_weight(), filter_fn=ffn_only)
# elif sparsity:
# from torchao.sparsity import semi_sparse_weight, sparsify_
# if "semi" in sparsity:
# #TODO there is a bug here, need to fix
# sparsify_(model.to(device), semi_sparse_weight(), filter_fn=ffn_only)

model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9

@@ -501,7 +514,7 @@ def ffn_or_attn_only(mod, fqn):
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)

if compile_prefill:
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
prefill = torch.compile(prefill, mode="max-autotune", fullgraph=True, dynamic=True)

if memory_profile:
if device == "cuda":
@@ -517,7 +530,7 @@ def ffn_or_attn_only(mod, fqn):
'decode_tokens_per_sec': [],
'prefill_time': [],
}
start = -1 if compile else 0
start = -2 if compile else 0

for i in range(start, num_samples):
if i==0:
@@ -576,7 +589,7 @@ def callback(x):
decode_start_event=decode_start_event,
decode_end_event=decode_end_event,
)
if i == -1:
if i < 0:
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
continue
if hasattr(prof, "export_chrome_trace"):
@@ -585,12 +598,11 @@ def callback(x):
t = time.perf_counter() - t0

if not interactive and prefill_size is None:
tok_list = y[0].tolist()
# truncate text after end of string token
tokens = tok_list if not tokenizer.eos_id() in tok_list else tok_list[:tok_list.index(tokenizer.eos_id())]
print(tokenizer.decode(tokens))
else:
print()
tok_list = y[0].tolist()
# truncate text after end of string token
tokens = tok_list if not tokenizer.eos_id() in tok_list else tok_list[:tok_list.index(tokenizer.eos_id())]
print(tokenizer.decode(tokens))

tokens_generated = (y.size(-1) - prompt_length)
tokens_sec = tokens_generated / t
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
@@ -624,12 +636,14 @@ def callback(x):

#ignore first sample for warmup
tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item()
overall_time = torch.mean(torch.tensor(aggregate_metrics['time'])).item()
ttft = torch.mean(torch.tensor(aggregate_metrics['prefill_time'])).item()
decode_tokpersec = torch.mean(torch.tensor(aggregate_metrics['decode_tokens_per_sec'])).item()
bandwidth = model_size * tokpersec
mem = torch.cuda.max_memory_reserved() /1e9
print(f"Average overall tokens/sec: {tokpersec:.2f}")
print(f"Average decode tokens/sec: {decode_tokens_sec:.04f} s")
print(f"Average overall time: {overall_time:.04f} s")
print(f"Average TTFT: {ttft:.04f} s")
if device == "cuda":
mem = torch.cuda.max_memory_reserved() /1e9
@@ -642,7 +656,7 @@ def callback(x):
print(f"Peak Memory Usage: {mem:.02f} GB")
print(f"Model Size: {model_size:.02f} GB")
if write_result:
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, tok/s_decode={decode_tokpersec:6.2f}, ttft={ttft:5.4f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB "
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, tok/s_decode={decode_tokpersec:6.2f}, time={overall_time:5.4f}, ttft={ttft:5.4f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB "
result_txt += f"quant: {quantization}, sparse: {sparsity}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
result_txt += f"repro: python generate.py "
result_txt += f"--quantization {quantization} " if quantization else ""
6 changes: 5 additions & 1 deletion torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
@@ -102,7 +102,7 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias):
# Attach the _quantized_linear_op to the AffineQuantizedTensor class
AffineQuantizedTensor._quantized_linear_op = _quantized_linear_op


from torchao.experimental.sparse.dynamic_wo import _linear_fp_act_int4_weight_sparse_marlin_decode_check, _linear_fp_act_int4_weight_sparse_marlin_decode_impl
# _register_aqt_quantized_linear_dispatches function has a list of (dispatch_condition, implementation) functions, defined in their dtype layout classes, that takes the following args:
# input_tensor: dimension is (M1, M2, ..., in_features)
# weight_tensor: dimension is (out_features, in_features)
@@ -135,6 +135,10 @@ def _register_aqt_quantized_linear_dispatches():
_linear_int8_act_int4_weight_marlin_qqq_check,
_linear_int8_act_int4_weight_marlin_qqq_impl,
),
(
_linear_fp_act_int4_weight_sparse_marlin_decode_check,
_linear_fp_act_int4_weight_sparse_marlin_decode_impl,
)
]:
register_aqt_quantized_linear_dispatch(dispatch_condition, impl)

7 changes: 7 additions & 0 deletions torchao/dtypes/uintx/plain_layout.py
Original file line number Diff line number Diff line change
@@ -230,6 +230,8 @@ def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias):
and isinstance(weight_tensor._layout, PlainLayout)
)

from collections import Counter
seen = Counter()

def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
#
@@ -242,6 +244,11 @@ def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
# large that y_dot_int32 * a float16 scale is greater than the maximum
# value of a float 16, (which results in a value of inf even if multiplying
# by the other scale would bring it within the expected range)
# global seen
# shape_key = tuple(input_tensor.shape)
# if shape_key not in seen:
# seen[shape_key] += 1
# print(seen)

x_vals_int8 = input_tensor.tensor_impl.int_data
x_scales = input_tensor.tensor_impl.scale
1 change: 1 addition & 0 deletions torchao/dtypes/uintx/semi_sparse_layout.py
Original file line number Diff line number Diff line change
@@ -55,6 +55,7 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(
y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape(
*x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1]
)
print("HERE")
output_dtype = input_tensor.dtype
# TODO: waiting for jesse's test/fix
y = y.to(output_dtype).contiguous()
4 changes: 4 additions & 0 deletions torchao/experimental/sparse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .semi_structured_float8 import SemiSparseFloat8Layout
__all__ = [
'SemiSparseFloat8Layout'
]
328 changes: 328 additions & 0 deletions torchao/experimental/sparse/dynamic_wo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,328 @@
from dataclasses import dataclass

import torch
from torch.utils._python_dispatch import (
return_and_correct_aliasing,
)

from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
register_layout,
)
from torchao.dtypes.uintx.tensor_core_tiled_layout import _aqt_is_tensor_core_tile_uint4
from torchao.dtypes.utils import AQTTensorImpl, Layout
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
)

aten = torch.ops.aten

from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
register_layout,
)


def _linear_fp_act_int4_weight_sparse_marlin_decode_check(input_tensor, weight_tensor, bias):
return (
isinstance(weight_tensor, AffineQuantizedTensor)
and _aqt_is_tensor_core_tile_uint4(weight_tensor)
and input_tensor.dtype == torch.float16
and len(weight_tensor.shape) == 2
and weight_tensor.zero_point_domain == ZeroPointDomain.INT
and isinstance(weight_tensor._layout, MarlinSparseLayoutDecodeSemiStructuredPrefillLayout)
)

def _linear_fp_act_int4_weight_sparse_marlin_decode_impl(input_tensor, weight_tensor, bias):
from torchao.ops import marlin_24_gemm
from torchao.sparsity.marlin import marlin_24_workspace

if isinstance(input_tensor, AffineQuantizedTensor):
x_vals_int8 = input_tensor.tensor_impl.int_data
x_scales = input_tensor.tensor_impl.scale
w_vals_int8 = weight_tensor.tensor_impl.int_data_prefill
w_scales = weight_tensor.tensor_impl.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
# must pad
row, col = tmp.shape
from torch.sparse import SparseSemiStructuredTensorCUSPARSELT
tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp)
# we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm
y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm(
w_vals_int8,
tmp_padded.t(),
alpha=w_scales.to(torch.float32),
out_dtype=torch.bfloat16,
).t()[:row, :]
y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape(
*x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1]
)
output_dtype = input_tensor.dtype
# TODO: waiting for jesse's test/fix
y = y.to(output_dtype).contiguous()
if bias is not None:
y += bias
return y

else:
sparse_w_int4 = weight_tensor.tensor_impl.int_data
scale = weight_tensor.tensor_impl.scale
meta = weight_tensor.tensor_impl.meta
original_shape = weight_tensor.tensor_impl.original_shape
num_bits = weight_tensor.tensor_impl.num_bits

# Folds batch dimension into the first dimension
input_2d = input_tensor.view(-1, input_tensor.shape[-1])

size_m = input_2d.shape[0]
size_n = scale.shape[1]
size_k = input_2d.shape[1]
workspace_24 = marlin_24_workspace(original_shape[1])

out = marlin_24_gemm(
input_2d,
sparse_w_int4,
meta,
scale,
workspace_24,
num_bits,
size_m,
size_n,
size_k,
)

# Unfold the batch dimension
out = out.reshape(input_tensor.shape[:-1] + (scale.shape[1],))

if bias is not None:
out += bias.to(out.dtype)
return out



@dataclass(frozen=True)
class MarlinSparseLayoutDecodeSemiStructuredPrefillLayout(Layout):
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
"""Preprocess the input tensor to be in the correct format for the Marlin sparse kernel.
- 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format
- 2º: tensor is injected with 2:4 sparsity
- 3º: transposes it again because the quantization process will compute the scales for dim=-1
Args:
input (torch.Tensor): the input tensor to preprocess
Returns:
torch.Tensor: the preprocessed tensor
"""
from torchao.sparsity.marlin import inject_24 # avoid circular import

input_t = input.t()
w_24, _ = inject_24(input_t, *input_t.shape)
return w_24.t()


@register_layout(MarlinSparseLayoutDecodeSemiStructuredPrefillLayout)
class MarlinSparseLayoutDecodeSemiStructuredPrefillImpl(AQTTensorImpl):
"""
TensorImpl for sparse_marlin_24 layout for affine quantized tensor.
Can be used with 4 bits and 8 bits quantization.
Original marlin documentation and information:
https://github.com/IST-DASLab/marlin/tree/master
Sparse marlin documentation and information:
https://github.com/IST-DASLab/Sparse-Marlin?tab=readme-ov-file
fields:
original_shape (torch.Size): the original shape of the tensor. used to unpack the tensor to the original shape
group_size (int): the group size used to pack the tensor
num_bits (int): the number of bits used to quantize the tensor
"""

@staticmethod
def __new__(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
int_data_prefill: torch.Tensor,
meta: torch.Tensor,
_layout: Layout,
original_shape: torch.Size,
group_size: int,
num_bits: int,
):
kwargs = {}
kwargs["device"] = int_data.device
kwargs["layout"] = (
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
)
kwargs["dtype"] = int_data.dtype
kwargs["requires_grad"] = False
shape = int_data.shape
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
int_data_prefill: torch.Tensor,
meta: torch.Tensor,
_layout: Layout,
original_shape: torch.Size,
group_size: int,
num_bits: int,
):
self.int_data = int_data
self.scale = scale
self.zero_point = zero_point
self.meta = meta
self._layout = _layout
self.original_shape = original_shape
self.group_size = group_size
self.num_bits = num_bits
self.int_data_prefill = int_data_prefill

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs

if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

raise NotImplementedError(
f"MarlinSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported"
)

def __tensor_flatten__(self):
return ["int_data", "scale", "zero_point", "int_data_prefill", "meta"], [
self._layout,
self.original_shape,
self.group_size,
self.num_bits,
]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
int_data = tensor_data_dict["int_data"]
scale = tensor_data_dict["scale"]
zero_point = tensor_data_dict["zero_point"]
meta = tensor_data_dict["meta"]
int_data_prefill = tensor_data_dict["int_data_prefill"]
_layout, original_shape, group_size, num_bits = tensor_attributes
return cls(
int_data,
scale,
zero_point,
int_data_prefill,
meta,
_layout,
original_shape,
group_size,
num_bits,
)

def get_plain(self):
from torchao.sparsity.marlin import (
unpack_from_marlin_24,
) # avoid circular import

int_data_expanded, scales_expanded = unpack_from_marlin_24(
self.int_data,
self.scale,
self.meta,
self.original_shape,
self.group_size,
self.num_bits,
)
int_data_expanded_t = int_data_expanded.t()
scales_expanded_t = scales_expanded.t()
return int_data_expanded_t, scales_expanded_t, self.zero_point

@classmethod
def from_plain(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
_layout: Layout,
):
from torchao.sparsity.marlin import (
const,
pack_to_marlin_24,
) # avoid circular import

assert isinstance(_layout, MarlinSparseLayoutDecodeSemiStructuredPrefillLayout)

# Linear layers are (in_features, out_features) but the int_data that is reaching this point
# is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code.
q_w_24 = int_data.t()
scale_t = scale.t()

if not torch.cuda.get_device_capability()[0] >= 8:
raise ValueError(
f"Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel."
)

if q_w_24.dtype != torch.int32:
raise ValueError("Only `torch.int32` weights are supported.")

in_features, out_features = q_w_24.shape
if in_features % 128 != 0 or out_features != 256 == 0:
raise ValueError(
"`in_features` must be divisible by 64 and `out_features` by 256."
)

# NOTE: The current marlin 2:4 kernel supports both 4 and 8 bits quantization but fp8
# will require a bit more work to get our current quantization flow to work with it.
# Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main
num_bits = 4 if torch.max(q_w_24) < 16 else -1
if num_bits not in [4]:
raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.")

group_size = in_features // scale_t.shape[0]
if group_size == 0:
group_size = in_features
assert (
group_size <= in_features
), "Group size must be less than or equal to in_features."

if group_size not in const.SUPPORTED_GROUP_SIZES:
raise ValueError(
f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}."
)

# Compress quantized weight to marlin 2:4 format
marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24(
q_w_24, scale_t, num_bits, group_size
)
int_data_prefill = torch._cslt_compress(int_data.to(torch.int8))

return cls(
marlin_24_q_w_comp,
marlin_24_s,
zero_point,
int_data_prefill,
meta,
_layout,
q_w_24.shape,
group_size,
num_bits,
)

def get_layout(self) -> Layout:
return self._layout

def _apply_fn_to_data(self, fn):
self.int_data = fn(self.int_data)
self.scale = fn(self.scale)
self.zero_point = fn(self.zero_point)
self.int_data_prefill = fn(self.int_data_prefill)
self.meta = fn(self.meta)
return self
113 changes: 113 additions & 0 deletions torchao/experimental/sparse/semi_structured_float8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import logging
from typing import List, Optional, Tuple, Union
from dataclasses import dataclass

import torch
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
register_layout,
)
from torchao.dtypes.affine_quantized_tensor_ops import (
register_aqt_quantized_linear_dispatch,
)

from torchao.dtypes.floatx.float8_layout import(
Float8Layout,
Float8AQTTensorImpl,
)
from torchao.float8.inference import (
Float8MMConfig,
addmm_float8_unwrapped_inference,
)
from torchao.dtypes.utils import Layout, get_out_shape

from torch.sparse import SparseSemiStructuredTensorCUSPARSELT

logger = logging.getLogger(__name__)

@dataclass(frozen=True)
class SemiSparseFloat8Layout(Layout):
mm_config: Optional[Float8MMConfig] = None

@register_layout(SemiSparseFloat8Layout)
class SemiSparseFloat8AQTTensorImpl(Float8AQTTensorImpl):
"""
TensorImpl storage class for semi_sparse_cusparselt layout for affine quantized tensor
"""
def get_plain(self):
# Currently we don't have cuSPARSELt expansion routines, so we matmul by
# the identity matrix to get the original dense matrix. This is slow though.
# cols = self.float8_data.numel() * 16 // (10 * self.shape[0])
# float_data_expanded = torch._cslt_sparse_mm(self.float8_data,
# torch.eye(cols,
# dtype=self.float8_data.dtype,
# device=self.float8_data.device).t())
return self.float8_data, self.scale, None

@classmethod
def from_plain(
cls,
float8_data: torch.Tensor,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
assert isinstance(_layout, SemiSparseFloat8Layout)
float8_data_compressed = torch._cslt_compress(float8_data)
output = cls(float8_data_compressed, scale, False, _layout)
return output

def _linear_fp8_act_fp8_weight_semi_structured_check(
input_tensor: Union[torch.Tensor, AffineQuantizedTensor],
weight_tensor: Union[torch.Tensor, AffineQuantizedTensor],
bias: Optional[torch.Tensor],
) -> bool:
def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor], layout=Float8Layout) -> bool:
return (
isinstance(aqt, AffineQuantizedTensor) and
isinstance(aqt._layout, layout)
and aqt.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt))
)
return check_aqt(input_tensor) and check_aqt(weight_tensor, layout=SemiSparseFloat8Layout)

def _linear_fp8_act_fp8_weight_semi_structured_impl(
input_tensor: AffineQuantizedTensor,
weight_tensor: AffineQuantizedTensor,
bias: Optional[torch.Tensor],
):
"""Implements matmul between FP8 input and FP8 weight with compute using _cslt_sparse_mm"""
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)

# Weight tensor preprocessing
scaled_mm_config = weight_tensor._layout.mm_config
w_tensor_impl = weight_tensor.tensor_impl
w_compressed = w_tensor_impl.float8_data
w_scale = w_tensor_impl.scale

# Input tensor preprocessing
inpt_data = input_tensor.tensor_impl.float8_data
input_scale = input_tensor.tensor_impl.scale
# Handle case where input tensor is more than 2D
inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1])

row, col = inpt_data.shape
inpt_data_padded = torch.sparse.SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(inpt_data)

y_dot_bf16 = torch._cslt_sparse_mm(
w_compressed,
inpt_data_padded.t(),
# out_dtype=torch.float8,
# transpose_result=True,
).t()[:row, :]
y = (y_dot_bf16)
output_dtype = input_tensor.dtype
y = y.to(output_dtype).contiguous()
return y


register_aqt_quantized_linear_dispatch(
_linear_fp8_act_fp8_weight_semi_structured_check,
_linear_fp8_act_fp8_weight_semi_structured_impl
)
97 changes: 91 additions & 6 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@
"""

import logging
from re import I
import types
import warnings
from typing import Callable, Optional, Tuple, Union
@@ -629,6 +630,58 @@ def int8_dynamic_activation_int4_weight(
)


from torchao.experimental.sparse.dynamic_wo import MarlinSparseLayoutDecodeSemiStructuredPrefillLayout

def int8_dynamic_prefill_int4_weight_only_decode(
group_size=128,
):
def apply_int4_weight_only_quant(weight):
if weight.shape[-1] % group_size != 0:
logger.info(
f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}"
)
return weight

in_features = weight.shape[1]
# int8 dynamic quantization only has benefit when in_feature > 16
if in_features <= 16:
logger.info(
f"Skipping applying int8_dynamic_activation_int8_weight to weight of shape {weight.shape}"
f" because `in_feature` is <= 16: {in_features}"
)
return weight

mapping_type = MappingType.SYMMETRIC
block_size = (1, group_size)
target_dtype = torch.int32
quant_min = 0
quant_max = 15
eps = 1e-6
preserve_zero = True
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.INT

input_quant_func = _int8_symm_per_token_reduced_range_quant_noop_decode
layout = MarlinSparseLayoutDecodeSemiStructuredPrefillLayout()

weight = to_affine_quantized_intx(
weight,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=zero_point_domain,
_layout=layout,
)
weight = to_linear_activation_quantized(weight, input_quant_func)
return weight

return _get_linear_subclass_inserter(apply_int4_weight_only_quant)

def int4_weight_only(
group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False
):
@@ -738,8 +791,29 @@ def _int8_symm_per_token_reduced_range_quant(x: torch.Tensor) -> torch.Tensor:
)


def _int8_symm_per_token_reduced_range_quant_noop_decode(x: torch.Tensor) -> torch.Tensor:
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
eps = 1e-5
quant_min = -127
quant_max = 127
if x.shape[1] == 1:
return x
else:
return to_affine_quantized_intx(
x,
mapping_type,
_get_per_token_block_size(x),
target_dtype,
eps=eps,
quant_min=quant_min,
quant_max=quant_max,
scale_dtype=torch.float32 if x.dtype == torch.float16 else None,
)


def int8_dynamic_activation_int8_weight(
layout=PlainLayout(), act_mapping_type=MappingType.SYMMETRIC
layout=PlainLayout(), act_mapping_type=MappingType.SYMMETRIC, optimize_prefill=False
):
"""
Applies int8 dynamic symmetric per-token activation and int8 per-channel weight
@@ -766,11 +840,14 @@ def get_weight_block_size(x):
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

# input settings
if act_mapping_type == MappingType.SYMMETRIC:
input_quant_func = _int8_symm_per_token_reduced_range_quant
if optimize_prefill:
input_quant_func = _int8_symm_per_token_reduced_range_quant_noop_decode
else:
input_quant_func = _int8_asymm_per_token_quant
# input settings
if act_mapping_type == MappingType.SYMMETRIC:
input_quant_func = _int8_symm_per_token_reduced_range_quant
else:
input_quant_func = _int8_asymm_per_token_quant

block_size = get_weight_block_size(weight)
weight = to_affine_quantized_intx(
@@ -950,7 +1027,9 @@ def float8_dynamic_activation_float8_weight(
Union[_fp8_granularities, Tuple[_fp8_granularities, _fp8_granularities]]
] = None,
mm_config: Optional[Float8MMConfig] = None,
layout=None,
):
from torchao.experimental.sparse import SemiSparseFloat8Layout
"""
Applies float8 dynamic symmetric quantization to both activations and weights of linear layers.
@@ -973,6 +1052,12 @@ def float8_dynamic_activation_float8_weight(

activation_granularity, weight_granularity = _normalize_granularity(granularity)

if layout is None:
layout = Float8Layout(mm_config=mm_config)
elif isinstance(layout, SemiSparseFloat8Layout):
assert activation_granularity == weight_granularity == PerTensor(), "SemiSparseFoat8Layout only supports PerTensor granularity for activations and weights"
layout = SemiSparseFloat8Layout(mm_config=mm_config)

def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
if not _fp8_mm_compat(weight):
return weight
@@ -987,7 +1072,7 @@ def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
block_size=block_size,
target_dtype=weight_dtype,
scale_dtype=torch.float32,
_layout=Float8Layout(mm_config=mm_config),
_layout=layout,
)

input_quant_func = _input_activation_quant_func_fp8