Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Terminator strings for generate() #28932

Merged
merged 68 commits into from
Apr 22, 2024
Merged

Conversation

Rocketknight1
Copy link
Member

@Rocketknight1 Rocketknight1 commented Feb 8, 2024

generate() stops when it encounters eos_token_id, but there are various circumstances when we want it to stop for other tokens too. The ideal situation would be to allow a set of strings that halts generation, and then include this information with the model, so model authors can set e.g. custom tokens like <|im_end|> as halting strings, even when those strings don't have a special token.

The problem with stopping for specific strings rather than tokens is that a string can be tokenized in many different ways, and the tokens that contain a string may also have overhangs on either end: ["?><", "|", "im_", "end", "|", ">>"]. Since we have to check after each token generated by the model, we want to avoid detokenization and string comparisons, as this will cause a lot of slowdown and prevent us from compiling the generation loop.

This PR adds a StoppingCriteria for stop strings. It takes some time to preprocess the stop strings and the tokenizer vocabulary together and builds an embedding matrix containing the information it needs about which tokens can construct each stop string, but once that's done the entire generation-time check can be performed with only tensor operations and static, known shapes.

fixes #28801

@Rocketknight1 Rocketknight1 requested a review from gante February 8, 2024 17:48
@Rocketknight1 Rocketknight1 force-pushed the terminator_strings_for_generate branch from 716c0ba to 40e4abe Compare February 8, 2024 17:49
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zucchini-nlp
Copy link
Member

@Rocketknight1 , hey! I liked the feature, a very useful one I think. Just a couple questions, since I am not sure what was the intended behavior initially

src/transformers/generation/stopping_criteria.py Outdated Show resolved Hide resolved
src/transformers/generation/stopping_criteria.py Outdated Show resolved Hide resolved

# Now we concatenate the match booleans across all strings and check if any are True
string_matches = torch.cat(string_matches, dim=0)
return torch.any(string_matches).item()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just being curious, so the generation stops for all batches when at least one has stop string?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is correct! I think this is the desired behaviour for all stopping conditions (cc @gante to confirm)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp raised a good point, about when to trigger the end of generation -- up until now, the stopping criteria behaved equally for all rows in the batch. This is the first one that can be True for some rows and False for others.

It should behave like the EOS token, where we continue generating until all rows reach some condition to stop. Finished rows keep adding the pad token until all rows are done. Now, here's the catch: this mechanism [per-row tracking of finished sequences] exists for the EOS token, but doesn't exist for the Stopping Criteria 😬

For now, let's keep as @Rocketknight1 added it (most users use with batch size = 1 anyways). As a follow up action let's: a) move the EOS logic to a StoppingCriteria; b) ensure all StoppingCriteria return a boolean array containing True in the rows that trigger the condition :D

@Rocketknight1
Copy link
Member Author

This should be ready for review now @gante @amy! The core code is totally incomprehensible tensor operations - don't stress if you can't follow them, because I wrote them in one caffeine-fuelled afternoon and I also forget what they're doing if I look away for more than 20 minutes. We're kind of trusting in the tests.

The main problem I encountered is I don't have a clean way to get the tokenizer's vocabulary - I'm handling the two common cases of replacing Ġ and with spaces, but there doesn't seem to be any universal method to get the actual string each token will yield. This will probably work for most tokenizers, though, and most stop strings don't contain spaces anyway.

@amyeroberts
Copy link
Collaborator

Nice! I'll let @gante review first to confirm it's all aligned with the current logic processors.

Just skimming my main comment is that we need tests for the criterion's methods, in particular get_matching_positions.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool PR 🔥

Extra request: let's add stop_strings to the GenerationConfig and, if it is set at generation time, let's add this stopping criteria in _get_stopping_criteria (pro tip: we can instantiate the tokenizer from the model repo attribute in the model instance, so we don't need to pass the tokenizer to generate :D). That way, users can:

  1. do model.generate(..., stop_strings=["foo", "bar"]), as opposed to model.generate(..., stopping_criteria=StoppingCriteriaList(...)), which is more user-friendly
  2. store their model's stop_strings in the generation config 💛


# Now we concatenate the match booleans across all strings and check if any are True
string_matches = torch.cat(string_matches, dim=0)
return torch.any(string_matches).item()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp raised a good point, about when to trigger the end of generation -- up until now, the stopping criteria behaved equally for all rows in the batch. This is the first one that can be True for some rows and False for others.

It should behave like the EOS token, where we continue generating until all rows reach some condition to stop. Finished rows keep adding the pad token until all rows are done. Now, here's the catch: this mechanism [per-row tracking of finished sequences] exists for the EOS token, but doesn't exist for the Stopping Criteria 😬

For now, let's keep as @Rocketknight1 added it (most users use with batch size = 1 anyways). As a follow up action let's: a) move the EOS logic to a StoppingCriteria; b) ensure all StoppingCriteria return a boolean array containing True in the rows that trigger the condition :D

stop_strings (`Union[str, List[str]]`):
A list of strings that should end generation. If a string is passed, it will be treated like a
list with a single element.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After we're done with most of the changes, can we add an example of how to use this StoppingCriteria with generate? 🙏 (Similar to the examples we have in the Logits processor class)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure which examples you mean! Can you link me to the code?

Copy link
Member

@gante gante Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, got it! Let me see if I can incorporate this into generate + the generation_config first, and then write those

Copy link
Member Author

@Rocketknight1 Rocketknight1 Mar 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! (With apologies for the delay when I got pulled away to work on other stuff)

The example now lives here. I tested it and it works well! You might have to expand stopping_criteria.py to see it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice example! Thank you for writing up this detailed docstring ❤️

src/transformers/generation/stopping_criteria.py Outdated Show resolved Hide resolved
src/transformers/generation/stopping_criteria.py Outdated Show resolved Hide resolved
src/transformers/generation/stopping_criteria.py Outdated Show resolved Hide resolved
src/transformers/generation/stopping_criteria.py Outdated Show resolved Hide resolved
tests/generation/test_stopping_criteria.py Show resolved Hide resolved
tests/generation/test_stopping_criteria.py Show resolved Hide resolved
@Rocketknight1
Copy link
Member Author

@amyeroberts those are purely internal methods - maybe I should just mark them as private with a leading _ instead?

@amyeroberts
Copy link
Collaborator

@Rocketknight1 Request for tests is to verify the logic rather than them being public or private. test_stop_string_criteria is good, but the logic in get_matching_positions is quite complex. I'd like for this to be properly covered so that:

  1. we can be certain this method is doing the right thing - not just all the pieces as a whole
  2. We can modify safely it if needed.

@Rocketknight1
Copy link
Member Author

@amyeroberts tests for the sub-methods are in!

@Rocketknight1 Rocketknight1 force-pushed the terminator_strings_for_generate branch from c4b90fa to 455259e Compare February 19, 2024 13:56
@Rocketknight1 Rocketknight1 force-pushed the terminator_strings_for_generate branch from cb74b51 to ba8c7d1 Compare February 20, 2024 12:04
@Rocketknight1
Copy link
Member Author

Quick update here: I refactored the initialization and added a small cache in case users repeatedly call generate(). Initialization time is down from 200ms to ~10ms in local testing. I also rewrote the core call so that all stop strings are tested in parallel - this results in a big speedup when we're testing more than 1 stop string.

We may still end up going with @amyeroberts' tokenizer-decode solution after profiling, but I wanted to make sure this method didn't have any obvious performance issues first ✊. I'll do some testing on Monday!

(also I still have a deep phobia of graph breaks from my XLA era)

@Rocketknight1
Copy link
Member Author

😌
image

@amyeroberts
Copy link
Collaborator

@Rocketknight1 Awesome ! Am I right in saying the profiling shows this is faster now?

