Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] Implement prefix caching for TPUs #10307

Merged
merged 2 commits into from
Nov 20, 2024
Merged

[TPU] Implement prefix caching for TPUs #10307

merged 2 commits into from
Nov 20, 2024

Conversation

WoosukKwon
Copy link
Collaborator

@WoosukKwon WoosukKwon commented Nov 13, 2024

This PR implements the prefix caching support for the TPU backend.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Nov 13, 2024
@WoosukKwon WoosukKwon added the tpu Related to Google TPUs label Nov 13, 2024
@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Nov 13, 2024

Nice work!

output = output.permute(0, 2, 1, 3)
else:
# Prefill with paged KV cache.
# TODO(woosuk): Tune the below knobs.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Woosuk for writing the PR.

I'm benchmarking the kernel so likely I'll have some recommended num_kv_pages_per_compute_block/num_queries_per_compute_block to share.

Also, the revised paged attention kernel is in torch_xla nightly. Could you try again? I pulled your PR and it seems it needs additional work to get the effective_q_lens and plumb it to the kernel.

cc: @WoosukKwon

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vanbasten23 Is the fixed kernel available in today's nightly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vanbasten23 After the kernel fix, the model generates correct outputs with prefix caching 🎉

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome. Thanks for confirming!

outputs = llm.generate(prompts, sampling_params)
for output, answer in zip(outputs, answers):
for output in outputs:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if you need a test for the prefix caching.

@vanbasten23
Copy link

Btw, which command did you use run examples/offline_inference_tpu.py. I used $ python vllm/examples/offline_inference_tpu.py but it fails. Do you need to use a model other than "google/gemma-2b"?

num_kv_pages_per_compute_block = 16
num_queries_per_compute_block = 16
assert seq_len % num_queries_per_compute_block == 0
output = torch.ops.xla.multi_queries_paged_attention(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vanbasten23 - does this new kernel have the same SMEM requirements as the original paged_attention where the entire block table is stored in SMEM?

E.g. for the decoding run (see below), we split the batch dimension into smaller chunks and run the kernel multiple times

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey @robertgshaw2-neuralmagic , yes this new kernel have the same SMEM requirements. I am aware of the SMEM OOM issue you mentioned and we plan to address it.

Copy link

mergify bot commented Nov 17, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @WoosukKwon.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 17, 2024
@WoosukKwon
Copy link
Collaborator Author

@vanbasten23

Btw, which command did you use run examples/offline_inference_tpu.py. I used $ python vllm/examples/offline_inference_tpu.py but it fails. Do you need to use a model other than "google/gemma-2b"?

This is weird. Which version & TPU are you using?

@WoosukKwon
Copy link
Collaborator Author

I will double check, update this PR, and merge it tonight.

@vanbasten23
Copy link

@vanbasten23

Btw, which command did you use run examples/offline_inference_tpu.py. I used $ python vllm/examples/offline_inference_tpu.py but it fails. Do you need to use a model other than "google/gemma-2b"?

This is weird. Which version & TPU are you using?

I'm using TPU v5e but I'm not sure if it depends on a specific TPU version.

@mergify mergify bot removed the needs-rebase label Nov 19, 2024
Signed-off-by: Woosuk Kwon <[email protected]>
@WoosukKwon WoosukKwon marked this pull request as ready for review November 20, 2024 21:43
Signed-off-by: Woosuk Kwon <[email protected]>
@WoosukKwon WoosukKwon merged commit 2f77b6c into main Nov 20, 2024
13 of 16 checks passed
@WoosukKwon WoosukKwon deleted the tpu-prefix-caching branch November 20, 2024 21:54
tlrmchlsmth pushed a commit to neuralmagic/vllm that referenced this pull request Nov 23, 2024
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
ccs96307 pushed a commit to ccs96307/vllm that referenced this pull request Nov 25, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 28, 2024
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Maxime Fournioux <[email protected]>
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build tpu Related to Google TPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants