Skip to content

Commit

Permalink
Terminator strings for generate() (#28932)
Browse files Browse the repository at this point in the history
* stash commit (will discard all of this)

* stash commit

* First commit - needs a lot of testing!

* Add a test

* Fix imports and make the tests actually test something

* Tests pass!

* Rearrange test

* Add comments (but it's still a bit confusing)

* Stop storing the tokenizer

* Comment fixup

* Fix for input_ids with a single sequence

* Update tests to test single sequences

* make fixup

* Fix incorrect use of isin()

* Expand tests to catch more cases

* Expand tests to catch more cases

* make fixup

* Fix length calculation and update tests

* Handle Ġ as a space replacement too

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: Joao Gante <[email protected]>

* Add optimizations from Joao's suggestion

* Remove TODO

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: Joao Gante <[email protected]>

* Update tests/generation/test_stopping_criteria.py

Co-authored-by: Joao Gante <[email protected]>

* make fixup

* Rename some variables and remove some debugging clauses for clarity

* Add tests for the sub-methods

* Clarify one test slightly

* Add stop_strings to GenerationConfig

* generate() supports stop_string arg, asks for tokenizer if not provided

* make fixup

* Cleanup code and rename variables for clarity

* Update tokenizer error

* Update tokenizer passing, handle generation on GPU

* Slightly more explanation cleanup

* More comment cleanup

* Factor out the token cleanup so it's more obvious what we're doing, and we can change it later

* Careful with that cleanup!

* Cleanup + optimizations to _get_matching_positions

* More minor performance tweaks

* Implement caching and eliminate some expensive ops (startup time: 200ms -> 9ms)

* Remove the pin_memory call

* Parallelize across all stop strings!

* Quick fix for tensor devices

* Update embeddings test for the new format

* Fix test imports

* Manual patching for BERT-like tokenizers

* Return a bool vector instead of a single True/False

* Better comment

* Better comment

* Add tests from @zucchini-nlp

* Amy's list creation nit

* tok_list -> token_list

* Push a big expanded docstring (should we put it somewhere else?)

* Expand docstrings

* Docstring fixups

* Rebase

* make fixup

* Make a properly general method for figuring out token strings

* Fix naming throughout the functions

* Move cache, refactor, fix tests

* Add comment

* Remove finished TODO

* Remove finished TODO

* make fixup

* Update src/transformers/generation/stopping_criteria.py

Co-authored-by: amyeroberts <[email protected]>

* Update and shorten docstring

* Update tests to be shorter/clearer and test specific cases

---------

Co-authored-by: Joao Gante <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
3 people authored and Ita Zaporozhets committed May 14, 2024
1 parent 4005748 commit 55a6ad3
Show file tree
Hide file tree
Showing 6 changed files with 529 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/transformers/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"StoppingCriteria",
"StoppingCriteriaList",
"validate_stopping_criteria",
"StopStringCriteria",
]
_import_structure["utils"] = [
"GenerationMixin",
Expand Down Expand Up @@ -224,6 +225,7 @@
MaxTimeCriteria,
StoppingCriteria,
StoppingCriteriaList,
StopStringCriteria,
validate_stopping_criteria,
)
from .utils import (
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class GenerationConfig(PushToHubMixin):
max_time(`float`, *optional*):
The maximum amount of time you allow the computation to run for in seconds. generation will still finish
the current pass after allocated time has been passed.
stop_strings(`str or List[str]`, *optional*):
A string or a list of strings that should terminate generation if the model outputs them.
> Parameters that control the generation strategy used
Expand Down Expand Up @@ -306,6 +308,7 @@ def __init__(self, **kwargs):
self.min_new_tokens = kwargs.pop("min_new_tokens", None)
self.early_stopping = kwargs.pop("early_stopping", False)
self.max_time = kwargs.pop("max_time", None)
self.stop_strings = kwargs.pop("stop_strings", None)

# Parameters that control the generation strategy used
self.do_sample = kwargs.pop("do_sample", False)
Expand Down
Loading

0 comments on commit 55a6ad3

Please sign in to comment.