Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Static cache + torch.compile: better documentation for prefill static sequence length #29151

Closed
fxmarty opened this issue Feb 20, 2024 · 8 comments · Fixed by #30788
Closed
Labels
Cache Compilation Issues related to torchdynamo and torchinductor Generation

Comments

@fxmarty
Copy link
Contributor

fxmarty commented Feb 20, 2024

Feature request

When using torch.compile, the prefill is recompiled for every new sequence length, which is slow. It may be nice to be able to compile only say for some sequence lengths (1, 2, 4, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, etc) on the fly depending on the input lengths, using some padding.

Motivation

torch.compile compilation is prohibitively slow even with #29114

If people want to use transformers + static cache + torch.compile, it should be FAST to run generate on new sequence lengths.

Your contribution

None for now

@fxmarty fxmarty changed the title Static cache: support prefill static sequence length Static cache + torch.compile: support prefill static sequence length Feb 20, 2024
@amyeroberts
Copy link
Collaborator

cc @gante

@gante
Copy link
Member

gante commented Feb 21, 2024

@fxmarty this is the same problem as we have in TF and Flax. There, we nudged users to use the pad_to_multiple_of argument in the tokenizer, which I believe solves the problem 🤗

How do you suggest us to let users know about this feature, other than docs?

@fxmarty
Copy link
Contributor Author

fxmarty commented Feb 22, 2024

@gante That's already good to support that in the tokenizer, but I am wondering whether it would make sense to support that in the generation directly. Have you seen any user request about that?

@gante
Copy link
Member

gante commented Feb 26, 2024

@fxmarty I haven't.

I am also not a big fan of it:
a) it pushes the problem from forward to generate (i.e. forward would not see recompilations, but generate will, as it will have an input tensor with arbitrary length)
b) it hides the real behavior (padding) from the user, which may lead to issues due to behavior misunderstandings. An obvious one I can foresee is "my input has X length, I have set max_new_tokens=Y, why isn't the output length X+Y?"

pad_to_multiple_of avoids the problems I mentioned, but it is harder to discover 🤗 Still, I think it is preferable!

@fxmarty
Copy link
Contributor Author

fxmarty commented Feb 26, 2024

a) it pushes the problem from forward to generate (i.e. forward would not see recompilations, but generate will, as it will have an input tensor with arbitrary length)

Not really (at least not for torch.compile), as generate is simply not compiled.

b) it hides the real behavior (padding) from the user, which may lead to issues due to behavior misunderstandings. An obvious one I can foresee is "my input has X length, I have set max_new_tokens=Y, why isn't the output length X+Y?"

Fair enough. I think a warning could be shown in generate (e.g. in case the model is an OptimizedModule) about the feature and/or we could document the usage with torch.compile.

@gante
Copy link
Member

gante commented Feb 26, 2024

as generate is simply not compiled.

@fxmarty yet ;) Beam search has some heavy tensor operations that should be compiled, some logits processors are heavy, etc.

The difference between passing a flag to generate or to the tokenizer is small, but passing to generate will restrict our ability to fully compile generate if we decide to go through that path for some reason

@fxmarty
Copy link
Contributor Author

fxmarty commented Feb 26, 2024

@gante agreed although @torch.compiler.disable is useful for that

@fxmarty fxmarty changed the title Static cache + torch.compile: support prefill static sequence length Static cache + torch.compile: better documentation for prefill static sequence length Feb 26, 2024
@fxmarty fxmarty added the Compilation Issues related to torchdynamo and torchinductor label Feb 28, 2024
@huggingface huggingface deleted a comment from github-actions bot Mar 25, 2024
@huggingface huggingface deleted a comment from github-actions bot Apr 19, 2024
@huggingface huggingface deleted a comment from github-actions bot May 14, 2024
@gante
Copy link
Member

gante commented May 25, 2024

#30788 -- this PR adds documentation to use pad_to_multiple_of to avoid input shape-related recompilation

I'm assuming this issue can be closed after the PR gets merged :) In the generate refactor we will be separating the prefill step, and we can then move/enhance related documentation.

@huggingface huggingface deleted a comment from github-actions bot Jun 19, 2024
@huggingface huggingface deleted a comment from github-actions bot Jul 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Cache Compilation Issues related to torchdynamo and torchinductor Generation
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants