-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Terminator strings for generate() #28932
Conversation
716c0ba
to
40e4abe
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@Rocketknight1 , hey! I liked the feature, a very useful one I think. Just a couple questions, since I am not sure what was the intended behavior initially |
|
||
# Now we concatenate the match booleans across all strings and check if any are True | ||
string_matches = torch.cat(string_matches, dim=0) | ||
return torch.any(string_matches).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just being curious, so the generation stops for all batches when at least one has stop string?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is correct! I think this is the desired behaviour for all stopping conditions (cc @gante to confirm)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zucchini-nlp raised a good point, about when to trigger the end of generation -- up until now, the stopping criteria behaved equally for all rows in the batch. This is the first one that can be True
for some rows and False
for others.
It should behave like the EOS token, where we continue generating until all rows reach some condition to stop. Finished rows keep adding the pad token until all rows are done. Now, here's the catch: this mechanism [per-row tracking of finished sequences] exists for the EOS token, but doesn't exist for the Stopping Criteria 😬
For now, let's keep as @Rocketknight1 added it (most users use with batch size = 1 anyways). As a follow up action let's: a) move the EOS logic to a StoppingCriteria
; b) ensure all StoppingCriteria
return a boolean array containing True in the rows that trigger the condition :D
c26d419
to
6a92a31
Compare
a0a2c23
to
254fa2d
Compare
This should be ready for review now @gante @amy! The core code is totally incomprehensible tensor operations - don't stress if you can't follow them, because I wrote them in one caffeine-fuelled afternoon and I also forget what they're doing if I look away for more than 20 minutes. We're kind of trusting in the tests. The main problem I encountered is I don't have a clean way to get the tokenizer's vocabulary - I'm handling the two common cases of replacing |
Nice! I'll let @gante review first to confirm it's all aligned with the current logic processors. Just skimming my main comment is that we need tests for the criterion's methods, in particular |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool PR 🔥
Extra request: let's add stop_strings
to the GenerationConfig
and, if it is set at generation time, let's add this stopping criteria in _get_stopping_criteria
(pro tip: we can instantiate the tokenizer from the model repo attribute in the model instance, so we don't need to pass the tokenizer to generate
:D). That way, users can:
- do
model.generate(..., stop_strings=["foo", "bar"])
, as opposed tomodel.generate(..., stopping_criteria=StoppingCriteriaList(...))
, which is more user-friendly - store their model's
stop_strings
in the generation config 💛
|
||
# Now we concatenate the match booleans across all strings and check if any are True | ||
string_matches = torch.cat(string_matches, dim=0) | ||
return torch.any(string_matches).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zucchini-nlp raised a good point, about when to trigger the end of generation -- up until now, the stopping criteria behaved equally for all rows in the batch. This is the first one that can be True
for some rows and False
for others.
It should behave like the EOS token, where we continue generating until all rows reach some condition to stop. Finished rows keep adding the pad token until all rows are done. Now, here's the catch: this mechanism [per-row tracking of finished sequences] exists for the EOS token, but doesn't exist for the Stopping Criteria 😬
For now, let's keep as @Rocketknight1 added it (most users use with batch size = 1 anyways). As a follow up action let's: a) move the EOS logic to a StoppingCriteria
; b) ensure all StoppingCriteria
return a boolean array containing True in the rows that trigger the condition :D
stop_strings (`Union[str, List[str]]`): | ||
A list of strings that should end generation. If a string is passed, it will be treated like a | ||
list with a single element. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After we're done with most of the changes, can we add an example of how to use this StoppingCriteria
with generate
? 🙏 (Similar to the examples we have in the Logits processor class)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't sure which examples you mean! Can you link me to the code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, got it! Let me see if I can incorporate this into generate + the generation_config
first, and then write those
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! (With apologies for the delay when I got pulled away to work on other stuff)
The example now lives here. I tested it and it works well! You might have to expand stopping_criteria.py
to see it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice example! Thank you for writing up this detailed docstring ❤️
@amyeroberts those are purely internal methods - maybe I should just mark them as private with a leading |
@Rocketknight1 Request for tests is to verify the logic rather than them being public or private.
|
@amyeroberts tests for the sub-methods are in! |
c4b90fa
to
455259e
Compare
cb74b51
to
ba8c7d1
Compare
Quick update here: I refactored the initialization and added a small cache in case users repeatedly call We may still end up going with @amyeroberts' tokenizer-decode solution after profiling, but I wanted to make sure this method didn't have any obvious performance issues first ✊. I'll do some testing on Monday! (also I still have a deep phobia of graph breaks from my XLA era) |
@Rocketknight1 Awesome ! Am I right in saying the profiling shows this is faster now? You know the next question I'm going to ask is if it correctly handles when tokenizers have different splitting behaviour e.g. preprending with |
I'm working on that and highly confident¹ that a solution can be found. ¹ Not at all confident |
3dd5dec
to
25ef298
Compare
Quick update - |
25ef298
to
8c23e39
Compare
@zucchini-nlp The condition now returns a per-sample vector correctly. Can I be lazy and ask you to add the test for it that @amyeroberts was requesting in #29116 here? If you're too busy, don't worry, I'll get to it! |
@Rocketknight1 yeah, I could add it but the PR is already merged. I hope @gante can make it clear and tell where the test goes |
@zucchini-nlp you can just add the test to this PR's branch instead! |
cc @amyeroberts @gante this PR now tests per-row stopping conditions from #29116, thanks to @zucchini-nlp. Tests are passing, so the feature looks good! I ran the slow tests locally as well. |
Co-authored-by: amyeroberts <[email protected]>
568e73f
to
8b52039
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the work adding and iterating on this!
end_overlaps = {token_list[idx]: overlaps for idx, overlaps in all_token_end_overlaps[stop_string].items()} | ||
self.assertEqual(valid_positions, {"s": [3], "last": [2]}) | ||
self.assertEqual(end_overlaps, {"top": [3], "topper": [3], "p": [1]}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great ❤️
self.assertEqual(valid_positions, {"s": [3], "last": [2]}) | ||
self.assertEqual(end_overlaps, {"top": [3], "topper": [3], "p": [1]}) | ||
|
||
def test_stop_string_embedding_vecs(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
No, thank you for all the patience fixing my horrifically verbose docstrings and incomprehensible tests, lol |
* stash commit (will discard all of this) * stash commit * First commit - needs a lot of testing! * Add a test * Fix imports and make the tests actually test something * Tests pass! * Rearrange test * Add comments (but it's still a bit confusing) * Stop storing the tokenizer * Comment fixup * Fix for input_ids with a single sequence * Update tests to test single sequences * make fixup * Fix incorrect use of isin() * Expand tests to catch more cases * Expand tests to catch more cases * make fixup * Fix length calculation and update tests * Handle Ġ as a space replacement too * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Joao Gante <[email protected]> * Add optimizations from Joao's suggestion * Remove TODO * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Joao Gante <[email protected]> * Update tests/generation/test_stopping_criteria.py Co-authored-by: Joao Gante <[email protected]> * make fixup * Rename some variables and remove some debugging clauses for clarity * Add tests for the sub-methods * Clarify one test slightly * Add stop_strings to GenerationConfig * generate() supports stop_string arg, asks for tokenizer if not provided * make fixup * Cleanup code and rename variables for clarity * Update tokenizer error * Update tokenizer passing, handle generation on GPU * Slightly more explanation cleanup * More comment cleanup * Factor out the token cleanup so it's more obvious what we're doing, and we can change it later * Careful with that cleanup! * Cleanup + optimizations to _get_matching_positions * More minor performance tweaks * Implement caching and eliminate some expensive ops (startup time: 200ms -> 9ms) * Remove the pin_memory call * Parallelize across all stop strings! * Quick fix for tensor devices * Update embeddings test for the new format * Fix test imports * Manual patching for BERT-like tokenizers * Return a bool vector instead of a single True/False * Better comment * Better comment * Add tests from @zucchini-nlp * Amy's list creation nit * tok_list -> token_list * Push a big expanded docstring (should we put it somewhere else?) * Expand docstrings * Docstring fixups * Rebase * make fixup * Make a properly general method for figuring out token strings * Fix naming throughout the functions * Move cache, refactor, fix tests * Add comment * Remove finished TODO * Remove finished TODO * make fixup * Update src/transformers/generation/stopping_criteria.py Co-authored-by: amyeroberts <[email protected]> * Update and shorten docstring * Update tests to be shorter/clearer and test specific cases --------- Co-authored-by: Joao Gante <[email protected]> Co-authored-by: amyeroberts <[email protected]>
* stash commit (will discard all of this) * stash commit * First commit - needs a lot of testing! * Add a test * Fix imports and make the tests actually test something * Tests pass! * Rearrange test * Add comments (but it's still a bit confusing) * Stop storing the tokenizer * Comment fixup * Fix for input_ids with a single sequence * Update tests to test single sequences * make fixup * Fix incorrect use of isin() * Expand tests to catch more cases * Expand tests to catch more cases * make fixup * Fix length calculation and update tests * Handle Ġ as a space replacement too * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Joao Gante <[email protected]> * Add optimizations from Joao's suggestion * Remove TODO * Update src/transformers/generation/stopping_criteria.py Co-authored-by: Joao Gante <[email protected]> * Update tests/generation/test_stopping_criteria.py Co-authored-by: Joao Gante <[email protected]> * make fixup * Rename some variables and remove some debugging clauses for clarity * Add tests for the sub-methods * Clarify one test slightly * Add stop_strings to GenerationConfig * generate() supports stop_string arg, asks for tokenizer if not provided * make fixup * Cleanup code and rename variables for clarity * Update tokenizer error * Update tokenizer passing, handle generation on GPU * Slightly more explanation cleanup * More comment cleanup * Factor out the token cleanup so it's more obvious what we're doing, and we can change it later * Careful with that cleanup! * Cleanup + optimizations to _get_matching_positions * More minor performance tweaks * Implement caching and eliminate some expensive ops (startup time: 200ms -> 9ms) * Remove the pin_memory call * Parallelize across all stop strings! * Quick fix for tensor devices * Update embeddings test for the new format * Fix test imports * Manual patching for BERT-like tokenizers * Return a bool vector instead of a single True/False * Better comment * Better comment * Add tests from @zucchini-nlp * Amy's list creation nit * tok_list -> token_list * Push a big expanded docstring (should we put it somewhere else?) * Expand docstrings * Docstring fixups * Rebase * make fixup * Make a properly general method for figuring out token strings * Fix naming throughout the functions * Move cache, refactor, fix tests * Add comment * Remove finished TODO * Remove finished TODO * make fixup * Update src/transformers/generation/stopping_criteria.py Co-authored-by: amyeroberts <[email protected]> * Update and shorten docstring * Update tests to be shorter/clearer and test specific cases --------- Co-authored-by: Joao Gante <[email protected]> Co-authored-by: amyeroberts <[email protected]>
generate()
stops when it encounterseos_token_id
, but there are various circumstances when we want it to stop for other tokens too. The ideal situation would be to allow a set of strings that halts generation, and then include this information with the model, so model authors can set e.g. custom tokens like<|im_end|>
as halting strings, even when those strings don't have a special token.The problem with stopping for specific strings rather than tokens is that a string can be tokenized in many different ways, and the tokens that contain a string may also have overhangs on either end:
["?><", "|", "im_", "end", "|", ">>"]
. Since we have to check after each token generated by the model, we want to avoid detokenization and string comparisons, as this will cause a lot of slowdown and prevent us from compiling the generation loop.This PR adds a
StoppingCriteria
for stop strings. It takes some time to preprocess the stop strings and the tokenizer vocabulary together and builds an embedding matrix containing the information it needs about which tokens can construct each stop string, but once that's done the entire generation-time check can be performed with only tensor operations and static, known shapes.fixes #28801