Skip to content

Commit

Permalink
Cleanup code and rename variables for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Feb 19, 2024
1 parent 9b8a10b commit 385abb9
Showing 1 changed file with 28 additions and 20 deletions.
48 changes: 28 additions & 20 deletions src/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,42 +237,50 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
flipped_ids = torch.flip(input_ids, (1,))
string_matches = []
for stop_string in self.stop_strings:
# We need the length of the stop string to know how many characters our token sequence should have
target_len = len(stop_string)
# Maximum number of internal positions a single token can match

# Size of the vector of positions a single token can match
max_valid_positions = self.max_valid_positions[stop_string]
# Maximum number of different overlap sizes a single token can have with the end of the string

# Size of the vector of overlap sizes a single token can have with the end of the string
max_valid_end_lens = self.max_valid_end_lens[stop_string]

# The embedding vec contains the valid positions, end_lengths and total lengths for each token
embedding_vec = self.embedding_vecs[stop_string]
embedded = F.embedding(flipped_ids, embedding_vec)

# Starts contains the number of characters from the string, counting from the end, that the token contains
# It can have multiple values if the same token can overlap different slices of the end of the string
# B x 1 x max_valid_end_lens
starts = embedded[:, :1, max_valid_positions : max_valid_positions + max_valid_end_lens]
# Lengths is the total length of each token. Unlike starts, it always has a single value
# end_lengths is the number of characters from the string, counting from the end, that the token
# contains. It can have multiple values if the same token can overlap different end lengths
end_lengths = embedded[:, :1, max_valid_positions : max_valid_positions + max_valid_end_lens]

# Lengths is the total length of each token. Unlike end_lengths, it always has a single value
lengths = embedded[:, 1:, -1:] # B x (maximum_token_len - 1) x 1
lengths = lengths.expand((-1, -1, starts.shape[-1])) # B x (maximum_token_len - 1) x max_valid_end_lens
lengths_with_starts = torch.cat([starts, lengths], dim=1) # B x maximum_token_len x max_valid_end_lens
# We concatenate each possible starting length with the lengths of the remaining tokens in input_ids
# Then we cumsum() to get the total length of the string after each token
cumsum = lengths_with_starts.cumsum(dim=1) # B x maximum_token_len x max_valid_end_lens

# Valid positions are the positions in the string that the token can validly appear after
# B x (maximum_token_len - 1) x max_valid_positions

# Concatenate lengths onto each possible end_lengths value
lengths = lengths.expand((-1, -1, end_lengths.shape[-1]))
lengths_with_ends = torch.cat([end_lengths, lengths], dim=1) # B x maximum_token_len x max_valid_end_lens

# cumsum() to get the number of matched characters in the stop string after each token
cumsum = lengths_with_ends.cumsum(dim=1) # B x maximum_token_len x max_valid_end_lens

# The code above assumes that all tokens are in valid positions. This code masks the ones that are not.
# First we get the vector of positions tokens can validly appear in
valid_positions = embedded[:, 1:, :max_valid_positions]
# Tokens can match the start of the string if they have any valid value in the starts vector
initial_match = starts > 0 # B x 1 x max_valid_end_lens

# Tokens can match the start of the string if they have any valid value in the end_lengths vector
initial_match = end_lengths > 0

# Tokens can continue the string if the cumsum() so far is one of the valid positions for that token
# Note that we're actually tracking one cumsum() for the list for each possible start overhang length
# B x (maximum_token_len - 1) x max_valid_end_lens
# Note that we're actually tracking one cumsum() for for each possible end_length
later_match = torch.any(cumsum[:, :-1, None] == valid_positions[:, :, :, None], axis=2)

# The match vector is a boolean vector that indicates which positions have valid tokens
match = torch.cat([initial_match, later_match], dim=1)

# Once a single position does not match, all positions following that position are masked
mask = (~match).cumsum(dim=1, dtype=torch.int32)
mask = mask == 0 # B x maximum_token_len x max_valid_end_lens
mask = mask == 0

# The string is matched if we reached a cumsum equal to or greater than the length of the string
# before hitting the mask
Expand Down

0 comments on commit 385abb9

Please sign in to comment.