Skip to content

Commit

Permalink
Update and shorten docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Apr 11, 2024
1 parent 19df6a8 commit 8b52039
Showing 1 changed file with 46 additions and 63 deletions.
109 changes: 46 additions & 63 deletions src/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class StopStringCriteria(StoppingCriteria):
- ["st", "opera"]
- ["sto", "pper"]
- ["las", "topper"]
- ["s", "to", "pped"]
Note that a match will only be triggered if the stop string is at the end of the generated sequence. In other
words, these sequences will not trigger a match:
Expand Down Expand Up @@ -181,75 +182,57 @@ class StopStringCriteria(StoppingCriteria):
- ["st", "opera"] (overlap is "op", overlap length == 2)
- ["sto", "pper"] (overlap is "p", overlap length == 1)
- ["las", "topper"] (overlap is "top", overlap length == 3)
- ["s", "to", "pped"] (overlap is "p", overlap length == 1)
It's impossible to construct a matching sequence that does not have this property (feel free to verify this
yourself). However, although this overlap between the start of the final token and the end of the stop string is
necessary for a match, it is not sufficient. We also need to check that the rest of the token sequence is
consistent with the stop string.
How do we do that? Let's say the stop string is N characters long, and the initial overlap covers the final
M characters. Then, we have N - M characters left to match. If the next token is less than N - M tokens long, then
the entire token must match: token == stop_string[-M - len(token): -M]. If the next token is longer than N - M
tokens, then we consider only the final N - M characters of the token. This allows for the token to have an overhang
off the start of the stop string.
Again, let's make this concrete with a worked example. We'll use the stop string "stop" and the token sequence
["las", "topper"]. The length of the stop string is 4. The final token is "topper", and its overlap with the stop
string is "top", which has length 3. We continue to the next token, "las", and we have 4 - 3 = 1 character left to
match. This is less than the length of "las", so we only need a partial match for this token to complete the string.
We check that "las"[-1:] == stop[:1], which is true. We have now matched 4 characters, which is the length of
the stop string, and we are done.
At this point, hopefully you agree that we have an algorithm that detects the presence of a stop string, but you
may not see how we can convert this to tensor operations, particularly since we want to avoid data-dependent
conditional branching in the compiled code, and ideally vectorize everything so it can be efficiently computed on
GPU. The key is to realize that although we don't have access to string operations inside the generation loop,
we can use them in a precomputation stage!
For every token in the tokenizer vocabulary, we precompute the values
we need for the above algorithm: The length of that token's overlap with the end of the stop string, the
position(s) in the stop string where that token matches perfectly, and the length of the token. We then pack
these values into a single vector per token, and stack those vectors into an embedding tensor which we can
gather from at runtime to get the values we need.
This is the approach we take in this class. The precomputation is done in the `_stop_string_create_embedding_vec`
function. Then, at runtime in the `__call__()` method, we implement the algorithm above in purely vectorized
fashion, starting from an input_ids vector containing the token IDs in the sequence:
- Gather from the embedding vector using input_ids as indices, and split the packed vectors into end overlap lengths,
valid token positions, and token lengths.
- Make a vector of the length of every token in the sequence, except for the final token, where we use the
end-overlap length instead.
- Compute the cumulative sum of the sequence, starting from the end. This represents the number of characters in the stop string that
we would match after each token, assuming that token is a valid fit for the sequence at that point.
- To determine if the tokens are valid at each position, we check that the cumulative length so far matches
one of the values in their valid positions vector. Where it does not, we mask that token and all tokens
preceding it.
- We then check the highest unmasked value in the cumulative sum. This represents the length of the total string
match before we reached a token that did not match the stop string. If it is equal to or greater than the length
of the stop string, the stop string is matched.
This is almost the complete algorithm, and the remaining details just handle edge cases: For example, what do
we do if a token can have multiple possible overlap lengths with the stop string? For example, consider the
stop string "banana", and the token sequences ["ba", "nana"] and ["bana", "nana"]. Both of these sequences
contain the stop string and should trigger a match. However, the overlap of the final token is different! In
the first case, the overlap is "nana". In the second case, the overlap is "na". When we start from the end
of the sequence and work backwards, we cannot know in advance which overlap length, if any, will lead to a valid
match, and therefore we must test all possible overlap lengths.
Therefore, for the stop string "banana" and the token "nana", we store two valid end overlap lengths: 2 and 4.
We then perform the above algorithm, starting from each value, and test whether each results in a match.
Thanks to vectorization, we can run these tests in parallel (in fact, we can run the test for every possible
overlap length and all stop strings in parallel).
The second detail is how we handle cases when the token sequence has an overhang off the start of the stop string,
as in the case of ["las", "top"], since we do not store "start overlaps" in the same way we do for end overlaps.
Instead, we simply store (in the valid_positions vector) that the token "las" is valid before "top", in the same
way that the token "s" is. Therefore, the total length it computes in the case of ["las", "top"] is 6 rather than 4,
because it doesn't truncate the match to the length of the stop string. However, since the algorithm concludes by
checking that the maximum match length is equal to or greater than the length of the stop string, this does not
affect the correctness of its final answer; both ["las", "top"] with a total length of 6, and ["s", "top"] with a
total length of 4, will be correctly identified as matches, because both are >= 4.
How do we do that? Let's use ["s", "to", "pped"] as an example. We know that the final token, "pped", has an
overlap of 1 with the stop string, "stop". We then go back to the previous token, "to". Since we have already
matched 1 character from the stop string, the remainder to check is "sto". We check that the next token "to"
matches the end of the remainder, which it does. We have now matched 3 characters from the stop string, and the
remainder to match is "s". We go back to the previous token again, which is also "s". This is a match, and so
we have matched the entire stop string.
How does it work when the tokens run off the start of the stop string, though? Let's consider the example of
["las", "topper"]. The final token, "topper", has an overlap of 3 with the stop string, "stop". Therefore,
the remaining stop string to match is "s". We go back to the previous token, "las". Because the remainder to
match is just "s", with length 1, we consider only the final 1 character from the token, which is "s". This
matches the stop string, and so the entire string is matched.
How do we compute these matches with tensor operations, though? Simply: we efficiently precompute the necessary
information for all tokens! For every token, we compute:
- Its overlap with the end of the stop string, if any
- The positions inside the stop string where the token matches, including matches that run off the start.
- The total length of the token
For example, for the token "pped", we would compute an end overlap of 1, no internal matching positions,
and a length of 4. For the token "to", we would compute no end overlap, a single internal matching position
of 1 (counting from the end), and a length of 2. For the token "s", we would compute no end overlap,
a single internal matching position of 3 (again counting from the end) and a length of 1.
As long as we have this information, we can execute the algorithm above without any string comparison
operations. We simply perform the following steps:
- Check if the final token has an end-overlap with the start string
- Continue backwards, keeping track of how much of the stop string we've matched so far
- At each point, check if the next token has the current position as one of its valid positions
- Continue until either a match fails, or we completely match the whole stop string
Again, consider ["s", "to", "pped"] as an example. "pped" has an end overlap of 1, so we can begin a match.
We have matched 1 character so far, so we check that the next token "to", has 1 as a valid position (again,
counting from the end). It does, so we add the length of "to" to our position tracker. We have now matched
3 characters, so we check that the next token "s" has 3 as a valid position. It does, so we add its length
to the position tracker. The position tracker is now 4, which is the length of the stop string. We have matched the
entire stop string.
In the second case, ["las", "topper"], "topper" has an end overlap of 3, so we can begin a match. We have
matched 3 characters so far, so we check that the next token "las" has 3 as a valid position. It does, because we
allow tokens to match positions that run off the start of the stop string. We add its length to the position
tracker. The position tracker is now 6, which is greater than the length of the stop string! Don't panic, though -
this also counts as a match of the stop string. We have matched the entire stop string.
Args:
tokenizer (`PreTrainedTokenizer`):
Expand Down

0 comments on commit 8b52039

Please sign in to comment.