From 8b52039193b3a5064ca23c6824732a3de1d4b691 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 11 Apr 2024 16:39:01 +0100 Subject: [PATCH] Update and shorten docstring --- .../generation/stopping_criteria.py | 109 ++++++++---------- 1 file changed, 46 insertions(+), 63 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 985b356a1ad11c..5a42f474be2692 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -151,6 +151,7 @@ class StopStringCriteria(StoppingCriteria): - ["st", "opera"] - ["sto", "pper"] - ["las", "topper"] + - ["s", "to", "pped"] Note that a match will only be triggered if the stop string is at the end of the generated sequence. In other words, these sequences will not trigger a match: @@ -181,75 +182,57 @@ class StopStringCriteria(StoppingCriteria): - ["st", "opera"] (overlap is "op", overlap length == 2) - ["sto", "pper"] (overlap is "p", overlap length == 1) - ["las", "topper"] (overlap is "top", overlap length == 3) + - ["s", "to", "pped"] (overlap is "p", overlap length == 1) It's impossible to construct a matching sequence that does not have this property (feel free to verify this yourself). However, although this overlap between the start of the final token and the end of the stop string is necessary for a match, it is not sufficient. We also need to check that the rest of the token sequence is consistent with the stop string. - How do we do that? Let's say the stop string is N characters long, and the initial overlap covers the final - M characters. Then, we have N - M characters left to match. If the next token is less than N - M tokens long, then - the entire token must match: token == stop_string[-M - len(token): -M]. If the next token is longer than N - M - tokens, then we consider only the final N - M characters of the token. This allows for the token to have an overhang - off the start of the stop string. - - Again, let's make this concrete with a worked example. We'll use the stop string "stop" and the token sequence - ["las", "topper"]. The length of the stop string is 4. The final token is "topper", and its overlap with the stop - string is "top", which has length 3. We continue to the next token, "las", and we have 4 - 3 = 1 character left to - match. This is less than the length of "las", so we only need a partial match for this token to complete the string. - We check that "las"[-1:] == stop[:1], which is true. We have now matched 4 characters, which is the length of - the stop string, and we are done. - - At this point, hopefully you agree that we have an algorithm that detects the presence of a stop string, but you - may not see how we can convert this to tensor operations, particularly since we want to avoid data-dependent - conditional branching in the compiled code, and ideally vectorize everything so it can be efficiently computed on - GPU. The key is to realize that although we don't have access to string operations inside the generation loop, - we can use them in a precomputation stage! - - For every token in the tokenizer vocabulary, we precompute the values - we need for the above algorithm: The length of that token's overlap with the end of the stop string, the - position(s) in the stop string where that token matches perfectly, and the length of the token. We then pack - these values into a single vector per token, and stack those vectors into an embedding tensor which we can - gather from at runtime to get the values we need. - - This is the approach we take in this class. The precomputation is done in the `_stop_string_create_embedding_vec` - function. Then, at runtime in the `__call__()` method, we implement the algorithm above in purely vectorized - fashion, starting from an input_ids vector containing the token IDs in the sequence: - - - Gather from the embedding vector using input_ids as indices, and split the packed vectors into end overlap lengths, - valid token positions, and token lengths. - - Make a vector of the length of every token in the sequence, except for the final token, where we use the - end-overlap length instead. - - Compute the cumulative sum of the sequence, starting from the end. This represents the number of characters in the stop string that - we would match after each token, assuming that token is a valid fit for the sequence at that point. - - To determine if the tokens are valid at each position, we check that the cumulative length so far matches - one of the values in their valid positions vector. Where it does not, we mask that token and all tokens - preceding it. - - We then check the highest unmasked value in the cumulative sum. This represents the length of the total string - match before we reached a token that did not match the stop string. If it is equal to or greater than the length - of the stop string, the stop string is matched. - - This is almost the complete algorithm, and the remaining details just handle edge cases: For example, what do - we do if a token can have multiple possible overlap lengths with the stop string? For example, consider the - stop string "banana", and the token sequences ["ba", "nana"] and ["bana", "nana"]. Both of these sequences - contain the stop string and should trigger a match. However, the overlap of the final token is different! In - the first case, the overlap is "nana". In the second case, the overlap is "na". When we start from the end - of the sequence and work backwards, we cannot know in advance which overlap length, if any, will lead to a valid - match, and therefore we must test all possible overlap lengths. - - Therefore, for the stop string "banana" and the token "nana", we store two valid end overlap lengths: 2 and 4. - We then perform the above algorithm, starting from each value, and test whether each results in a match. - Thanks to vectorization, we can run these tests in parallel (in fact, we can run the test for every possible - overlap length and all stop strings in parallel). - - The second detail is how we handle cases when the token sequence has an overhang off the start of the stop string, - as in the case of ["las", "top"], since we do not store "start overlaps" in the same way we do for end overlaps. - Instead, we simply store (in the valid_positions vector) that the token "las" is valid before "top", in the same - way that the token "s" is. Therefore, the total length it computes in the case of ["las", "top"] is 6 rather than 4, - because it doesn't truncate the match to the length of the stop string. However, since the algorithm concludes by - checking that the maximum match length is equal to or greater than the length of the stop string, this does not - affect the correctness of its final answer; both ["las", "top"] with a total length of 6, and ["s", "top"] with a - total length of 4, will be correctly identified as matches, because both are >= 4. + How do we do that? Let's use ["s", "to", "pped"] as an example. We know that the final token, "pped", has an + overlap of 1 with the stop string, "stop". We then go back to the previous token, "to". Since we have already + matched 1 character from the stop string, the remainder to check is "sto". We check that the next token "to" + matches the end of the remainder, which it does. We have now matched 3 characters from the stop string, and the + remainder to match is "s". We go back to the previous token again, which is also "s". This is a match, and so + we have matched the entire stop string. + + How does it work when the tokens run off the start of the stop string, though? Let's consider the example of + ["las", "topper"]. The final token, "topper", has an overlap of 3 with the stop string, "stop". Therefore, + the remaining stop string to match is "s". We go back to the previous token, "las". Because the remainder to + match is just "s", with length 1, we consider only the final 1 character from the token, which is "s". This + matches the stop string, and so the entire string is matched. + + How do we compute these matches with tensor operations, though? Simply: we efficiently precompute the necessary + information for all tokens! For every token, we compute: + - Its overlap with the end of the stop string, if any + - The positions inside the stop string where the token matches, including matches that run off the start. + - The total length of the token + + For example, for the token "pped", we would compute an end overlap of 1, no internal matching positions, + and a length of 4. For the token "to", we would compute no end overlap, a single internal matching position + of 1 (counting from the end), and a length of 2. For the token "s", we would compute no end overlap, + a single internal matching position of 3 (again counting from the end) and a length of 1. + + As long as we have this information, we can execute the algorithm above without any string comparison + operations. We simply perform the following steps: + - Check if the final token has an end-overlap with the start string + - Continue backwards, keeping track of how much of the stop string we've matched so far + - At each point, check if the next token has the current position as one of its valid positions + - Continue until either a match fails, or we completely match the whole stop string + + Again, consider ["s", "to", "pped"] as an example. "pped" has an end overlap of 1, so we can begin a match. + We have matched 1 character so far, so we check that the next token "to", has 1 as a valid position (again, + counting from the end). It does, so we add the length of "to" to our position tracker. We have now matched + 3 characters, so we check that the next token "s" has 3 as a valid position. It does, so we add its length + to the position tracker. The position tracker is now 4, which is the length of the stop string. We have matched the + entire stop string. + + In the second case, ["las", "topper"], "topper" has an end overlap of 3, so we can begin a match. We have + matched 3 characters so far, so we check that the next token "las" has 3 as a valid position. It does, because we + allow tokens to match positions that run off the start of the stop string. We add its length to the position + tracker. The position tracker is now 6, which is greater than the length of the stop string! Don't panic, though - + this also counts as a match of the stop string. We have matched the entire stop string. + Args: tokenizer (`PreTrainedTokenizer`):