Skip to content

Commit

Permalink
Fix length calculation and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Feb 13, 2024
1 parent ecbe599 commit 6a92a31
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
19 changes: 12 additions & 7 deletions src/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,40 +250,45 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
string_matches = []
for stop_string in self.stop_strings:
target_len = len(stop_string)
# Maximum number of internal positions a single token can match
max_valid_positions = self.max_valid_positions[stop_string]
# Maximum number of different overlap sizes a single token can have with the end of the string
max_valid_end_lens = self.max_valid_end_lens[stop_string]
# The embedding vec contains the valid positions, end_lengths and total lengths for each token
embedding_vec = self.embedding_vecs[stop_string]
embedded = F.embedding(flipped_ids, embedding_vec)

# Starts contains the number of characters from the string, counting from the end, that the token contains
# It can have multiple values if the same token can overlap different slices of the end of the string
# B x 1 x max_valid_end_lens
starts = embedded[:, :1, max_valid_positions : max_valid_positions + max_valid_end_lens]
# Lengths is the total length of each token. Unlike starts, it always has a single value
lengths = embedded[:, 1:, -1:]
lengths = lengths.expand((-1, -1, starts.shape[-1]))
lengths_with_starts = torch.cat([starts, lengths], dim=1)
lengths = embedded[:, 1:, -1:] # B x (maximum_token_len - 1) x 1
lengths = lengths.expand((-1, -1, starts.shape[-1])) # B x (maximum_token_len - 1) x max_valid_end_lens
lengths_with_starts = torch.cat([starts, lengths], dim=1) # B x maximum_token_len x max_valid_end_lens
# We concatenate each possible starting length with the lengths of the remaining tokens in input_ids
# Then we cumsum() to get the total length of the string after each token
cumsum = lengths_with_starts.cumsum(dim=1)
cumsum = lengths_with_starts.cumsum(dim=1) # B x maximum_token_len x max_valid_end_lens

# Valid positions are the positions in the string that the token can validly appear after
# B x (maximum_token_len - 1) x max_valid_positions
valid_positions = embedded[:, 1:, :max_valid_positions]
# Tokens can match the start of the string if they have any valid value in the starts vector
initial_match = torch.any(starts > 0, dim=-1, keepdim=True)
initial_match = starts > 0 # B x 1 x max_valid_end_lens
# Tokens can continue the string if the cumsum() so far is one of the valid positions for that token
# Note that we're actually tracking one cumsum() for the list for each possible start overhang length
# B x (maximum_token_len - 1) x max_valid_end_lens
later_match = torch.any(cumsum[:, :-1, None] == valid_positions[:, :, :, None], axis=2)
# The match vector is a boolean vector that indicates which positions have valid tokens
match = torch.cat([initial_match, later_match], dim=1)

# Once a single position does not match, all positions following that position are masked
mask = (~match).cumsum(dim=1, dtype=torch.int32)
mask = mask == 0
mask = mask == 0 # B x maximum_token_len x max_valid_end_lens

# The string is matched if we reached a cumsum equal to or greater than the length of the string
# before hitting the mask
string_matches.append(torch.max(cumsum * mask, dim=1).values.squeeze(1) >= target_len)
string_matches.append(torch.amax(cumsum * mask, dim=(1, 2)) >= target_len)

# Now we concatenate the match booleans across all strings and check if any are True
string_matches = torch.cat(string_matches, dim=0)
Expand Down
18 changes: 14 additions & 4 deletions tests/generation/test_stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,24 @@ def test_validate_stopping_criteria(self):
self.assertEqual(len(stopping_criteria), 1)

def test_stop_string_criteria(self):
true_strings = ["<|im_start|><|im_end|>", "<|im_start|><|im_end|<|im_end|>", ">><|im_start|>>stop", "stop"]
true_strings = [
"<|im_start|><|im_end|>",
"<|im_start|><|im_end|<|im_end|>",
">><|im_start|>>stop",
"stop",
"end",
]
false_strings = [
"<|im_start|><|im_end|",
"<|im_start|><|im_end|<|im_end|",
"<|im_end|><|im_start|>",
"<|im_end|<>stop<|im_end|",
]
too_short_strings = ["<|im_end|", "|im_end|>", "s", "end"]
too_short_strings = [
"<|im_end|",
"|im_end|>",
"s",
]

# Use a tokenizer that won't actually have special tokens for these
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
Expand All @@ -129,7 +139,7 @@ def test_stop_string_criteria(self):
too_short_strings, return_tensors="pt", padding="longest", add_special_tokens=False
)
scores = None
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["<|im_end|>", "stop"])
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["<|im_end|>", "stop", "end"])
for i in range(len(true_strings)):
self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores))
for i in range(len(false_strings)):
Expand All @@ -145,7 +155,7 @@ def test_stop_string_criteria(self):
too_short_input_ids = tokenizer(
too_short_strings, return_tensors="pt", padding="longest", add_special_tokens=False
)
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["<|im_end|>", "stop"])
criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["<|im_end|>", "stop", "end"])
for i in range(len(true_strings)):
self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores))
for i in range(len(false_strings)):
Expand Down

0 comments on commit 6a92a31

Please sign in to comment.