You know the next question I'm going to ask is if it correctly handles when tokenizers have different splitting behaviour e.g. preprending with ## :)

@Rocketknight1
Copy link
Member Author

Rocketknight1 commented Feb 26, 2024

I'm working on that and highly confident¹ that a solution can be found.

¹ Not at all confident

@Rocketknight1 Rocketknight1 force-pushed the terminator_strings_for_generate branch from 3dd5dec to 25ef298 Compare February 26, 2024 14:13
@Rocketknight1
Copy link
Member Author

Quick update - tokenizer.convert_tokens_to_string() is actually a general method that does what we want, but it has the problem that it strips prefix spaces in tokenizer classes where prepending prefix spaces is the default, and I can't figure out how to make that stop! I think we can use it, but I might need a method where I always prepend with a fixed token before calling it or something to stop that behaviour.

@Rocketknight1 Rocketknight1 force-pushed the terminator_strings_for_generate branch from 25ef298 to 8c23e39 Compare February 26, 2024 16:32
@Rocketknight1
Copy link
Member Author

@zucchini-nlp The condition now returns a per-sample vector correctly. Can I be lazy and ask you to add the test for it that @amyeroberts was requesting in #29116 here? If you're too busy, don't worry, I'll get to it!

@zucchini-nlp
Copy link
Member

@Rocketknight1 yeah, I could add it but the PR is already merged. I hope @gante can make it clear and tell where the test goes

@Rocketknight1
Copy link
Member Author

@zucchini-nlp you can just add the test to this PR's branch instead!

@Rocketknight1
Copy link
Member Author

Rocketknight1 commented Feb 27, 2024

cc @amyeroberts @gante this PR now tests per-row stopping conditions from #29116, thanks to @zucchini-nlp. Tests are passing, so the feature looks good! I ran the slow tests locally as well.

@Rocketknight1 Rocketknight1 force-pushed the terminator_strings_for_generate branch from 568e73f to 8b52039 Compare April 11, 2024 15:39
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the work adding and iterating on this!

Comment on lines +187 to +189
end_overlaps = {token_list[idx]: overlaps for idx, overlaps in all_token_end_overlaps[stop_string].items()}
self.assertEqual(valid_positions, {"s": [3], "last": [2]})
self.assertEqual(end_overlaps, {"top": [3], "topper": [3], "p": [1]})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great ❤️

self.assertEqual(valid_positions, {"s": [3], "last": [2]})
self.assertEqual(end_overlaps, {"top": [3], "topper": [3], "p": [1]})

