diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 3fff4421f3977a..56b8466830b52d 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -250,7 +250,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa string_matches = [] for stop_string in self.stop_strings: target_len = len(stop_string) + # Maximum number of internal positions a single token can match max_valid_positions = self.max_valid_positions[stop_string] + # Maximum number of different overlap sizes a single token can have with the end of the string max_valid_end_lens = self.max_valid_end_lens[stop_string] # The embedding vec contains the valid positions, end_lengths and total lengths for each token embedding_vec = self.embedding_vecs[stop_string] @@ -258,32 +260,35 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa # Starts contains the number of characters from the string, counting from the end, that the token contains # It can have multiple values if the same token can overlap different slices of the end of the string + # B x 1 x max_valid_end_lens starts = embedded[:, :1, max_valid_positions : max_valid_positions + max_valid_end_lens] # Lengths is the total length of each token. Unlike starts, it always has a single value - lengths = embedded[:, 1:, -1:] - lengths = lengths.expand((-1, -1, starts.shape[-1])) - lengths_with_starts = torch.cat([starts, lengths], dim=1) + lengths = embedded[:, 1:, -1:] # B x (maximum_token_len - 1) x 1 + lengths = lengths.expand((-1, -1, starts.shape[-1])) # B x (maximum_token_len - 1) x max_valid_end_lens + lengths_with_starts = torch.cat([starts, lengths], dim=1) # B x maximum_token_len x max_valid_end_lens # We concatenate each possible starting length with the lengths of the remaining tokens in input_ids # Then we cumsum() to get the total length of the string after each token - cumsum = lengths_with_starts.cumsum(dim=1) + cumsum = lengths_with_starts.cumsum(dim=1) # B x maximum_token_len x max_valid_end_lens # Valid positions are the positions in the string that the token can validly appear after + # B x (maximum_token_len - 1) x max_valid_positions valid_positions = embedded[:, 1:, :max_valid_positions] # Tokens can match the start of the string if they have any valid value in the starts vector - initial_match = torch.any(starts > 0, dim=-1, keepdim=True) + initial_match = starts > 0 # B x 1 x max_valid_end_lens # Tokens can continue the string if the cumsum() so far is one of the valid positions for that token # Note that we're actually tracking one cumsum() for the list for each possible start overhang length + # B x (maximum_token_len - 1) x max_valid_end_lens later_match = torch.any(cumsum[:, :-1, None] == valid_positions[:, :, :, None], axis=2) # The match vector is a boolean vector that indicates which positions have valid tokens match = torch.cat([initial_match, later_match], dim=1) # Once a single position does not match, all positions following that position are masked mask = (~match).cumsum(dim=1, dtype=torch.int32) - mask = mask == 0 + mask = mask == 0 # B x maximum_token_len x max_valid_end_lens # The string is matched if we reached a cumsum equal to or greater than the length of the string # before hitting the mask - string_matches.append(torch.max(cumsum * mask, dim=1).values.squeeze(1) >= target_len) + string_matches.append(torch.amax(cumsum * mask, dim=(1, 2)) >= target_len) # Now we concatenate the match booleans across all strings and check if any are True string_matches = torch.cat(string_matches, dim=0) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 08f8ded3df9ce8..4c6878882bb7e7 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -110,14 +110,24 @@ def test_validate_stopping_criteria(self): self.assertEqual(len(stopping_criteria), 1) def test_stop_string_criteria(self): - true_strings = ["<|im_start|><|im_end|>", "<|im_start|><|im_end|<|im_end|>", ">><|im_start|>>stop", "stop"] + true_strings = [ + "<|im_start|><|im_end|>", + "<|im_start|><|im_end|<|im_end|>", + ">><|im_start|>>stop", + "stop", + "end", + ] false_strings = [ "<|im_start|><|im_end|", "<|im_start|><|im_end|<|im_end|", "<|im_end|><|im_start|>", "<|im_end|<>stop<|im_end|", ] - too_short_strings = ["<|im_end|", "|im_end|>", "s", "end"] + too_short_strings = [ + "<|im_end|", + "|im_end|>", + "s", + ] # Use a tokenizer that won't actually have special tokens for these tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") @@ -129,7 +139,7 @@ def test_stop_string_criteria(self): too_short_strings, return_tensors="pt", padding="longest", add_special_tokens=False ) scores = None - criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["<|im_end|>", "stop"]) + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["<|im_end|>", "stop", "end"]) for i in range(len(true_strings)): self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores)) for i in range(len(false_strings)): @@ -145,7 +155,7 @@ def test_stop_string_criteria(self): too_short_input_ids = tokenizer( too_short_strings, return_tensors="pt", padding="longest", add_special_tokens=False ) - criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["<|im_end|>", "stop"]) + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["<|im_end|>", "stop", "end"]) for i in range(len(true_strings)): self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores)) for i in range(len(false_strings)):