Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream ModelOutputs #29545

Closed
wants to merge 49 commits into from
Closed

Conversation

dmarx
Copy link

@dmarx dmarx commented Mar 8, 2024

What does this PR do?

(NB: PR is still work in progress, but close to ready)

The "streaming" feature for text generation is currently constrained to streaming token ids. If a user requests an output dict, the stream still only generates token ids without other attributes that may have been requested such as scores, raw logits, activations. After the stream has been consumed, the function returns its final output, which will be the originally requested output dict.

This PR aligns the return type of the streamer with the requested return type by encapsulating the logic that determines how the return value is constructed. In so doing, this change also permits users to stream richer representations than just token ids. EDIT: Additionally, this exposes a useful probe for testing, as demonstrated with the discovery of #29551

Summary of changes:

  • Adds a ._prepare_output() function to GenerationMixin, encapsulates output formatting logic
  • ._prepare_output() replaces nested conditionals that previously built return values (beam samplers excluded)
  • Streamer.put() receives its input from ._prepare_output() invocation rather than restricted to token IDs tensor
  • Inputs to Streamer.put() are no longer sent to the CPU prior to being passed to .put()
    • When desired, delegates responsibility for moving the tensor to the provided Streamer.
  • Adds OutputStreamer and OutputIteratorStreamer classes

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • Did you write any new necessary tests?
    • OutputIteratorStreamerTester to be modularized with fixtures, DRYed
  • test against hydra-node
  • fix inconsistent list output from OutputStreamer.on_ready
  • flesh out rest of intermediate return values
    • scores
    • logits
    • attention
    • hidden
    • NB: not currently streaming KV cache and encoder states
  • add some kind of "output formatter" factory to ensure outputs are uniform across the class w/o repeating a ton of logic.
    • encapsulate output construction logic in a function
    • refactor to use function
    • add tests validating output type consistency and parity with streamer.put()
  • break out tests into more isolated test cases (rather than monolithic permutation) for better reporting
  • update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
    • Send the CW documentation team a CR
    • Docs pages updated
    • docstrings
    • typehints
    • explicit input args for _prepare_output() (? maybe it's fine as is?)
  • fix import
  • rebase onto main
  • fix sigterm on hydra-node
  • [ ] revisit ECHO behavior?
    • on our end, just skipping over streamed outputs when output.score is None
    • should be sufficient for HF's end if we just document that the ECHO response will have null scores/logits/activations
  • reimplement their TextIteratorStreamer using the new OutputIteraterSreamer
  • Fix support for assisted decoding
    • I propose making this out-of scope for this PR and address it in a future PR, similar to how streaming currently doesn't support beam decoding.

Who can review?

@gante

@dmarx dmarx mentioned this pull request Mar 8, 2024
29 tasks
@dmarx
Copy link
Author

dmarx commented Mar 8, 2024

NB: I think I've uncovered an issue in the contrastive decoding implementation. Currently, scores and (raw) logits are the same values, but should not be. I think this is because the scores are warped by logic encapsulated in the _ranking_fast function. I currently have my PR here set such that the streaming outputs match with the baseline, but I achieved that by providing the same values to the streamed logits and scores attributes.

@gante
Copy link
Member

gante commented Mar 12, 2024

Hi @dmarx 👋 Thank you for opening this PR!

We're not accepting complex updates to streaming at the moment, as we have an ongoing design not far from being added (and it touches other non-streaming goals for generate).

In a nutshell, we are going to add the option to yield stuff from generate :)

@dmarx
Copy link
Author

dmarx commented Mar 12, 2024

Nice! Looking forward to the change, fingers crossed users will be able to yield output dicts with scores and logits in addition to token_ids (i.e. the motivation for this PR) :)

Any chance you could estimate an ETA for the yield from generate change?

@dmarx
Copy link
Author

dmarx commented Mar 13, 2024

actually better yet, @gante could you maybe point me to the issue or PR I should follow to keep tabs on the "yield from generate" updates?

@gante
Copy link
Member

gante commented Mar 14, 2024

@dmarx Here's the roadmap:

  1. Add support for torch.compile (under way, I estimate 1-2 months of additional work). This ensures we have a generate structure that can be optimized, as well as a test suite to prevent regressions
  2. Generate blockwise refactor. Currently generate is a monolith, which prevents adding new modalities or optionally change return into yield without massive if/else blocks in many places
  3. Streaming 2.0. Now that we have optimization features and a lego-style generate function, enable a codepath with yield

No trackers for 2. and 3. yet. I'd estimate 3 months super optimistically, 6 months if a flurry of new generation techniques and model modifications come out in the near future :)

@dmarx
Copy link
Author

dmarx commented Mar 14, 2024

@gante refactoring messy research code is sorta my jam (just ask poli). Let me know if I could help accelerate some of this roadmap, would be happy to help with (2) if you can elaborate on your design plan a bit more.

Also, let me know if it would be helpful to re-articulate this PR's motivations as an issue for your user story tracking. I can just leave this PR open if you prefer.

@gante
Copy link
Member

gante commented Mar 14, 2024

@dmarx extra hands would indeed be handy (pun intended)!

I'm going to chat with @zucchini-nlp soon (tomorrow?), so we can gather a set of requirements to then share publicly a concrete plan for the factorization of generate. I'm sure some of the tasks can be done in parallel with the torch.compile goal without causing (big) conflicts. In fact, the refactor is already in motion -- this PR was written with the refactor in mind 😉

If I don't reply within a week, please don't hesitate to ping me -- I do want to take your offer!

@dmarx
Copy link
Author

dmarx commented Mar 29, 2024

@gante @zucchini-nlp bump re the GenerationMixin.generate refactoring roadmap and design plans

@gante
Copy link
Member

gante commented Apr 18, 2024

@dmarx not forgotten, we are finalizing gathering requirements internally across the different teams/repos :)

@dmarx
Copy link
Author

dmarx commented May 9, 2024

@gante just checking in.

@gante
Copy link
Member

gante commented May 16, 2024

@dmarx #30810 :D

@dmarx
Copy link
Author

dmarx commented May 20, 2024

had kept this open for communication purposes, closing in favor of #30810

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants