diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 6653f3c8d123e9..a669d6ed0659cf 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -86,6 +86,7 @@ "StoppingCriteria", "StoppingCriteriaList", "validate_stopping_criteria", + "StopStringCriteria", ] _import_structure["utils"] = [ "GenerationMixin", @@ -224,6 +225,7 @@ MaxTimeCriteria, StoppingCriteria, StoppingCriteriaList, + StopStringCriteria, validate_stopping_criteria, ) from .utils import ( diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index f40960c213ea67..4504b8b183941b 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -115,6 +115,8 @@ class GenerationConfig(PushToHubMixin): max_time(`float`, *optional*): The maximum amount of time you allow the computation to run for in seconds. generation will still finish the current pass after allocated time has been passed. + stop_strings(`str or List[str]`, *optional*): + A string or a list of strings that should terminate generation if the model outputs them. > Parameters that control the generation strategy used @@ -306,6 +308,7 @@ def __init__(self, **kwargs): self.min_new_tokens = kwargs.pop("min_new_tokens", None) self.early_stopping = kwargs.pop("early_stopping", False) self.max_time = kwargs.pop("max_time", None) + self.stop_strings = kwargs.pop("stop_strings", None) # Parameters that control the generation strategy used self.do_sample = kwargs.pop("do_sample", False) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index bac537b71b96ec..5a42f474be2692 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -1,15 +1,22 @@ import time import warnings from abc import ABC +from collections import OrderedDict from copy import deepcopy -from typing import List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union +import numpy as np import torch +from torch.nn import functional as F +from ..tokenization_utils_base import PreTrainedTokenizerBase from ..utils import add_start_docstrings, logging logger = logging.get_logger(__name__) +# We maintain a module-level cache of the embedding vectors for the stop string criterion +# because they are slow to compute +STOP_STRING_EMBEDDING_CACHE = OrderedDict() STOPPING_CRITERIA_INPUTS_DOCSTRING = r""" @@ -129,6 +136,334 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool) +class StopStringCriteria(StoppingCriteria): + """ + This class can be used to stop generation whenever specific string sequences are generated. It preprocesses + the strings together with the tokenizer vocab to find positions where tokens can validly complete the stop strings. + + Generation is stopped as soon as a token is generated that completes any of the stop strings. + We want to catch any instance in which the stop string would be present in the decoded output, which means + we must also catch cases with "overhangs" off one or both ends. To make this more concrete, for the stop string + "stop", any of the following token sequences would trigger the match: + + - ["st", "op"] + - ["stop"] + - ["st", "opera"] + - ["sto", "pper"] + - ["las", "topper"] + - ["s", "to", "pped"] + + Note that a match will only be triggered if the stop string is at the end of the generated sequence. In other + words, these sequences will not trigger a match: + + - ["stop", "at"] + - ["st", "op", "at"] + - ["st", "opera", "tion"] + + The reason these are not a match is that the stop string does not overlap with the final token. If you can remove + one or more tokens from the end of the sequence without destroying the stop string, then this criterion will not + match that stop string. This is by design; because this check is run after each token is generated, we can't miss a + valid stop string if one is generated, but we don't want to halt generation just because the stop string exists + somewhere in the past input_ids. + + How is the match actually performed, though? We do it in quite a confusing way, because we want the entire match + process to be compilable with Torch or XLA, which means we cannot use standard string methods. However, it is possible, + with some work, to do string matching with pure tensor operations. We'll begin by describing the algorithm we use + with standard string operations, and then at the end we'll explain how this is converted to pure tensor operations. + + The key to the algorithm is an observation: Because the stop string must overlap with the end of the token sequence, we can start at + the end of the sequence and work backwards. Specifically, we check that there is an overlap between the start of + the final token and the end of the stop_string, or to put it another way, stop_string[-i:] == token[:i] for + some i > 0. If you look at the positive examples above, you'll see the last token in all of them fulfills this + property: + + - ["st", "op"] (overlap is "op", overlap length == 2) + - ["stop"] (overlap is "stop", overlap length == 4) + - ["st", "opera"] (overlap is "op", overlap length == 2) + - ["sto", "pper"] (overlap is "p", overlap length == 1) + - ["las", "topper"] (overlap is "top", overlap length == 3) + - ["s", "to", "pped"] (overlap is "p", overlap length == 1) + + It's impossible to construct a matching sequence that does not have this property (feel free to verify this + yourself). However, although this overlap between the start of the final token and the end of the stop string is + necessary for a match, it is not sufficient. We also need to check that the rest of the token sequence is + consistent with the stop string. + + How do we do that? Let's use ["s", "to", "pped"] as an example. We know that the final token, "pped", has an + overlap of 1 with the stop string, "stop". We then go back to the previous token, "to". Since we have already + matched 1 character from the stop string, the remainder to check is "sto". We check that the next token "to" + matches the end of the remainder, which it does. We have now matched 3 characters from the stop string, and the + remainder to match is "s". We go back to the previous token again, which is also "s". This is a match, and so + we have matched the entire stop string. + + How does it work when the tokens run off the start of the stop string, though? Let's consider the example of + ["las", "topper"]. The final token, "topper", has an overlap of 3 with the stop string, "stop". Therefore, + the remaining stop string to match is "s". We go back to the previous token, "las". Because the remainder to + match is just "s", with length 1, we consider only the final 1 character from the token, which is "s". This + matches the stop string, and so the entire string is matched. + + How do we compute these matches with tensor operations, though? Simply: we efficiently precompute the necessary + information for all tokens! For every token, we compute: + - Its overlap with the end of the stop string, if any + - The positions inside the stop string where the token matches, including matches that run off the start. + - The total length of the token + + For example, for the token "pped", we would compute an end overlap of 1, no internal matching positions, + and a length of 4. For the token "to", we would compute no end overlap, a single internal matching position + of 1 (counting from the end), and a length of 2. For the token "s", we would compute no end overlap, + a single internal matching position of 3 (again counting from the end) and a length of 1. + + As long as we have this information, we can execute the algorithm above without any string comparison + operations. We simply perform the following steps: + - Check if the final token has an end-overlap with the start string + - Continue backwards, keeping track of how much of the stop string we've matched so far + - At each point, check if the next token has the current position as one of its valid positions + - Continue until either a match fails, or we completely match the whole stop string + + Again, consider ["s", "to", "pped"] as an example. "pped" has an end overlap of 1, so we can begin a match. + We have matched 1 character so far, so we check that the next token "to", has 1 as a valid position (again, + counting from the end). It does, so we add the length of "to" to our position tracker. We have now matched + 3 characters, so we check that the next token "s" has 3 as a valid position. It does, so we add its length + to the position tracker. The position tracker is now 4, which is the length of the stop string. We have matched the + entire stop string. + + In the second case, ["las", "topper"], "topper" has an end overlap of 3, so we can begin a match. We have + matched 3 characters so far, so we check that the next token "las" has 3 as a valid position. It does, because we + allow tokens to match positions that run off the start of the stop string. We add its length to the position + tracker. The position tracker is now 6, which is greater than the length of the stop string! Don't panic, though - + this also counts as a match of the stop string. We have matched the entire stop string. + + + Args: + tokenizer (`PreTrainedTokenizer`): + The model's associated tokenizer (necessary to extract vocab and tokenize the termination sequences) + stop_strings (`Union[str, List[str]]`): + A list of strings that should end generation. If a string is passed, it will be treated like a + list with a single element. + + Examples: + + ```python + >>> from transformers import AutoModelForCausalLM, AutoTokenizer + + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") + >>> model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2") + >>> inputs = tokenizer("The biggest states in the USA by land area:", return_tensors="pt") + + >>> gen_out = model.generate(**inputs) + >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0]) + The biggest states in the USA by land area: + - Alaska + - Texas + - California + + >>> # Passing one or more stop strings will halt generation after those strings are emitted + >>> # Note that generating with stop strings requires you to pass the tokenizer too + >>> gen_out = model.generate(**inputs, stop_strings=["Texas"], tokenizer=tokenizer) + >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0]) + The biggest states in the USA by land area: + - Alaska + - Texas + ``` + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, List[str]]): + if isinstance(stop_strings, str): + stop_strings = [stop_strings] + self.stop_strings: Tuple[str, ...] = tuple(stop_strings) + vocab = tokenizer.get_vocab() + token_list, token_indices = tuple(vocab.keys()), tuple(vocab.values()) + self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = self.clean_and_embed_tokens_with_cache( + token_list, token_indices, self.stop_strings, tokenizer + ) + + self.maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings]) + self.num_stop_strings = len(self.stop_strings) + self.target_lens = torch.tensor([len(stop_string) for stop_string in stop_strings], dtype=torch.int32) + + def clean_and_embed_tokens_with_cache(self, token_list, token_indices, stop_strings, tokenizer): + # We don't use the tokenizer in the cache key, because I don't trust it to have well-behaved equality + if (token_list, token_indices, stop_strings) in STOP_STRING_EMBEDDING_CACHE: + embedding_vec, max_valid_positions, max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE[ + (token_list, token_indices, self.stop_strings) + ] + STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, stop_strings)) + else: + clean_token_list, clean_token_indices = self.clean_tokenizer_vocab(tokenizer) + embedding_vec, max_valid_positions, max_valid_end_lens = self._stop_string_create_embedding_vec( + clean_token_list, clean_token_indices, stop_strings + ) + STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, stop_strings)] = ( + embedding_vec, + max_valid_positions, + max_valid_end_lens, + ) + if len(STOP_STRING_EMBEDDING_CACHE) > 8: + STOP_STRING_EMBEDDING_CACHE.popitem(last=False) # Pop from the start, the least recently used item + return embedding_vec, max_valid_positions, max_valid_end_lens + + @staticmethod + def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef"): + """ + This method turns a tokenizer vocab into a "clean" vocab where each token represents the actual string + it will yield, without any special prefixes like "##" or "Ġ". This is trickier than it looks - the method + tokenizer.convert_tokens_to_string() does not always return the correct string because of issues with prefix + space addition/removal. To work around this, we add a static prefix to the start of the token, then remove + it (and any prefix that may have been introduced with it) after calling convert_tokens_to_string(). + """ + vocab = tokenizer.get_vocab() + clean_token_list = [] + clean_token_indices = [] + sentence_base = tokenizer(static_prefix, add_special_tokens=False)["input_ids"] + tokens_base = [tokenizer._convert_id_to_token(tok) for tok in sentence_base] + for token, token_idx in vocab.items(): + token_string = tokenizer.convert_tokens_to_string(tokens_base + [token]) + token_string = token_string[token_string.index(static_prefix) + len(static_prefix) :] + clean_token_list.append(token_string) + clean_token_indices.append(token_idx) + return tuple(clean_token_list), tuple(clean_token_indices) + + @staticmethod + def _stop_string_get_matching_positions( + token_list, token_indices, stop_strings + ) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]: + """This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can + validly appear in the stop strings. For each token, it computes a list of positions in the stop string where the + token appears, as well as a list of the possible "end overlaps" for that token - that is, the number of characters + from the end of the stop string that overlap with the start of the token, which can have more than one value. + + The reason for computing these may seem a bit cryptic - please see the docstring for StopStringCriteria for a full + explanation of what these values are for!""" + + token_valid_positions = {} + token_end_overlaps = {} + for stop_string in stop_strings: + reversed_stop_string = stop_string[::-1] + token_valid_positions[stop_string] = {} + token_end_overlaps[stop_string] = {} + for token, tok_idx in zip(token_list, token_indices): + reversed_token = token[::-1] + matching_positions = [] + possible_end_lengths = [] + for i in range(1 - len(token), len(stop_string)): + if i < 0: + tok = reversed_token[-i:] + i = 0 + else: + tok = reversed_token + stop = reversed_stop_string[i : i + len(tok)] + if tok.startswith(stop): + if i == 0: + possible_end_lengths.append(min(len(tok), len(stop))) + else: + matching_positions.append(i) + + if matching_positions: + token_valid_positions[stop_string][tok_idx] = matching_positions + if possible_end_lengths: + token_end_overlaps[stop_string][tok_idx] = possible_end_lengths + return token_valid_positions, token_end_overlaps + + @staticmethod + def _stop_string_create_embedding_vec(token_list, token_indices, stop_strings) -> Dict[str, torch.tensor]: + """This function precomputes everything needed for the run-time checks in StopStringCriteria, and packs + them into an embedding tensor that can be accessed with pure tensor operations. For the specifics of the values + that are precomputed and what they are used for, please refer to the StopStringCriteria docstring!""" + token_valid_positions, token_end_overlaps = StopStringCriteria._stop_string_get_matching_positions( + token_list, token_indices, stop_strings + ) + + max_valid_positions = max( + len(val) for positions in token_valid_positions.values() for val in positions.values() + ) + max_valid_end_lens = max(len(val) for positions in token_end_overlaps.values() for val in positions.values()) + vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1 + gather_vec = np.full((len(token_list), vec_size), dtype=np.int32, fill_value=-1) + + for i, stop_string in enumerate(stop_strings): + positions = token_valid_positions[stop_string] + end_lens = token_end_overlaps[stop_string] + + # Since this is lots of very small assignments of lists, we build it with numpy rather + # than torch for speed + simplicity, then convert to torch at the end + for token_idx, valid_positions in positions.items(): + gather_vec[ + token_idx, max_valid_positions * i : max_valid_positions * i + len(valid_positions) + ] = valid_positions + for token_idx, possible_end_lens in end_lens.items(): + gather_vec[ + token_idx, + max_valid_positions * len(stop_strings) + max_valid_end_lens * i : max_valid_positions + * len(stop_strings) + + max_valid_end_lens * i + + len(possible_end_lens), + ] = possible_end_lens + for token, token_idx in zip(token_list, token_indices): + gather_vec[token_idx, -1] = len(token) + + gather_vec = torch.tensor(gather_vec, dtype=torch.int32) + + return gather_vec, max_valid_positions, max_valid_end_lens + + @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.Tensor: + self.embedding_vec = self.embedding_vec.to(input_ids.device) + self.target_lens = self.target_lens.to(input_ids.device) + # The maximum length we need to consider is 1 token per character. Note that input_ids can also be + # *shorter* than the global max, and the code below should be ready for that + input_ids = input_ids[:, -self.maximum_token_len :] + + # Flip input_ids because we're only matching strings at the end of the generated sequence + flipped_ids = torch.flip(input_ids, (1,)) + + # Size of the vector of positions a single token can match + max_valid_positions = self.max_valid_positions + + # The embedding vec contains the valid positions, end_lengths and total lengths for each token + embedded = F.embedding(flipped_ids, self.embedding_vec) + + # Now we split the embedding vector. valid_positions is the positions in the stop string the token can fit + valid_positions = embedded[:, 1:, : max_valid_positions * self.num_stop_strings].unflatten( + -1, (self.num_stop_strings, -1) + ) + # end_lengths is the number of characters from the string, counting from the end, that the token + # contains. It can have multiple values if the same token can overlap different end lengths + end_lengths = embedded[:, :1, max_valid_positions * self.num_stop_strings : -1].unflatten( + -1, (self.num_stop_strings, -1) + ) + # Lengths is the total length of each token. Unlike the others, it always has a single value + lengths = embedded[:, 1:, None, -1:] # Insert a dummy dimension for stop_strings even though lengths are const + + # Concatenate lengths onto each possible end_lengths value + lengths = lengths.expand((-1, -1, end_lengths.shape[-2], end_lengths.shape[-1])) + lengths_with_ends = torch.cat([end_lengths, lengths], dim=1) + + # cumsum() to get the number of matched characters in the stop string after each token + cumsum = lengths_with_ends.cumsum(dim=1) # B x maximum_token_len x num_stop_strings x max_valid_end_lens + + # The calculation above assumes that all tokens are in valid positions. Now we mask the ones that are not. + # First, tokens match the start of the string if they have a positive value in the end_lengths vector + initial_match = end_lengths > 0 + + # Tokens continue the string if the cumsum() so far is one of the valid positions for that token + # Note that we're actually tracking one cumsum() for for each possible end_length + later_match = torch.any(cumsum[:, :-1, :, None] == valid_positions[:, :, :, :, None], axis=-2) + + # The match vector is a boolean vector that indicates which positions have valid tokens + match = torch.cat([initial_match, later_match], dim=1) + + # Once a single position does not match, all positions following that position are masked + mask = (~match).cumsum(dim=1, dtype=torch.int32) + mask = mask == 0 + + # The string is matched if we reached a cumsum equal to or greater than the length of the string + # before hitting the mask + string_matches = torch.amax(cumsum * mask, dim=(1, -1)) >= self.target_lens[None, :] + + # We return a per-sample vector that is True if any stop string is matched for that sample + return torch.any(string_matches, dim=-1) + + class EosTokenCriteria(StoppingCriteria): """ This class can be used to stop generation whenever the "end-of-sequence" token is generated. diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 36e62794a435a1..2a046196d85046 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -80,12 +80,14 @@ MaxTimeCriteria, StoppingCriteria, StoppingCriteriaList, + StopStringCriteria, validate_stopping_criteria, ) if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel + from ..tokenization_utils_base import PreTrainedTokenizerBase from .streamers import BaseStreamer logger = logging.get_logger(__name__) @@ -881,7 +883,11 @@ def _get_logits_processor( return processors def _get_stopping_criteria( - self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList] + self, + generation_config: GenerationConfig, + stopping_criteria: Optional[StoppingCriteriaList], + tokenizer: Optional["PreTrainedTokenizerBase"] = None, + **kwargs, ) -> StoppingCriteriaList: criteria = StoppingCriteriaList() if generation_config.max_length is not None: @@ -894,6 +900,14 @@ def _get_stopping_criteria( ) if generation_config.max_time is not None: criteria.append(MaxTimeCriteria(max_time=generation_config.max_time)) + if generation_config.stop_strings is not None: + if tokenizer is None: + raise ValueError( + "There are one or more stop strings, either in the arguments to `generate` or in the " + "model's generation config, but we could not locate a tokenizer. When generating with " + "stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`." + ) + criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer)) if generation_config.eos_token_id is not None: criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id)) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) @@ -1376,6 +1390,7 @@ def generate( """ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() + tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) self._validate_model_kwargs(model_kwargs.copy()) @@ -1385,6 +1400,7 @@ def generate( synced_gpus = True else: synced_gpus = False + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() @@ -1527,7 +1543,7 @@ def generate( # 9. prepare stopping criteria prepared_stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria + generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs ) # 10. go into different generation modes if generation_mode == GenerationMode.ASSISTED_GENERATION: diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 0c770972a7fdff..1a22491b9aa0f6 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -16,7 +16,7 @@ import time import unittest -from transformers import is_torch_available +from transformers import AutoTokenizer, is_torch_available from transformers.testing_utils import require_torch, torch_device from ..test_modeling_common import ids_tensor @@ -31,6 +31,7 @@ MaxNewTokensCriteria, MaxTimeCriteria, StoppingCriteriaList, + StopStringCriteria, validate_stopping_criteria, ) @@ -124,3 +125,134 @@ def test_validate_stopping_criteria(self): stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11) self.assertEqual(len(stopping_criteria), 1) + + def test_stop_string_criteria(self): + true_strings = [ + "<|im_start|><|im_end|>", + "<|im_start|><|im_end|<|im_end|>", + ">><|im_start|>>stop", + "stop", + "e nd", + ] + false_strings = [ + "<|im_start|><|im_end|", + "<|im_start|><|im_end|<|im_end|", + "<|im_end|><|im_start|>", + "<|im_end|<>stop<|im_end|", + "end", + "en d", + "eNd", + "<|im_end|", + "|im_end|>", + "s", + ] + stop_strings = ["<|im_end|>", "stop", "e nd"] + + # Use a tokenizer that won't actually have special tokens for these + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + tokenizer.pad_token_id = tokenizer.eos_token_id + tokenizer.padding_side = "left" + true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False) + false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False) + + scores = None + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) + for i in range(len(true_strings)): + self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores)) + for i in range(len(false_strings)): + self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores)) + + # Now try it with a tokenizer where those are actually special tokens + tokenizer = AutoTokenizer.from_pretrained("cognitivecomputations/dolphin-2.5-mixtral-8x7b") + tokenizer.padding_side = "left" + true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False) + false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False) + + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) + for i in range(len(true_strings)): + self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores)) + for i in range(len(false_strings)): + self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores)) + + def test_stop_string_matching_positions(self): + stop_string = "stop" + token_list = ["last", "top", "topper", "s", "p"] + token_indices = list(range(len(token_list))) + all_token_valid_positions, all_token_end_overlaps = StopStringCriteria._stop_string_get_matching_positions( + token_list=token_list, token_indices=token_indices, stop_strings=[stop_string] + ) + valid_positions = { + token_list[idx]: positions for idx, positions in all_token_valid_positions[stop_string].items() + } + end_overlaps = {token_list[idx]: overlaps for idx, overlaps in all_token_end_overlaps[stop_string].items()} + self.assertEqual(valid_positions, {"s": [3], "last": [2]}) + self.assertEqual(end_overlaps, {"top": [3], "topper": [3], "p": [1]}) + + def test_stop_string_embedding_vecs(self): + stop_string = "stop" + token_list = ["last", "top", "topper", "s", "p"] + token_indices = list(range(len(token_list))) + embedding_vec, max_valid_positions, max_valid_end_lens = StopStringCriteria._stop_string_create_embedding_vec( + token_list=token_list, token_indices=token_indices, stop_strings=[stop_string] + ) + + # Positions inside the stop string where the token matches (excluding end overlaps) + valid_positions = embedding_vec[:, 0].tolist() + self.assertEqual(valid_positions, [2, -1, -1, 3, -1]) + + # Overlap lengths between end of stop string and start of token + end_overlaps = embedding_vec[:, 1].tolist() + self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1]) + + # Length of each token + token_lengths = embedding_vec[:, 2].tolist() + self.assertEqual(token_lengths, [len(token) for token in token_list]) + + def test_criterias_per_row(self): + text = "They completed the challenging puzzle, revealing the hidden image at the end" + stop_strings = ["end"] + + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + tokenizer.pad_token_id = tokenizer.eos_token_id + inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False) + + scores = None + criteria = StoppingCriteriaList( + [ + MaxLengthCriteria(max_length=20), + StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings), + ] + ) + + # trigger stopping when at leat one criteria is satisfied, one value per batch + self.assertTrue(criteria(inputs["input_ids"], scores)) + + # return False when neither is satisfied + self.assertFalse(criteria(inputs["input_ids"][:, :-1], scores)) + + def test_criterias_per_row_batched(self): + text = [ + "They completed the challenging puzzle, revealing the hidden image at the end", + "Today a dragon flew over France", + "The aroma of freshly baked pizza filled the kitchen", + ] + stop_strings = ["end"] + + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + tokenizer.pad_token_id = tokenizer.eos_token_id + tokenizer.padding_side = "left" + inputs = tokenizer(text, return_tensors="pt", padding="longest", add_special_tokens=False) + + scores = None + criteria = StoppingCriteriaList( + [ + MaxLengthCriteria(max_length=20), + StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings), + ] + ) + + # trigger stopping when at leat one criteria is satisfied + self.assertListEqual(criteria(inputs["input_ids"], scores).tolist(), [True, False, False]) + + # False when neither is satisfied + self.assertListEqual(criteria(inputs["input_ids"][:, :-1], scores).tolist(), [False, False, False]) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index d6b4840c4910c9..37b9d3dba5ec0b 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2355,6 +2355,43 @@ def test_constrained_beam_search_example_integration(self): self.assertListEqual(outputs, ["Wie alt sind Sie?"]) + @slow + def test_per_row_stopping_criteria(self): + text = [ + "They completed the challenging puzzle, revealing the hidden", + "Today a dragon flew over France", + "The aroma of freshly baked pizza filled the kitchen", + ] + stop_strings = ["secrets"] + + model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + tokenizer.padding_side = "left" + tokenizer.pad_token_id = tokenizer.eos_token_id + input_ids = tokenizer(text, return_tensors="pt", padding="longest", add_special_tokens=False).input_ids.to( + torch_device + ) + + # normal generation with one stopping criteria + out = model.generate(input_ids, max_length=15) + out_text = tokenizer.batch_decode(out) + expected_out = [ + "They completed the challenging puzzle, revealing the hidden secrets of the world.\n", + "<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced", + "The aroma of freshly baked pizza filled the kitchen with a sense of freshness", + ] + self.assertListEqual(out_text, expected_out) + + # generation should stop at "secrets" for first batch only, filling the rest with eos tokens + out = model.generate(input_ids, max_length=15, stop_strings=stop_strings, tokenizer=tokenizer) + out_text = tokenizer.batch_decode(out) + expected_out = [ + "They completed the challenging puzzle, revealing the hidden secrets<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", + "<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced", + "The aroma of freshly baked pizza filled the kitchen with a sense of freshness", + ] + self.assertListEqual(out_text, expected_out) + def test_constrained_beam_search_mixin_type_checks(self): # PT-only test: TF doesn't have constrained beam search tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random")