Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TTFT benchmarks + update sparsity benchmarks #1140

Merged
merged 41 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
f390fd9
Add sparsity flag to benchmark
jcaip Oct 18, 2024
67937a9
update
jcaip Oct 18, 2024
6b62266
update
jcaip Oct 18, 2024
aa4c9df
fp8 testing
jcaip Oct 18, 2024
6b1ede1
fp8 testing
jcaip Oct 18, 2024
3c07c40
wip
jcaip Oct 22, 2024
a6c7de9
update benchmark script
jcaip Oct 22, 2024
3660766
update
jcaip Oct 22, 2024
ddf2e10
wip
jcaip Oct 22, 2024
ad4d3b0
udpate
jcaip Oct 22, 2024
653587e
update
jcaip Oct 22, 2024
c757357
wip
jcaip Oct 22, 2024
f1b0841
wip
jcaip Oct 22, 2024
afeaff5
test
jcaip Oct 22, 2024
c294765
wip
jcaip Oct 22, 2024
803e9b3
update
jcaip Oct 22, 2024
eb18850
fix
jcaip Oct 22, 2024
2642212
wip
jcaip Oct 22, 2024
4eccdb9
move out of aqt
jcaip Oct 22, 2024
13e6fd6
wip
jcaip Oct 22, 2024
608d70c
moved float8+24 to it's own file
jcaip Oct 22, 2024
b1f1796
Merge branch 'main' into jcaip/sparse-benchmarking-updates
jcaip Oct 22, 2024
30a4fac
update
jcaip Oct 23, 2024
6091592
wip
jcaip Oct 23, 2024
17f9121
remove float8 for now
jcaip Oct 23, 2024
75d0a0b
wip
jcaip Oct 23, 2024
b2fba99
fix
jcaip Oct 28, 2024
ba5665d
fix
jcaip Oct 28, 2024
4fdfa7b
time prefill by default
jcaip Dec 2, 2024
111babc
update
jcaip Dec 3, 2024
35f1fc7
merge
jcaip Dec 3, 2024
23f981d
fix merge conflicts
jcaip Dec 3, 2024
74c52ff
update
jcaip Dec 3, 2024
eed072d
update benchmarks
jcaip Dec 3, 2024
67cbcbb
fix ruff check
jcaip Dec 3, 2024
0e579ae
fix ruff v2
jcaip Dec 3, 2024
443db19
undo change
jcaip Dec 3, 2024
054717e
add padding
jcaip Dec 3, 2024
2e5b72a
update import
jcaip Dec 3, 2024
2b81dd6
final commit
jcaip Dec 3, 2024
de2d447
fix script
jcaip Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions scripts/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf
python scripts/download.py --repo_id meta-llama/Meta-Llama-3-8B
python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B
python scripts/download.py --repo_id meta-llama/Llama-3.2-3B
python scripts/download.py --repo_id nm-testing/SparseLlama-3-8B-pruned_50.2of4
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-3.2-3B
# neuralmagic doesn't come with tokenizer, so we need to copy it over
mkdir -p checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original && cp checkpoints/meta-llama/Meta-Llama-3-8B/original/tokenizer.model checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original/tokenizer.model
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4
3 changes: 3 additions & 0 deletions test/prototype/test_sparse_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def test_sparse(self):
sparsify_(model, semi_sparse_weight())
sparse_result = model(input)

if compile:
model = torch.compile(model)

torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)


Expand Down
21 changes: 19 additions & 2 deletions torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --wr
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt

Expand All @@ -62,7 +62,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --wr
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt

Expand All @@ -79,3 +79,20 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 1
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 128

# TTFT benchmarks
export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --prefill_size 8000
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --write_result benchmark_results.txt --prefill_size 8000
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8wo --write_result benchmark_results.txt --prefill_size 8000
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --sparsity semi-structured --write_result benchmark_results.txt --prefill_size 8000
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8dq --write_result benchmark_results.txt --prefill_size 8000
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8wo --write_result benchmark_results.txt --prefill_size 8000
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int4wo-64 --write_result benchmark_results.txt --prefill_size 8000
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization sparse-marlin --write_result benchmark_results.txt --prefill_size 8000 --precision float16 --sparsity semi-structured

# 2:4 sparse model
export MODEL_REPO=nm-testing/SparseLlama-3-8B-pruned_50.2of4
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --sparsity semi-structured --precision float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt
116 changes: 104 additions & 12 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,29 @@
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False

class HostEvent:
def __init__(self):
self.event_time = None

def record(self):
self.event_time = time.perf_counter()

def elapsed_time(self, other_event):
if self.event_time is None:
raise ValueError("Event not recorded!")
# return ms to match cuda event
return abs(other_event.event_time - self.event_time) * 1000

def device_timer(device):
if "cuda" in device:
return torch.cuda.Event(enable_timing=True)
elif ("cpu" in device) or ("mps" in device):
return HostEvent()
else:
print(f"device={device} is not yet suppported")

def device_sync(device):
if "cuda" in device:
torch.cuda.synchronize(device)
Expand Down Expand Up @@ -98,6 +121,10 @@ def generate(
kv_cache_quantization: bool = False,
cache_size: Optional[int] = None,
linear_causal_mask: bool=False,
prefill_start_event: Optional[torch.cuda.Event]=None,
prefill_end_event: Optional[torch.cuda.Event]=None,
decode_start_event: Optional[torch.cuda.Event]=None,
decode_end_event: Optional[torch.cuda.Event]=None,
**sampling_kwargs
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -128,12 +155,21 @@ def generate(
model.setup_caches(max_batch_size=batch_size, max_seq_length=cache_size, kv_cache_quantization=kv_cache_quantization, linear_causal_mask=linear_causal_mask, prompt_length=T)

# execute prefill
if prefill_start_event is not None:
prefill_start_event.record()
next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone()
seq[:, T] = next_token.squeeze()
if prefill_end_event is not None:
prefill_end_event.record()

# execute token generation
if decode_start_event is not None:
decode_start_event.record()
input_pos = torch.tensor([T], device=device, dtype=torch.int)
generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)
seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1)
if decode_end_event is not None:
decode_end_event.record()

return seq

Expand All @@ -157,6 +193,7 @@ def _load_model(checkpoint_path, device, precision):
B_INST, E_INST = "[INST]", "[/INST]"

def main(
prefill_size: Optional[int] = None,
prompt: str = "Hello, my name is",
interactive: bool = False,
num_samples: int = 5,
Expand All @@ -166,6 +203,7 @@ def main(
temperature: float = 0.8,
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
quantization: Optional[str] = None,
sparsity: Optional[str] = None,
kv_cache_quantization: bool = False,
cache_size: Optional[int] = None,
linear_causal_mask: bool=False,
Expand All @@ -181,6 +219,10 @@ def main(
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
"""

if prefill_size is not None and prefill_size > 0:
# create prompt of prefill size
prompt = "prompt " * (int(prefill_size)-3)

torchao.quantization.utils.recommended_inductor_config_setter()

assert checkpoint_path.is_file(), checkpoint_path
Expand All @@ -205,6 +247,14 @@ def main(

torch.manual_seed(1234)

def ffn_only(mod, fqn):
return isinstance(mod, torch.nn.Linear) and "feed_forward" in fqn

def not_ffn_only(mod, fqn):
return isinstance(mod, torch.nn.Linear) and not ffn_only(mod, fqn)

def ffn_or_attn_only(mod, fqn):
return isinstance(mod, torch.nn.Linear) and ("feed_forward" in fqn or "attention" in fqn)

if quantization:
from torchao.quantization import (
Expand All @@ -228,9 +278,14 @@ def main(
apply_spinquant(model)
if "int8wo" in quantization:
quantize_(model, int8_weight_only())
elif "int8dq" in quantization:
quantize_(model, int8_dynamic_activation_int8_weight())
elif "int4wo" in quantization:
if "int8dq" in quantization:
if sparsity and "semi" in sparsity:
from torchao.dtypes import SemiSparseLayout
quantize_(model, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), filter_fn=ffn_only)
quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only)
else:
quantize_(model, int8_dynamic_activation_int8_weight())
if "int4wo" in quantization:
if "hqq" in quantization:
use_hqq=True
else:
Expand All @@ -250,9 +305,9 @@ def main(
layout=MarlinQQQLayout(),
),
)
else:
elif "semi" in sparsity:
from torchao.dtypes import MarlinSparseLayout
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()), filter_fn=ffn_or_attn_only)
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
elif "embed-int8wo" in quantization:
Expand Down Expand Up @@ -426,6 +481,13 @@ def main(
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(model)

# standalone sparsity
elif sparsity:
from torchao.sparsity import semi_sparse_weight, sparsify_
if "semi" in sparsity:
#TODO there is a bug here, need to fix
sparsify_(model.to(device), semi_sparse_weight(), filter_fn=ffn_only)

model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9

if save:
Expand All @@ -451,6 +513,9 @@ def main(

aggregate_metrics = {
'tokens_per_sec': [],
'time': [],
'decode_tokens_per_sec': [],
'prefill_time': [],
}
start = -1 if compile else 0

Expand Down Expand Up @@ -485,6 +550,8 @@ def callback(x):
else:
callback = lambda x : x
t0 = time.perf_counter()
prefill_start_event, prefill_end_event = device_timer(device), device_timer(device)
decode_start_event, decode_end_event = device_timer(device), device_timer(device)
import contextlib
if (i != num_samples - 1 or not profile):
prof = contextlib.nullcontext()
Expand All @@ -504,6 +571,10 @@ def callback(x):
kv_cache_quantization=kv_cache_quantization,
cache_size=cache_size,
linear_causal_mask=linear_causal_mask,
prefill_start_event=prefill_start_event,
prefill_end_event=prefill_end_event,
decode_start_event=decode_start_event,
decode_end_event=decode_end_event,
)
if i == -1:
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
Expand All @@ -513,7 +584,7 @@ def callback(x):
device_sync(device=device) # MKG
t = time.perf_counter() - t0

if not interactive:
if not interactive and prefill_size is None:
tok_list = y[0].tolist()
# truncate text after end of string token
tokens = tok_list if not tokenizer.eos_id() in tok_list else tok_list[:tok_list.index(tokenizer.eos_id())]
Expand All @@ -523,7 +594,14 @@ def callback(x):
tokens_generated = (y.size(-1) - prompt_length)
tokens_sec = tokens_generated / t
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
aggregate_metrics['time'].append(t)
decode_time = decode_start_event.elapsed_time(decode_end_event) / 1000
decode_tokens_sec = tokens_generated / decode_time
aggregate_metrics['decode_tokens_per_sec'].append(decode_tokens_sec)
prefill_time = prefill_start_event.elapsed_time(prefill_end_event) / 1000
aggregate_metrics['prefill_time'].append(prefill_time)
print(f"Sample {i+1} | overall time {t:.04f} s {tokens_sec:.02f} tokens/sec",
f"| prefill time {prefill_time:.04f} s decode {decode_tokens_sec:.02f} tokens/sec")
print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s")

if memory_profile and i==0:
Expand All @@ -544,8 +622,15 @@ def callback(x):
break
print("==========")

#ignore first sample for warmup
tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item()
ttft = torch.mean(torch.tensor(aggregate_metrics['prefill_time'])).item()
decode_tokpersec = torch.mean(torch.tensor(aggregate_metrics['decode_tokens_per_sec'])).item()
bandwidth = model_size * tokpersec
mem = torch.cuda.max_memory_reserved() /1e9
print(f"Average overall tokens/sec: {tokpersec:.2f}")
print(f"Average decode tokens/sec: {decode_tokens_sec:.04f} s")
print(f"Average TTFT: {ttft:.04f} s")
if device == "cuda":
mem = torch.cuda.max_memory_reserved() /1e9
elif device == "xpu":
Expand All @@ -557,15 +642,17 @@ def callback(x):
print(f"Peak Memory Usage: {mem:.02f} GB")
print(f"Model Size: {model_size:.02f} GB")
if write_result:
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB "
result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, tok/s_decode={decode_tokpersec:6.2f}, ttft={ttft:5.4f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB "
result_txt += f"quant: {quantization}, sparse: {sparsity}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
result_txt += f"repro: python generate.py "
result_txt += f"--quantization {quantization} " if quantization else ""
result_txt += f"--sparsity {sparsity} " if sparsity else ""
result_txt += f"--checkpoint_path {checkpoint_path} "
result_txt += f"--device {device} "
result_txt += f"--precision {precision} "
result_txt += f"--compile " if compile else ""
result_txt += f"--compile_prefill " if compile_prefill else ""
result_txt += f"--prefill_size {prefill_size}" if prefill_size else ""
result_txt += f"--profile {profile} " if profile else ""
result_txt += f"--profile {memory_profile} " if memory_profile else ""
result_txt += f"--interactive " if interactive else ""
Expand All @@ -587,7 +674,7 @@ def callback(x):
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Your CLI description.')

parser.add_argument('--prefill_size', type=int, default=0, help='Whether to run in ttft mode')
parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
Expand All @@ -603,6 +690,11 @@ def callback(x):
+'embed-int8wo, marlin_qqq'
)
)
parser.add_argument('-s', '--sparsity', type=str,
help=(
'Which sparsity techniques to apply: semi-structured'
)
)
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size')
parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)')
Expand All @@ -617,6 +709,6 @@ def callback(x):

args = parser.parse_args()
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
args.prefill_size, args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
args.temperature, args.checkpoint_path, args.quantization, args.sparsity, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
)
Loading
Loading