Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speculative #1308

Merged
merged 38 commits into from
Dec 11, 2023
Merged

Speculative #1308

merged 38 commits into from
Dec 11, 2023

Conversation

Narsil
Copy link
Collaborator

@Narsil Narsil commented Dec 4, 2023

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Narsil Narsil requested a review from OlivierDehaene December 4, 2023 14:46
proto/generate.proto Show resolved Hide resolved
Comment on lines 10 to 11
pub use pb::generate::v1::InfoResponse as ShardInfo;
pub use pb::generate::v1::{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pub use pb::generate::v1::InfoResponse as ShardInfo;
pub use pb::generate::v1::{
pub use pb::generate::v2::InfoResponse as ShardInfo;
pub use pb::generate::v2::{

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -515,6 +515,7 @@ fn send_responses(

let mut stopped = false;

tracing::info!("Generation: {:?}", generation);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tracing::info!("Generation: {:?}", generation);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

text: generation.token_text,
logprob: generation.token_logprob,
special: generation.token_is_special,
let tokens: Vec<Token> = if let Some(tokens_) = generation.tokens {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should .expect() here as this should be a bug if tokens is empty.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also if the vecs are not the same size.

@@ -97,6 +101,8 @@ def get_model(
else:
raise RuntimeError(f"Unknown dtype {dtype}")

SPECULATE = 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to proper handling.

next_input_ids, next_token_logprobs, logprobs = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits

from text_generation_server.models import SPECULATE
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a get like we do for the cache manager instead.

Comment on lines 942 to 943
# next_token_ids,
# next_token_logprobs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# next_token_ids,
# next_token_logprobs,

Comment on lines 961 to 962
# next_token_id,
# next_token_logprob,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# next_token_id,
# next_token_logprob,

if not stop:
stopped = False
left = 0
for j, next_token_id in enumerate(_next_token_ids):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this could happen in the same for loop as above?

@OlivierDehaene
Copy link
Member

We are also missing:

  • new prometheus metrics for the number of accepted speculative ids
  • speculative tokens count in next_batch
  • maybe refactor infer to have the same InferResponse as before.

@Narsil
Copy link
Collaborator Author

Narsil commented Dec 5, 2023

We are also missing:

* new prometheus metrics for the number of accepted speculative ids

* speculative tokens count in `next_batch`

* maybe refactor infer to have the same InferResponse as before.

Did all 3.

integration-tests/models/test_flash_llama.py Outdated Show resolved Hide resolved
server/text_generation_server/cli.py Show resolved Hide resolved
server/text_generation_server/utils/tokens.py Outdated Show resolved Hide resolved
server/text_generation_server/utils/tokens.py Outdated Show resolved Hide resolved
server/text_generation_server/utils/tokens.py Outdated Show resolved Hide resolved
server/text_generation_server/utils/tokens.py Outdated Show resolved Hide resolved
server/text_generation_server/utils/tokens.py Outdated Show resolved Hide resolved
@Narsil
Copy link
Collaborator Author

Narsil commented Dec 9, 2023

Sorry you reviewed this, I was making adjustements to reduce the overhead. Cleaned it up.

@Narsil
Copy link
Collaborator Author

Narsil commented Dec 9, 2023

@OlivierDehaene Good for review this time, I'll run a few benches.

OlivierDehaene
OlivierDehaene previously approved these changes Dec 11, 2023
server/text_generation_server/utils/layers.py Outdated Show resolved Hide resolved
server/text_generation_server/models/flash_causal_lm.py Outdated Show resolved Hide resolved
@OlivierDehaene OlivierDehaene merged commit 9ecfa16 into main Dec 11, 2023
8 checks passed
@OlivierDehaene OlivierDehaene deleted the medusa2 branch December 11, 2023 11:46
@shcho1118
Copy link

shcho1118 commented Dec 12, 2023

I know this PR has been merged, but I have a question.
Where can I find the part about generating masking for tree-based attention, which is one of Medusa's features?
Or does it just use top-1 from each medusa head (in which case it would be the same as causal)?
image

@DoubleVII
Copy link

I know this PR has been merged, but I have a question. Where can I find the part about generating masking for tree-based attention, which is one of Medusa's features? Or does it just use top-1 from each medusa head (in which case it would be the same as causal)? image

As you say, this PR just use top-1 from each medusa head.

kdamaszk pushed a commit to kdamaszk/tgi-gaudi that referenced this pull request Apr 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants