Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

int8 dynamic prefill weight only decode #1436

Merged
merged 63 commits into from
Dec 30, 2024
Merged
Changes from 1 commit
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
f390fd9
Add sparsity flag to benchmark
jcaip Oct 18, 2024
67937a9
update
jcaip Oct 18, 2024
6b62266
update
jcaip Oct 18, 2024
aa4c9df
fp8 testing
jcaip Oct 18, 2024
6b1ede1
fp8 testing
jcaip Oct 18, 2024
3c07c40
wip
jcaip Oct 22, 2024
a6c7de9
update benchmark script
jcaip Oct 22, 2024
3660766
update
jcaip Oct 22, 2024
ddf2e10
wip
jcaip Oct 22, 2024
ad4d3b0
udpate
jcaip Oct 22, 2024
653587e
update
jcaip Oct 22, 2024
c757357
wip
jcaip Oct 22, 2024
f1b0841
wip
jcaip Oct 22, 2024
afeaff5
test
jcaip Oct 22, 2024
c294765
wip
jcaip Oct 22, 2024
803e9b3
update
jcaip Oct 22, 2024
eb18850
fix
jcaip Oct 22, 2024
2642212
wip
jcaip Oct 22, 2024
4eccdb9
move out of aqt
jcaip Oct 22, 2024
13e6fd6
wip
jcaip Oct 22, 2024
608d70c
moved float8+24 to it's own file
jcaip Oct 22, 2024
b1f1796
Merge branch 'main' into jcaip/sparse-benchmarking-updates
jcaip Oct 22, 2024
30a4fac
update
jcaip Oct 23, 2024
6091592
wip
jcaip Oct 23, 2024
17f9121
remove float8 for now
jcaip Oct 23, 2024
75d0a0b
wip
jcaip Oct 23, 2024
b2fba99
fix
jcaip Oct 28, 2024
ba5665d
fix
jcaip Oct 28, 2024
4fdfa7b
time prefill by default
jcaip Dec 2, 2024
111babc
update
jcaip Dec 3, 2024
35f1fc7
merge
jcaip Dec 3, 2024
23f981d
fix merge conflicts
jcaip Dec 3, 2024
74c52ff
update
jcaip Dec 3, 2024
eed072d
update benchmarks
jcaip Dec 3, 2024
67cbcbb
fix ruff check
jcaip Dec 3, 2024
0e579ae
fix ruff v2
jcaip Dec 3, 2024
443db19
undo change
jcaip Dec 3, 2024
054717e
add padding
jcaip Dec 3, 2024
2e5b72a
update import
jcaip Dec 3, 2024
2b81dd6
final commit
jcaip Dec 3, 2024
de2d447
fix script
jcaip Dec 3, 2024
c0fa0da
wip
jcaip Dec 6, 2024
584c013
update
jcaip Dec 6, 2024
38d60c7
update
jcaip Dec 25, 2024
97cca7a
update
jcaip Dec 25, 2024
525053b
merge main
jcaip Dec 25, 2024
4da1b31
fix merge confligt
jcaip Dec 25, 2024
2517406
demo
jcaip Dec 25, 2024
5b8a28c
update
jcaip Dec 30, 2024
e25b30c
update generate
jcaip Dec 30, 2024
a58e0fd
moved summarization to standalone script
jcaip Dec 30, 2024
ea5cb0c
update
jcaip Dec 30, 2024
17a191a
update weight only decode flag
jcaip Dec 30, 2024
8899435
remove prompt.txt
jcaip Dec 30, 2024
a3056ff
cleanup
jcaip Dec 30, 2024
67a1a35
remove moby.txt
jcaip Dec 30, 2024
1554a8c
update
jcaip Dec 30, 2024
5161364
update
jcaip Dec 30, 2024
562191f
update
jcaip Dec 30, 2024
bf18806
update benchmars
jcaip Dec 30, 2024
89f03d8
rename arg
jcaip Dec 30, 2024
ce58e1e
update demo script
jcaip Dec 30, 2024
b144a53
formatting
jcaip Dec 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wip
jcaip committed Oct 22, 2024
commit ddf2e10cd5e6138fae7368d83b4cf7bf1d5d871e
36 changes: 25 additions & 11 deletions benchmarks/benchmark_gpu_sparsity.py
Original file line number Diff line number Diff line change
@@ -87,6 +87,8 @@ def sparse_func():
sparse_func_c = torch.compile(sparse_func, mode="max-autotune")
sparse_time_c = benchmark_model_with_warmup(sparse_func_c, 'sparse_compile.json.gz')

torch._dynamo.reset()

return {
"test_function": args.eval_fn,
"m": m,
@@ -107,7 +109,8 @@ def sparse_func():
"--mode",
type=str,
choices=[
"llama-3b",
"llama3-8b-a",
"llama3-8b-w",
"vit-mlp",
"nvidia-fixed-k",
"nvidia-fixed-mn",
@@ -157,15 +160,8 @@ def sparse_func():

print(f"Started benchmark: {args}")

if args.mode == "llama-3b-shapes":
bert_shapes = [
(3072, 1024, 16384),
(4096, 1024, 16384),
(1024, 1024, 16384),
(1024, 4096, 16384),
# (16, 4096, 11008),
# (16, 4096, 4096),
# (16, 11008, 4096),
if args.mode == "llama3-8b-a":
mm_shapes = [
(4096, 13312, 16384),
(4096, 16384, 6560),
(4096, 22528, 32768),
@@ -175,12 +171,30 @@ def sparse_func():
]
results = (
run_gpu_sparse_benchmark(m, k, n, args)
for (m, k, n) in tqdm(bert_shapes)
for (m, n, k) in tqdm(mm_shapes)
)
elif args.mode == "llama3-8b-w":
mm_shapes = [
(16, 4096, 11008),
(16, 4096, 4096),
(16, 11008, 4096),
(4096, 4096, 11008),
(4096, 4096, 4096),
(4096, 11008, 4096),
(8192, 4096, 11008),
(8192, 4096, 4096),
(8192, 11008, 4096),
]
results = (
run_gpu_sparse_benchmark(m, k, n, args)
for (m, k, n) in tqdm(mm_shapes)
)
elif args.mode == "vit-mlp":
vit_shapes= [
# vit-base
(768, 3072, 50432),
(3072, 3072, 50432),
# vit-huge
(1280, 5120, 65792),
(5120, 1280, 65792),
]
4 changes: 1 addition & 3 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
@@ -15,8 +15,6 @@
import torch._inductor.config
from torchao.utils import get_model_size_in_bytes
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torch.sparse import SparseSemiStructuredTensor
SparseSemiStructuredTensor._FORCE_CUTLASS = False

def device_sync(device):
if "cuda" in device:
@@ -481,7 +479,7 @@ def callback(x):
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Your CLI description.')

parser.add_argument('--ttft', type=bool, default=False, help='Whether to run in ttft mode')
parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.')
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')