def test_stop_string_embedding_vecs(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

@Rocketknight1 Rocketknight1 merged commit 0d84901 into main Apr 22, 2024
22 checks passed
@Rocketknight1 Rocketknight1 deleted the terminator_strings_for_generate branch April 22, 2024 13:13
@Rocketknight1
Copy link
Member Author

No, thank you for all the patience fixing my horrifically verbose docstrings and incomprehensible tests, lol

zafstojano pushed a commit to zafstojano/transformers that referenced this pull request Apr 22, 2024
* stash commit (will discard all of this)

* stash commit

* First commit - needs a lot of testing!

* Add a test

* Fix imports and make the tests actually test something

* Tests pass!

* Rearrange test

* Add comments (but it's still a bit confusing)

* Stop storing the tokenizer

* Comment fixup

* Fix for input_ids with a single sequence

* Update tests to test single sequences

* make fixup

* Fix incorrect use of isin()

* Expand tests to catch more cases

* Expand tests to catch more cases

* make fixup

* Fix length calculation and update tests

* Handle Ġ as a space replacement too

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: Joao Gante <[email protected]>

* Add optimizations from Joao's suggestion

* Remove TODO

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: Joao Gante <[email protected]>

* Update tests/generation/test_stopping_criteria.py

Co-authored-by: Joao Gante <[email protected]>

* make fixup

* Rename some variables and remove some debugging clauses for clarity

* Add tests for the sub-methods

* Clarify one test slightly

* Add stop_strings to GenerationConfig

* generate() supports stop_string arg, asks for tokenizer if not provided

* make fixup

* Cleanup code and rename variables for clarity

* Update tokenizer error

* Update tokenizer passing, handle generation on GPU

* Slightly more explanation cleanup

* More comment cleanup

* Factor out the token cleanup so it's more obvious what we're doing, and we can change it later

* Careful with that cleanup!

* Cleanup + optimizations to _get_matching_positions

* More minor performance tweaks

* Implement caching and eliminate some expensive ops (startup time: 200ms -> 9ms)

* Remove the pin_memory call

* Parallelize across all stop strings!

* Quick fix for tensor devices

* Update embeddings test for the new format

* Fix test imports

* Manual patching for BERT-like tokenizers

* Return a bool vector instead of a single True/False

* Better comment

* Better comment

* Add tests from @zucchini-nlp

* Amy's list creation nit

* tok_list -> token_list

* Push a big expanded docstring (should we put it somewhere else?)

* Expand docstrings

* Docstring fixups

* Rebase

* make fixup

* Make a properly general method for figuring out token strings

* Fix naming throughout the functions

* Move cache, refactor, fix tests

* Add comment

* Remove finished TODO

* Remove finished TODO

* make fixup

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: amyeroberts <[email protected]>

* Update and shorten docstring

* Update tests to be shorter/clearer and test specific cases

---------

Co-authored-by: Joao Gante <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
itazap pushed a commit that referenced this pull request May 14, 2024
* stash commit (will discard all of this)

* stash commit

* First commit - needs a lot of testing!

* Add a test

* Fix imports and make the tests actually test something

* Tests pass!

* Rearrange test

* Add comments (but it's still a bit confusing)

* Stop storing the tokenizer

* Comment fixup

* Fix for input_ids with a single sequence

* Update tests to test single sequences

* make fixup

* Fix incorrect use of isin()

* Expand tests to catch more cases

* Expand tests to catch more cases

* make fixup

* Fix length calculation and update tests

* Handle Ġ as a space replacement too

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: Joao Gante <[email protected]>

* Add optimizations from Joao's suggestion

* Remove TODO

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: Joao Gante <[email protected]>

* Update tests/generation/test_stopping_criteria.py

Co-authored-by: Joao Gante <[email protected]>

* make fixup

* Rename some variables and remove some debugging clauses for clarity

* Add tests for the sub-methods

* Clarify one test slightly

* Add stop_strings to GenerationConfig

* generate() supports stop_string arg, asks for tokenizer if not provided

* make fixup

* Cleanup code and rename variables for clarity

* Update tokenizer error

* Update tokenizer passing, handle generation on GPU

* Slightly more explanation cleanup

* More comment cleanup

* Factor out the token cleanup so it's more obvious what we're doing, and we can change it later

* Careful with that cleanup!

* Cleanup + optimizations to _get_matching_positions

* More minor performance tweaks

* Implement caching and eliminate some expensive ops (startup time: 200ms -> 9ms)

* Remove the pin_memory call

* Parallelize across all stop strings!

* Quick fix for tensor devices

* Update embeddings test for the new format

* Fix test imports

* Manual patching for BERT-like tokenizers

* Return a bool vector instead of a single True/False

* Better comment

* Better comment

* Add tests from @zucchini-nlp

* Amy's list creation nit

* tok_list -> token_list

* Push a big expanded docstring (should we put it somewhere else?)

* Expand docstrings

* Docstring fixups

* Rebase

* make fixup

* Make a properly general method for figuring out token strings

* Fix naming throughout the functions

* Move cache, refactor, fix tests

* Add comment

* Remove finished TODO

* Remove finished TODO

* make fixup

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: amyeroberts <[email protected]>

* Update and shorten docstring

* Update tests to be shorter/clearer and test specific cases

---------

Co-authored-by: Joao Gante <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Conversational Pipeline returns <|im_end|> in the assistant's output.
5 participants