Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TTFT benchmarks + update sparsity benchmarks #1140

Merged
merged 41 commits into from
Dec 4, 2024

Conversation

jcaip
Copy link
Contributor

@jcaip jcaip commented Oct 22, 2024

This PR adds in TTFT token benchmarks to torchAO, and also updates the benchmarking script to handle sparsity a bit nicer + use the 2:4 sparse checkpoints that are available.

image

Additionally also adds in padding support for int8 dynamic quant + 2:4 sparsity, which we were missing before.

jcaip added 6 commits October 18, 2024 11:05
Summary:

This PR adds in a sparsity option to the LLaMa benchmarks.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link

pytorch-bot bot commented Oct 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1140

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit de2d447 with merge base 2f97b09 (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 22, 2024
@jcaip jcaip marked this pull request as ready for review October 28, 2024 18:27
@jcaip jcaip changed the title [wip] add ttft benchmarks + update sparsity benchmarks Add TTFT benchmarks + update sparsity benchmarks Oct 28, 2024
@jcaip jcaip requested a review from HDCharles October 28, 2024 18:28
from torchao.dtypes import MarlinSparseLayout
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
if sparsity and "semi" in sparsity:
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't using any of the derived variables. It should use the derived ones or be in a separate section.

Copy link
Contributor

@HDCharles HDCharles left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm if you move the marlin stuff so its clearer what derived variables it actually uses

print(f"Peak Memory Usage: {mem:.02f} GB")
print(f"Model Size: {model_size:.02f} GB")
if write_result:
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB "
result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, time={t:5.4f} sec, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

time is a really generic term, is this TTFT or overall run? the tok/s info is already the non prefill indicator so TTFT or time to do prefill is probably more valuable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's overall time, but I limit num_tokens to 1. I can make this a bit clearer though, maybe a --ttft flag that sets forces num_tokens_generated to be 1.

@jcaip jcaip force-pushed the jcaip/sparse-benchmarking-updates branch from 6858180 to 4fdfa7b Compare December 2, 2024 21:18
@jcaip jcaip added benchmark topic: new feature Use this tag if this PR adds a new feature topic: performance Use this tag if this PR improves the performance of a feature labels Dec 3, 2024
@jcaip jcaip merged commit 1a0dbf1 into main Dec 4, 2024
17 of 21 checks passed
vkuzo pushed a commit that referenced this pull request Dec 5, 2024
This PR adds in TTFT token benchmarks to torchAO, and also updates the benchmarking script to handle sparsity a bit nicer + use the 2:4 sparse checkpoints that are available.

Additionally also adds in padding support for int8 dynamic quant + 2:4 sparsity, which we were missing before.
@zhyncs
Copy link

zhyncs commented Dec 8, 2024

Hi @vkuzo Thanks for the great work!
I see that GitHub has already released 0.7.0, but PyPI is still at 0.6.1. When is the update for PyPI expected? Thanks!
https://github.com/pytorch/ao/releases/tag/v0.7.0

yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
* Torchchat CLI pipeline for Multimodal Models

* Remove torchaudio check; we don't use it

* Flip the imports back for ET

---------

Co-authored-by: vmpuri <[email protected]>
Co-authored-by: Jack-Khuu <[email protected]>
@vkuzo
Copy link
Contributor

vkuzo commented Dec 11, 2024

Hi @vkuzo Thanks for the great work!

I see that GitHub has already released 0.7.0, but PyPI is still at 0.6.1. When is the update for PyPI expected? Thanks!

https://github.com/pytorch/ao/releases/tag/v0.7.0

It's available as of last night!

@zhyncs
Copy link

zhyncs commented Dec 11, 2024

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
benchmark CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature topic: performance Use this tag if this PR improves the performance of a feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants