Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swapping between two different AttentionProcessors #158

Open
Birch-san opened this issue Aug 18, 2024 · 1 comment
Open

Swapping between two different AttentionProcessors #158

Birch-san opened this issue Aug 18, 2024 · 1 comment

Comments

@Birch-san
Copy link

thanks for the great work on stable-fast. it compiles quickly and boosts speed a lot.

is it possible to support two different compilation graphs?

for example swapping the SDXL UNet's AttentionProcessor to do a different algorithm depending on what kind of request the user sends us.

the current workaround is to "always run all optional functionality" and just pass zeroes for anything we're not using. but this isn't free.

so is there a way to compile the UNet with its default AttentionProcessor, then apply a different AttentionProcessor to it, and compile it again?
and after that whenever a user sends a request: we'd just apply whichever AttentionProcessor is appropriate, and it'll use either of the graphs it compiled previously?

@Birch-san
Copy link
Author

okay I was able to get this working by putting my control flow-dependent operations into a @script_if_tracing subroutine.
https://ppwwyyxx.com/blog/2022/TorchScript-Tracing-vs-Scripting/

this is awesome; can trace the UNet just once, and have small sections which activate functionality at runtime based on "if my_cool_kwarg is None". this means that we don't pay for the cost of unused optional functionality (we're not forced to send an all-zeros batch and do lots of no-op maths).

and sfast's register_custom_python_operator was a lifesaver for wrapping operations to survive script-mode JIT.

the other way to approach this would've been "always activate optional functionality, but send it a batch-of-zero". but I wasn't able to try that approach because torch sdpa + Flash Attn currently rejects batch-of-zero.
pytorch/pytorch#133780

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant