From 262537f0107c7f1728d4fe4bc22ec1ee9e652070 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 6 Feb 2024 14:27:43 +0000 Subject: [PATCH 01/68] stash commit (will discard all of this) --- .../generation/stopping_criteria.py | 54 ++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index bac537b71b96ec..b04c5d779d5bca 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -2,7 +2,9 @@ import warnings from abc import ABC from copy import deepcopy -from typing import List, Optional, Union +from typing import Optional, List, Union + +from ..tokenization_utils_base import PreTrainedTokenizerBase import torch @@ -129,6 +131,56 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool) +class TerminationSequenceCriteria(StoppingCriteria): + """ + This class can be used to stop generation whenever specific string sequences are encountered. Because the same + substring can be tokenized in different ways depending on context, this class expands strings up into every possible + token sequence that could contain them in a preprocessing step, then does a vectorized comparison against + `input_ids` during generation. This is much faster than doing detokenization inside the generation loop. + + Args: + tokenizer (`PreTrainedTokenizer`): + The model's associated tokenizer (necessary to extract vocab and tokenize the termination sequences) + termination_sequences (`Union[str, List[str]]`): + The sequences that should end generation. If a string is passed, it will be treated like a + list with a single element. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase, termination_sequences: Union[str, List[str]]): + vocab = tokenizer.get_vocab() + tok_list = list(vocab.keys()) + if isinstance(termination_sequences, str): + termination_sequences = [termination_sequences] + termination_tokens = [] + for seq in termination_sequences: + if seq in tokenizer.special_tokens_map.values(): + # If it's a special token it won't be split, so we can just use it directly + termination_tokens.append(vocab[seq]) + continue + # If it isn't a special token, we need to figure out all sequences of tokens which contain this string. + # This is horribly inefficient, but it'll do to start. + bridging_seqs = [] + for prefix_len in range(1, len(seq) + 1): + for suffix_len in range(len(seq), len(seq) - prefix_len, -1): + prefix = seq[:prefix_len] + suffix = seq[-suffix_len:] + middle = seq[prefix_len:-suffix_len] + possible_starts = [token for token in tok_list if token.endswith(prefix)] + possible_ends = [token for token in tok_list if token.startswith(suffix)] + if not possible_starts or not possible_ends: + continue + bridging_seqs.extend([start + middle + end for start in possible_starts for end in possible_ends]) + if not bridging_seqs: + raise ValueError("Couldn't find any set of tokens spanning the termination sequence " + seq) + bridging_seqs = list(set(bridging_seqs)) # Uniquify just in case + termination_tokens.extend(tokenizer(bridging_seqs, add_special_tokens=False)['input_ids']) + + + @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + return time.time() - self.initial_timestamp > self.max_time + + class EosTokenCriteria(StoppingCriteria): """ This class can be used to stop generation whenever the "end-of-sequence" token is generated. From cfa538b01e46a0fa0cd833f8f1bc72d12b48690d Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 8 Feb 2024 13:29:06 +0000 Subject: [PATCH 02/68] stash commit --- .../generation/stopping_criteria.py | 58 +++++++++---------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index b04c5d779d5bca..cdce7f85c437ee 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -131,7 +131,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool) -class TerminationSequenceCriteria(StoppingCriteria): +class StopStringCriteria(StoppingCriteria): """ This class can be used to stop generation whenever specific string sequences are encountered. Because the same substring can be tokenized in different ways depending on context, this class expands strings up into every possible @@ -141,39 +141,39 @@ class TerminationSequenceCriteria(StoppingCriteria): Args: tokenizer (`PreTrainedTokenizer`): The model's associated tokenizer (necessary to extract vocab and tokenize the termination sequences) - termination_sequences (`Union[str, List[str]]`): - The sequences that should end generation. If a string is passed, it will be treated like a + stop_strings (`Union[str, List[str]]`): + A list of strings that should end generation. If a string is passed, it will be treated like a list with a single element. """ - def __init__(self, tokenizer: PreTrainedTokenizerBase, termination_sequences: Union[str, List[str]]): + def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, List[str]]): vocab = tokenizer.get_vocab() tok_list = list(vocab.keys()) - if isinstance(termination_sequences, str): - termination_sequences = [termination_sequences] - termination_tokens = [] - for seq in termination_sequences: - if seq in tokenizer.special_tokens_map.values(): - # If it's a special token it won't be split, so we can just use it directly - termination_tokens.append(vocab[seq]) - continue - # If it isn't a special token, we need to figure out all sequences of tokens which contain this string. - # This is horribly inefficient, but it'll do to start. - bridging_seqs = [] - for prefix_len in range(1, len(seq) + 1): - for suffix_len in range(len(seq), len(seq) - prefix_len, -1): - prefix = seq[:prefix_len] - suffix = seq[-suffix_len:] - middle = seq[prefix_len:-suffix_len] - possible_starts = [token for token in tok_list if token.endswith(prefix)] - possible_ends = [token for token in tok_list if token.startswith(suffix)] - if not possible_starts or not possible_ends: - continue - bridging_seqs.extend([start + middle + end for start in possible_starts for end in possible_ends]) - if not bridging_seqs: - raise ValueError("Couldn't find any set of tokens spanning the termination sequence " + seq) - bridging_seqs = list(set(bridging_seqs)) # Uniquify just in case - termination_tokens.extend(tokenizer(bridging_seqs, add_special_tokens=False)['input_ids']) + if isinstance(stop_strings, str): + stop_strings = [stop_strings] + for stop_string in stop_strings: + for token in tok_list: + matching_positions = [] + for i in range(1 - len(token), len(stop_string)): + if i < 0: + token = token[:i] + if not token: + raise ValueError("Token is null - this is a bug!") + i = 0 + stop_string = stop_string[i: i + len(token)] + if not stop_string: + raise ValueError("Stop string is null - this is a bug!") + if len(token) > len(stop_string): + token = token[-len(stop_string):] + if not token: + raise ValueError("Token is null after stop string truncation - this is a bug!") + if token == stop_string: + matching_positions.append((i, len(token))) + + + + + @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) From 127182a99e685499be8c9a7d3991390895c04ff9 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 8 Feb 2024 16:37:30 +0000 Subject: [PATCH 03/68] First commit - needs a lot of testing! --- .../generation/stopping_criteria.py | 156 ++++++++++++++---- 1 file changed, 128 insertions(+), 28 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index cdce7f85c437ee..d60dfec83248e0 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -2,12 +2,13 @@ import warnings from abc import ABC from copy import deepcopy -from typing import Optional, List, Union - -from ..tokenization_utils_base import PreTrainedTokenizerBase +from typing import List, Optional, Union +import numpy as np import torch +from torch.nn import functional as F +from ..tokenization_utils_base import PreTrainedTokenizerBase from ..utils import add_start_docstrings, logging @@ -133,10 +134,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa class StopStringCriteria(StoppingCriteria): """ - This class can be used to stop generation whenever specific string sequences are encountered. Because the same - substring can be tokenized in different ways depending on context, this class expands strings up into every possible - token sequence that could contain them in a preprocessing step, then does a vectorized comparison against - `input_ids` during generation. This is much faster than doing detokenization inside the generation loop. + This class can be used to stop generation whenever specific string sequences are encountered. It preprocesses + the strings together with the tokenizer vocab to find positions where tokens can validly complete the stop strings. Args: tokenizer (`PreTrainedTokenizer`): @@ -147,38 +146,139 @@ class StopStringCriteria(StoppingCriteria): """ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, List[str]]): - vocab = tokenizer.get_vocab() - tok_list = list(vocab.keys()) if isinstance(stop_strings, str): stop_strings = [stop_strings] + + self.tokenizer = tokenizer + self.stop_strings: List[str] = stop_strings + self.strings_to_valid_positions, self.strings_to_end_lengths = self.get_matching_positions( + tokenizer, stop_strings + ) + + self.max_valid_positions = { + stop_string: max([len(val) for val in self.strings_to_valid_positions[stop_string].values()]) + for stop_string in stop_strings + } + self.global_max_position = max(self.max_valid_positions.values()) + self.max_valid_end_lens = { + stop_string: max([len(val) for val in self.strings_to_end_lengths[stop_string].values()]) + for stop_string in stop_strings + } + self.embedding_vecs = self.create_embedding_vecs() + + @staticmethod + def get_matching_positions(tokenizer, stop_strings): + """This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can + validly appear in the stop strings. For each stop string, it returns a dictionary mapping tokens to a list of + valid positions, as well as a dictionary mapping tokens to a list of possible overlap lengths at the + end of the stop string.""" + vocab = tokenizer.get_vocab() + tok_list = list(vocab.keys()) + strings_to_valid_positions = {} + strings_to_end_lengths = {} for stop_string in stop_strings: + strings_to_valid_positions[stop_string] = {} + strings_to_end_lengths[stop_string] = {} for token in tok_list: matching_positions = [] + possible_end_lengths = [] for i in range(1 - len(token), len(stop_string)): + tok = token[::-1].replace("▁", " ") + stop = stop_string[::-1] if i < 0: - token = token[:i] - if not token: - raise ValueError("Token is null - this is a bug!") + tok = tok[-i:] + if not tok: + raise ValueError("Tok is null - this is a bug!") i = 0 - stop_string = stop_string[i: i + len(token)] - if not stop_string: - raise ValueError("Stop string is null - this is a bug!") - if len(token) > len(stop_string): - token = token[-len(stop_string):] - if not token: - raise ValueError("Token is null after stop string truncation - this is a bug!") - if token == stop_string: - matching_positions.append((i, len(token))) - - - - - - + stop = stop[i : i + len(tok)] + if not stop: + raise ValueError("Stop is null - this is a bug!") + if len(tok) > len(stop): + tok = tok[: len(stop)] + if not tok: + raise ValueError("Tok is null after stop string truncation - this is a bug!") + if len(tok) != len(stop): + raise ValueError("Truncated token and stop string have different lengths - this is a bug!") + if tok == stop: + if i == 0: + possible_end_lengths.append(len(tok)) + else: + matching_positions.append(i) + if matching_positions: + strings_to_valid_positions[stop_string][token] = matching_positions + if possible_end_lengths: + strings_to_end_lengths[stop_string][token] = possible_end_lengths + return strings_to_valid_positions, strings_to_end_lengths + + def create_embedding_vecs(self): + """ + This function builds an embedding matrix for each stop string, consisting of possible valid positions + and possible end lengths for each token, and the total length of the token string. When tokens have + fewer valid positions or end lengths than the maximum, we pad the vectors with -1000. + The value of -1000 is chosen to be very negative and thus overwhelm any positive values + in the cumsum() calls later. + """ + vocab = self.tokenizer.get_vocab() + embedding_vecs = {} + for stop_string in self.stop_strings: + positions = self.strings_to_valid_positions[stop_string] + end_lens = self.strings_to_end_lengths[stop_string] + # TODO Matt: Merge the embeddings across all stop strings to save space and reduce gather calls? + + # Since this is lots of very small assignments of lists, we build it with numpy rather + # than torch for speed + simplicity, then convert to torch at the end + max_valid_positions = self.max_valid_positions[stop_string] + max_valid_end_lens = self.max_valid_end_lens[stop_string] + vec_size = max_valid_positions + max_valid_end_lens + 1 + gather_vec = np.full((len(self.tokenizer), vec_size), dtype=np.int32, fill_value=-1000) + for token, valid_positions in positions.items(): + token_idx = vocab[token] + gather_vec[token_idx, : len(valid_positions)] = valid_positions + for token, possible_end_lens in end_lens.items(): + token_idx = vocab[token] + gather_vec[ + token_idx, max_valid_positions : max_valid_positions + len(possible_end_lens) + ] = possible_end_lens + for token, token_idx in vocab.items(): + gather_vec[token_idx, -1] = len(token) + embedding_vecs[stop_string] = torch.tensor(gather_vec, dtype=torch.int32) + return embedding_vecs @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: - return time.time() - self.initial_timestamp > self.max_time + # TODO Joao - I'm not using the scores at all and just checking the most recent tokens in input_ids + # Is this correct? Should I be sampling from scores? + # Note that input_ids can also be *shorter* than the global max position, and the code below should be + # ready for that + input_ids = input_ids[:, -self.global_max_position :] + flipped_ids = torch.flip(input_ids, (1,)) + string_matches = [] + for stop_string in self.stop_strings: + target_len = len(stop_string) + max_valid_positions = self.max_valid_positions[stop_string] + max_valid_end_lens = self.max_valid_end_lens[stop_string] + embedding_vec = self.embedding_vecs[stop_string] + embedded = F.embedding(flipped_ids, embedding_vec) + + starts = embedded[:, :1, max_valid_positions : max_valid_positions + max_valid_end_lens] + lengths = embedded[:, 1:, -1:] + lengths = lengths.expand((-1, -1, starts.shape[-1])) + lengths_with_starts = torch.cat([starts, lengths], dim=1) + cumsum = lengths_with_starts.cumsum(dim=1) + valid_positions = embedded[:, 1:, :max_valid_positions] + + initial_match = torch.any(starts > 0, dim=-1, keepdim=True) + later_match = torch.isin(cumsum[:, :-1], valid_positions) + match = torch.cat([initial_match, later_match], dim=1) + + mask = (~match).cumsum(dim=1, dtype=torch.int32) + mask = mask == 0 + + string_matches.append(torch.max(cumsum * mask, dim=1).values.squeeze() >= target_len) + + # Now we concatenate matches across all strings and check if any are True + string_matches = torch.cat(string_matches, dim=0) + return torch.any(string_matches).item() class EosTokenCriteria(StoppingCriteria): From 8cd605917191f458a88baa917b77d1b152f3fd89 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 8 Feb 2024 16:52:43 +0000 Subject: [PATCH 04/68] Add a test --- tests/generation/test_stopping_criteria.py | 34 +++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 0c770972a7fdff..6a37335a390b6e 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -16,7 +16,7 @@ import time import unittest -from transformers import is_torch_available +from transformers import is_torch_available, AutoTokenizer from transformers.testing_utils import require_torch, torch_device from ..test_modeling_common import ids_tensor @@ -30,6 +30,7 @@ MaxLengthCriteria, MaxNewTokensCriteria, MaxTimeCriteria, + StopStringCriteria, StoppingCriteriaList, validate_stopping_criteria, ) @@ -124,3 +125,34 @@ def test_validate_stopping_criteria(self): stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11) self.assertEqual(len(stopping_criteria), 1) + + def test_stop_string_criteria(self): + # Use a tokenizer that won't actually have special tokens for these + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + + true_strings = [ + "<|im_start|><|im_end|>", + "<|im_start|><|im_end|<|im_end|>", + ">><|im_start|>><|im_end|>", + ] + false_strings = [ + "<|im_start|><|im_end|", + "<|im_start|><|im_end|<|im_end|", + "<|im_end|><|im_start|>", + "<|im_end|><|im_start|", + ] + tokenizer.pad_token_id = tokenizer.eos_token_id + true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", padding_side="left") + false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", padding_side="left") + scores = None + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings="<|im_end|>") + self.assertTrue(criteria(true_input_ids["input_ids"], scores)) + self.assertFalse(criteria(false_input_ids["input_ids"], scores)) + + # Now try it with a tokenizer where those are actually special tokens + tokenizer = AutoTokenizer.from_pretrained("cognitivecomputations/dolphin-2.5-mixtral-8x7b") + true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", padding_side="left") + false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", padding_side="left") + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings="<|im_end|>") + self.assertTrue(criteria(true_input_ids["input_ids"], scores)) + self.assertFalse(criteria(false_input_ids["input_ids"], scores)) \ No newline at end of file From 5fde7aeb3bfb70ff60048c438fd040d667905543 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 8 Feb 2024 17:06:50 +0000 Subject: [PATCH 05/68] Fix imports and make the tests actually test something --- src/transformers/generation/__init__.py | 2 ++ tests/generation/test_stopping_criteria.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 6653f3c8d123e9..e4cbfb75661412 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -86,6 +86,7 @@ "StoppingCriteria", "StoppingCriteriaList", "validate_stopping_criteria", + "StopStringCriteria", ] _import_structure["utils"] = [ "GenerationMixin", @@ -225,6 +226,7 @@ StoppingCriteria, StoppingCriteriaList, validate_stopping_criteria, + StopStringCriteria, ) from .utils import ( BeamSampleDecoderOnlyOutput, diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 6a37335a390b6e..9e6383b5a82f4d 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -139,11 +139,12 @@ def test_stop_string_criteria(self): "<|im_start|><|im_end|", "<|im_start|><|im_end|<|im_end|", "<|im_end|><|im_start|>", - "<|im_end|><|im_start|", + "<|im_end|<><|im_end|", ] tokenizer.pad_token_id = tokenizer.eos_token_id - true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", padding_side="left") - false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", padding_side="left") + tokenizer.padding_side = "left" + true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest") + false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest") scores = None criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings="<|im_end|>") self.assertTrue(criteria(true_input_ids["input_ids"], scores)) @@ -151,8 +152,10 @@ def test_stop_string_criteria(self): # Now try it with a tokenizer where those are actually special tokens tokenizer = AutoTokenizer.from_pretrained("cognitivecomputations/dolphin-2.5-mixtral-8x7b") - true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", padding_side="left") - false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", padding_side="left") + tokenizer.pad_token_id = tokenizer.eos_token_id + tokenizer.padding_side = "left" + true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest") + false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest") criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings="<|im_end|>") self.assertTrue(criteria(true_input_ids["input_ids"], scores)) self.assertFalse(criteria(false_input_ids["input_ids"], scores)) \ No newline at end of file From ff02b0cd1d721f9560c034db8b00e6cee1108995 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 8 Feb 2024 17:40:55 +0000 Subject: [PATCH 06/68] Tests pass! --- src/transformers/generation/__init__.py | 2 +- .../generation/stopping_criteria.py | 8 +++---- tests/generation/test_stopping_criteria.py | 22 ++++++++++++------- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index e4cbfb75661412..a669d6ed0659cf 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -225,8 +225,8 @@ MaxTimeCriteria, StoppingCriteria, StoppingCriteriaList, - validate_stopping_criteria, StopStringCriteria, + validate_stopping_criteria, ) from .utils import ( BeamSampleDecoderOnlyOutput, diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index d60dfec83248e0..3755abe81ebb14 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -159,7 +159,6 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, stop_string: max([len(val) for val in self.strings_to_valid_positions[stop_string].values()]) for stop_string in stop_strings } - self.global_max_position = max(self.max_valid_positions.values()) self.max_valid_end_lens = { stop_string: max([len(val) for val in self.strings_to_end_lengths[stop_string].values()]) for stop_string in stop_strings @@ -248,9 +247,10 @@ def create_embedding_vecs(self): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: # TODO Joao - I'm not using the scores at all and just checking the most recent tokens in input_ids # Is this correct? Should I be sampling from scores? - # Note that input_ids can also be *shorter* than the global max position, and the code below should be - # ready for that - input_ids = input_ids[:, -self.global_max_position :] + # The maximum length we need to consider is 1 token per character. Note that input_ids can also be + # *shorter* than the global max, and the code below should be ready for that + maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings]) + input_ids = input_ids[:, -maximum_token_len:] flipped_ids = torch.flip(input_ids, (1,)) string_matches = [] for stop_string in self.stop_strings: diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 9e6383b5a82f4d..0c0cef0d04f251 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -16,7 +16,7 @@ import time import unittest -from transformers import is_torch_available, AutoTokenizer +from transformers import AutoTokenizer, is_torch_available from transformers.testing_utils import require_torch, torch_device from ..test_modeling_common import ids_tensor @@ -30,8 +30,8 @@ MaxLengthCriteria, MaxNewTokensCriteria, MaxTimeCriteria, - StopStringCriteria, StoppingCriteriaList, + StopStringCriteria, validate_stopping_criteria, ) @@ -140,22 +140,28 @@ def test_stop_string_criteria(self): "<|im_start|><|im_end|<|im_end|", "<|im_end|><|im_start|>", "<|im_end|<><|im_end|", - ] + ] + too_short_strings = ["<|im_end|", "|im_end|>"] tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" - true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest") - false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest") + true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False) + false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False) + too_short_input_ids = tokenizer( + too_short_strings, return_tensors="pt", padding="longest", add_special_tokens=False + ) scores = None criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings="<|im_end|>") self.assertTrue(criteria(true_input_ids["input_ids"], scores)) self.assertFalse(criteria(false_input_ids["input_ids"], scores)) + self.assertFalse(criteria(too_short_input_ids["input_ids"], scores)) # Now try it with a tokenizer where those are actually special tokens tokenizer = AutoTokenizer.from_pretrained("cognitivecomputations/dolphin-2.5-mixtral-8x7b") tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" - true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest") - false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest") + true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False) + false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False) criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings="<|im_end|>") self.assertTrue(criteria(true_input_ids["input_ids"], scores)) - self.assertFalse(criteria(false_input_ids["input_ids"], scores)) \ No newline at end of file + self.assertFalse(criteria(false_input_ids["input_ids"], scores)) + self.assertFalse(criteria(too_short_input_ids["input_ids"], scores)) From 4ce1aba0831660d015536997d153de7154d6d04e Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 8 Feb 2024 17:50:03 +0000 Subject: [PATCH 07/68] Rearrange test --- tests/generation/test_stopping_criteria.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 0c0cef0d04f251..69b4e27a28bfcc 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -127,9 +127,6 @@ def test_validate_stopping_criteria(self): self.assertEqual(len(stopping_criteria), 1) def test_stop_string_criteria(self): - # Use a tokenizer that won't actually have special tokens for these - tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - true_strings = [ "<|im_start|><|im_end|>", "<|im_start|><|im_end|<|im_end|>", @@ -142,6 +139,9 @@ def test_stop_string_criteria(self): "<|im_end|<><|im_end|", ] too_short_strings = ["<|im_end|", "|im_end|>"] + + # Use a tokenizer that won't actually have special tokens for these + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False) From 1742b681492cc8d0f992ea0af34f0b3535ebd83d Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 8 Feb 2024 18:03:24 +0000 Subject: [PATCH 08/68] Add comments (but it's still a bit confusing) --- .../generation/stopping_criteria.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 3755abe81ebb14..c01af4162578c6 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -213,9 +213,7 @@ def create_embedding_vecs(self): """ This function builds an embedding matrix for each stop string, consisting of possible valid positions and possible end lengths for each token, and the total length of the token string. When tokens have - fewer valid positions or end lengths than the maximum, we pad the vectors with -1000. - The value of -1000 is chosen to be very negative and thus overwhelm any positive values - in the cumsum() calls later. + fewer valid positions or end lengths than the maximum, we pad the vectors with -1. """ vocab = self.tokenizer.get_vocab() embedding_vecs = {} @@ -229,7 +227,7 @@ def create_embedding_vecs(self): max_valid_positions = self.max_valid_positions[stop_string] max_valid_end_lens = self.max_valid_end_lens[stop_string] vec_size = max_valid_positions + max_valid_end_lens + 1 - gather_vec = np.full((len(self.tokenizer), vec_size), dtype=np.int32, fill_value=-1000) + gather_vec = np.full((len(self.tokenizer), vec_size), dtype=np.int32, fill_value=-1) for token, valid_positions in positions.items(): token_idx = vocab[token] gather_vec[token_idx, : len(valid_positions)] = valid_positions @@ -251,32 +249,47 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa # *shorter* than the global max, and the code below should be ready for that maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings]) input_ids = input_ids[:, -maximum_token_len:] + # Flip input_ids because we're only matching strings at the end of the generated sequence flipped_ids = torch.flip(input_ids, (1,)) string_matches = [] for stop_string in self.stop_strings: target_len = len(stop_string) max_valid_positions = self.max_valid_positions[stop_string] max_valid_end_lens = self.max_valid_end_lens[stop_string] + # The embedding vec contains the valid positions, end_lengths and total lengths for each token embedding_vec = self.embedding_vecs[stop_string] embedded = F.embedding(flipped_ids, embedding_vec) + # Starts contains the number of characters from the string, counting from the end, that the token contains + # It can have multiple values if the same token can overlap different slices of the end of the string starts = embedded[:, :1, max_valid_positions : max_valid_positions + max_valid_end_lens] + # Lengths is the total length of each token. Unlike starts, it always has a single value lengths = embedded[:, 1:, -1:] lengths = lengths.expand((-1, -1, starts.shape[-1])) lengths_with_starts = torch.cat([starts, lengths], dim=1) + # We concatenate each possible starting length with the lengths of the remaining tokens in input_ids + # Then we cumsum() to get the total length of the string after each token cumsum = lengths_with_starts.cumsum(dim=1) - valid_positions = embedded[:, 1:, :max_valid_positions] + # Valid positions are the positions in the string that the token can validly appear after + valid_positions = embedded[:, 1:, :max_valid_positions] + # Tokens can match the start of the string if they have any valid value in the starts vector initial_match = torch.any(starts > 0, dim=-1, keepdim=True) + # Tokens can continue the string if the cumsum() so far is one of the valid positions for that token + # Note that we're actually tracking one cumsum() for the list for each possible start overhang length later_match = torch.isin(cumsum[:, :-1], valid_positions) + # The match vector is a boolean vector that indicates which positions have valid tokens match = torch.cat([initial_match, later_match], dim=1) + # Once a single position does not match, all positions following that position are masked mask = (~match).cumsum(dim=1, dtype=torch.int32) mask = mask == 0 + # The string is matched if we reached a cumsum equal to or greater than the length of the string + # before hitting the masked run string_matches.append(torch.max(cumsum * mask, dim=1).values.squeeze() >= target_len) - # Now we concatenate matches across all strings and check if any are True + # Now we concatenate the match booleans across all strings and check if any are True string_matches = torch.cat(string_matches, dim=0) return torch.any(string_matches).item() From 9fb77e33801d0359c6625fda9f93dba25522cce0 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 8 Feb 2024 18:05:24 +0000 Subject: [PATCH 09/68] Stop storing the tokenizer --- src/transformers/generation/stopping_criteria.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index c01af4162578c6..6671d9f2df7e8a 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -149,10 +149,10 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, if isinstance(stop_strings, str): stop_strings = [stop_strings] - self.tokenizer = tokenizer + self.vocab = tokenizer.get_vocab() self.stop_strings: List[str] = stop_strings self.strings_to_valid_positions, self.strings_to_end_lengths = self.get_matching_positions( - tokenizer, stop_strings + self.vocab, stop_strings ) self.max_valid_positions = { @@ -166,12 +166,11 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, self.embedding_vecs = self.create_embedding_vecs() @staticmethod - def get_matching_positions(tokenizer, stop_strings): + def get_matching_positions(vocab, stop_strings): """This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can validly appear in the stop strings. For each stop string, it returns a dictionary mapping tokens to a list of valid positions, as well as a dictionary mapping tokens to a list of possible overlap lengths at the end of the stop string.""" - vocab = tokenizer.get_vocab() tok_list = list(vocab.keys()) strings_to_valid_positions = {} strings_to_end_lengths = {} @@ -215,7 +214,7 @@ def create_embedding_vecs(self): and possible end lengths for each token, and the total length of the token string. When tokens have fewer valid positions or end lengths than the maximum, we pad the vectors with -1. """ - vocab = self.tokenizer.get_vocab() + vocab = self.vocab embedding_vecs = {} for stop_string in self.stop_strings: positions = self.strings_to_valid_positions[stop_string] @@ -227,7 +226,7 @@ def create_embedding_vecs(self): max_valid_positions = self.max_valid_positions[stop_string] max_valid_end_lens = self.max_valid_end_lens[stop_string] vec_size = max_valid_positions + max_valid_end_lens + 1 - gather_vec = np.full((len(self.tokenizer), vec_size), dtype=np.int32, fill_value=-1) + gather_vec = np.full((len(self.vocab), vec_size), dtype=np.int32, fill_value=-1) for token, valid_positions in positions.items(): token_idx = vocab[token] gather_vec[token_idx, : len(valid_positions)] = valid_positions From 667d6d880915907ec85f598c26ee204f5c902a8d Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 8 Feb 2024 18:07:09 +0000 Subject: [PATCH 10/68] Comment fixup --- src/transformers/generation/stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 6671d9f2df7e8a..7ba39634ece991 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -285,7 +285,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa mask = mask == 0 # The string is matched if we reached a cumsum equal to or greater than the length of the string - # before hitting the masked run + # before hitting the mask string_matches.append(torch.max(cumsum * mask, dim=1).values.squeeze() >= target_len) # Now we concatenate the match booleans across all strings and check if any are True From 070a76e864ce1f24dda045bb0f5568642ab14ef3 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 9 Feb 2024 14:01:58 +0000 Subject: [PATCH 11/68] Fix for input_ids with a single sequence --- src/transformers/generation/stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 7ba39634ece991..3fb849c08d92f4 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -286,7 +286,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa # The string is matched if we reached a cumsum equal to or greater than the length of the string # before hitting the mask - string_matches.append(torch.max(cumsum * mask, dim=1).values.squeeze() >= target_len) + string_matches.append(torch.max(cumsum * mask, dim=1).values.squeeze(1) >= target_len) # Now we concatenate the match booleans across all strings and check if any are True string_matches = torch.cat(string_matches, dim=0) From 4c436f2cda149c9fc66729c2669eb8939c2ed7f5 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 9 Feb 2024 14:02:06 +0000 Subject: [PATCH 12/68] Update tests to test single sequences --- tests/generation/test_stopping_criteria.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 69b4e27a28bfcc..b1b454c30c6e47 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -151,9 +151,12 @@ def test_stop_string_criteria(self): ) scores = None criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings="<|im_end|>") - self.assertTrue(criteria(true_input_ids["input_ids"], scores)) - self.assertFalse(criteria(false_input_ids["input_ids"], scores)) - self.assertFalse(criteria(too_short_input_ids["input_ids"], scores)) + for i in range(len(true_strings)): + self.assertTrue(criteria(true_input_ids["input_ids"][i: i+1], scores)) + for i in range(len(false_strings)): + self.assertFalse(criteria(false_input_ids["input_ids"][i: i+1], scores)) + for i in range(len(too_short_strings)): + self.assertFalse(criteria(too_short_input_ids["input_ids"][i: i+1], scores)) # Now try it with a tokenizer where those are actually special tokens tokenizer = AutoTokenizer.from_pretrained("cognitivecomputations/dolphin-2.5-mixtral-8x7b") @@ -162,6 +165,9 @@ def test_stop_string_criteria(self): true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False) false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False) criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings="<|im_end|>") - self.assertTrue(criteria(true_input_ids["input_ids"], scores)) - self.assertFalse(criteria(false_input_ids["input_ids"], scores)) - self.assertFalse(criteria(too_short_input_ids["input_ids"], scores)) + for i in range(len(true_strings)): + self.assertTrue(criteria(true_input_ids["input_ids"][i: i+1], scores)) + for i in range(len(false_strings)): + self.assertFalse(criteria(false_input_ids["input_ids"][i: i+1], scores)) + for i in range(len(too_short_strings)): + self.assertFalse(criteria(too_short_input_ids["input_ids"][i: i+1], scores)) \ No newline at end of file From 78b0f247548ab26d11b66b6e9a2dbec7b90f246c Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 9 Feb 2024 14:04:31 +0000 Subject: [PATCH 13/68] make fixup --- tests/generation/test_stopping_criteria.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index b1b454c30c6e47..23c83fa0dd3b4e 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -152,11 +152,11 @@ def test_stop_string_criteria(self): scores = None criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings="<|im_end|>") for i in range(len(true_strings)): - self.assertTrue(criteria(true_input_ids["input_ids"][i: i+1], scores)) + self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores)) for i in range(len(false_strings)): - self.assertFalse(criteria(false_input_ids["input_ids"][i: i+1], scores)) + self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores)) for i in range(len(too_short_strings)): - self.assertFalse(criteria(too_short_input_ids["input_ids"][i: i+1], scores)) + self.assertFalse(criteria(too_short_input_ids["input_ids"][i : i + 1], scores)) # Now try it with a tokenizer where those are actually special tokens tokenizer = AutoTokenizer.from_pretrained("cognitivecomputations/dolphin-2.5-mixtral-8x7b") @@ -166,8 +166,8 @@ def test_stop_string_criteria(self): false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False) criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings="<|im_end|>") for i in range(len(true_strings)): - self.assertTrue(criteria(true_input_ids["input_ids"][i: i+1], scores)) + self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores)) for i in range(len(false_strings)): - self.assertFalse(criteria(false_input_ids["input_ids"][i: i+1], scores)) + self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores)) for i in range(len(too_short_strings)): - self.assertFalse(criteria(too_short_input_ids["input_ids"][i: i+1], scores)) \ No newline at end of file + self.assertFalse(criteria(too_short_input_ids["input_ids"][i : i + 1], scores)) From 8ee5762e1bc41915d463911e6c382a848ff4bf95 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 12 Feb 2024 16:52:46 +0000 Subject: [PATCH 14/68] Fix incorrect use of isin() --- src/transformers/generation/stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 3fb849c08d92f4..8acaedc1f4e8dd 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -276,7 +276,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa initial_match = torch.any(starts > 0, dim=-1, keepdim=True) # Tokens can continue the string if the cumsum() so far is one of the valid positions for that token # Note that we're actually tracking one cumsum() for the list for each possible start overhang length - later_match = torch.isin(cumsum[:, :-1], valid_positions) + later_match = torch.any(cumsum[:, :-1, None] == valid_positions[:, :, :, None], axis=2) # The match vector is a boolean vector that indicates which positions have valid tokens match = torch.cat([initial_match, later_match], dim=1) From 9f43a2a63f91b4f564aaf7dd3f1f9508f072cf42 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 12 Feb 2024 16:52:57 +0000 Subject: [PATCH 15/68] Expand tests to catch more cases --- tests/generation/test_stopping_criteria.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 23c83fa0dd3b4e..8e2c6e3b2f7f40 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -130,15 +130,16 @@ def test_stop_string_criteria(self): true_strings = [ "<|im_start|><|im_end|>", "<|im_start|><|im_end|<|im_end|>", - ">><|im_start|>><|im_end|>", + ">><|im_start|>>stop", + "stop" ] false_strings = [ "<|im_start|><|im_end|", "<|im_start|><|im_end|<|im_end|", "<|im_end|><|im_start|>", - "<|im_end|<><|im_end|", + "<|im_end|<>stop<|im_end|", ] - too_short_strings = ["<|im_end|", "|im_end|>"] + too_short_strings = ["<|im_end|", "|im_end|>", "s"] # Use a tokenizer that won't actually have special tokens for these tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") @@ -150,7 +151,7 @@ def test_stop_string_criteria(self): too_short_strings, return_tensors="pt", padding="longest", add_special_tokens=False ) scores = None - criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings="<|im_end|>") + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["<|im_end|>", "stop"]) for i in range(len(true_strings)): self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores)) for i in range(len(false_strings)): @@ -160,11 +161,13 @@ def test_stop_string_criteria(self): # Now try it with a tokenizer where those are actually special tokens tokenizer = AutoTokenizer.from_pretrained("cognitivecomputations/dolphin-2.5-mixtral-8x7b") - tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "left" true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False) false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False) - criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings="<|im_end|>") + too_short_input_ids = tokenizer( + too_short_strings, return_tensors="pt", padding="longest", add_special_tokens=False + ) + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["<|im_end|>", "stop"]) for i in range(len(true_strings)): self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores)) for i in range(len(false_strings)): From f0fa7074971a9b0ac9d7fecd26f8728bad1bfaa5 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 12 Feb 2024 17:00:51 +0000 Subject: [PATCH 16/68] Expand tests to catch more cases --- tests/generation/test_stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 8e2c6e3b2f7f40..b3b2a20eab03d6 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -139,7 +139,7 @@ def test_stop_string_criteria(self): "<|im_end|><|im_start|>", "<|im_end|<>stop<|im_end|", ] - too_short_strings = ["<|im_end|", "|im_end|>", "s"] + too_short_strings = ["<|im_end|", "|im_end|>", "s", "end"] # Use a tokenizer that won't actually have special tokens for these tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") From 5bcf5e476beb730b80062a7f1287df0c57e883a2 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 12 Feb 2024 17:06:48 +0000 Subject: [PATCH 17/68] make fixup --- tests/generation/test_stopping_criteria.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index b3b2a20eab03d6..3aabaeb5f2f260 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -127,12 +127,7 @@ def test_validate_stopping_criteria(self): self.assertEqual(len(stopping_criteria), 1) def test_stop_string_criteria(self): - true_strings = [ - "<|im_start|><|im_end|>", - "<|im_start|><|im_end|<|im_end|>", - ">><|im_start|>>stop", - "stop" - ] + true_strings = ["<|im_start|><|im_end|>", "<|im_start|><|im_end|<|im_end|>", ">><|im_start|>>stop", "stop"] false_strings = [ "<|im_start|><|im_end|", "<|im_start|><|im_end|<|im_end|", From 8cca9a403a44ec8998428a563f8e4589c1a0094d Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 13 Feb 2024 14:22:32 +0000 Subject: [PATCH 18/68] Fix length calculation and update tests --- .../generation/stopping_criteria.py | 19 ++++++++++++------- tests/generation/test_stopping_criteria.py | 18 ++++++++++++++---- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 8acaedc1f4e8dd..17310c0f1c7967 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -253,7 +253,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa string_matches = [] for stop_string in self.stop_strings: target_len = len(stop_string) + # Maximum number of internal positions a single token can match max_valid_positions = self.max_valid_positions[stop_string] + # Maximum number of different overlap sizes a single token can have with the end of the string max_valid_end_lens = self.max_valid_end_lens[stop_string] # The embedding vec contains the valid positions, end_lengths and total lengths for each token embedding_vec = self.embedding_vecs[stop_string] @@ -261,32 +263,35 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa # Starts contains the number of characters from the string, counting from the end, that the token contains # It can have multiple values if the same token can overlap different slices of the end of the string + # B x 1 x max_valid_end_lens starts = embedded[:, :1, max_valid_positions : max_valid_positions + max_valid_end_lens] # Lengths is the total length of each token. Unlike starts, it always has a single value - lengths = embedded[:, 1:, -1:] - lengths = lengths.expand((-1, -1, starts.shape[-1])) - lengths_with_starts = torch.cat([starts, lengths], dim=1) + lengths = embedded[:, 1:, -1:] # B x (maximum_token_len - 1) x 1 + lengths = lengths.expand((-1, -1, starts.shape[-1])) # B x (maximum_token_len - 1) x max_valid_end_lens + lengths_with_starts = torch.cat([starts, lengths], dim=1) # B x maximum_token_len x max_valid_end_lens # We concatenate each possible starting length with the lengths of the remaining tokens in input_ids # Then we cumsum() to get the total length of the string after each token - cumsum = lengths_with_starts.cumsum(dim=1) + cumsum = lengths_with_starts.cumsum(dim=1) # B x maximum_token_len x max_valid_end_lens # Valid positions are the positions in the string that the token can validly appear after + # B x (maximum_token_len - 1) x max_valid_positions valid_positions = embedded[:, 1:, :max_valid_positions] # Tokens can match the start of the string if they have any valid value in the starts vector - initial_match = torch.any(starts > 0, dim=-1, keepdim=True) + initial_match = starts > 0 # B x 1 x max_valid_end_lens # Tokens can continue the string if the cumsum() so far is one of the valid positions for that token # Note that we're actually tracking one cumsum() for the list for each possible start overhang length + # B x (maximum_token_len - 1) x max_valid_end_lens later_match = torch.any(cumsum[:, :-1, None] == valid_positions[:, :, :, None], axis=2) # The match vector is a boolean vector that indicates which positions have valid tokens match = torch.cat([initial_match, later_match], dim=1) # Once a single position does not match, all positions following that position are masked mask = (~match).cumsum(dim=1, dtype=torch.int32) - mask = mask == 0 + mask = mask == 0 # B x maximum_token_len x max_valid_end_lens # The string is matched if we reached a cumsum equal to or greater than the length of the string # before hitting the mask - string_matches.append(torch.max(cumsum * mask, dim=1).values.squeeze(1) >= target_len) + string_matches.append(torch.amax(cumsum * mask, dim=(1, 2)) >= target_len) # Now we concatenate the match booleans across all strings and check if any are True string_matches = torch.cat(string_matches, dim=0) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 3aabaeb5f2f260..1a8f109b70989f 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -127,14 +127,24 @@ def test_validate_stopping_criteria(self): self.assertEqual(len(stopping_criteria), 1) def test_stop_string_criteria(self): - true_strings = ["<|im_start|><|im_end|>", "<|im_start|><|im_end|<|im_end|>", ">><|im_start|>>stop", "stop"] + true_strings = [ + "<|im_start|><|im_end|>", + "<|im_start|><|im_end|<|im_end|>", + ">><|im_start|>>stop", + "stop", + "end", + ] false_strings = [ "<|im_start|><|im_end|", "<|im_start|><|im_end|<|im_end|", "<|im_end|><|im_start|>", "<|im_end|<>stop<|im_end|", ] - too_short_strings = ["<|im_end|", "|im_end|>", "s", "end"] + too_short_strings = [ + "<|im_end|", + "|im_end|>", + "s", + ] # Use a tokenizer that won't actually have special tokens for these tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") @@ -146,7 +156,7 @@ def test_stop_string_criteria(self): too_short_strings, return_tensors="pt", padding="longest", add_special_tokens=False ) scores = None - criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["<|im_end|>", "stop"]) + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["<|im_end|>", "stop", "end"]) for i in range(len(true_strings)): self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores)) for i in range(len(false_strings)): @@ -162,7 +172,7 @@ def test_stop_string_criteria(self): too_short_input_ids = tokenizer( too_short_strings, return_tensors="pt", padding="longest", add_special_tokens=False ) - criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["<|im_end|>", "stop"]) + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["<|im_end|>", "stop", "end"]) for i in range(len(true_strings)): self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores)) for i in range(len(false_strings)): From ec6f72659e977a8445804c7749b3138124cbfc2c Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 14 Feb 2024 16:19:39 +0000 Subject: [PATCH 19/68] =?UTF-8?q?Handle=20=C4=A0=20as=20a=20space=20replac?= =?UTF-8?q?ement=20too?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/transformers/generation/stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 17310c0f1c7967..f1989cf5ca691a 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -181,7 +181,7 @@ def get_matching_positions(vocab, stop_strings): matching_positions = [] possible_end_lengths = [] for i in range(1 - len(token), len(stop_string)): - tok = token[::-1].replace("▁", " ") + tok = token[::-1].replace("▁", " ").replace("Ġ", " ") stop = stop_string[::-1] if i < 0: tok = tok[-i:] From 0e632c2fd9ebb295156f53a4edc3ad426c4b9f3a Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 15 Feb 2024 13:33:52 +0000 Subject: [PATCH 20/68] Update src/transformers/generation/stopping_criteria.py Co-authored-by: Joao Gante --- src/transformers/generation/stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index f1989cf5ca691a..0980de8738da97 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -166,7 +166,7 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, self.embedding_vecs = self.create_embedding_vecs() @staticmethod - def get_matching_positions(vocab, stop_strings): + def get_matching_positions(vocab: List[str], stop_strings: List[str]) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]: """This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can validly appear in the stop strings. For each stop string, it returns a dictionary mapping tokens to a list of valid positions, as well as a dictionary mapping tokens to a list of possible overlap lengths at the From ac1135c24ed61d9d5e1fb4e5ebc77c288ecba565 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 15 Feb 2024 13:39:47 +0000 Subject: [PATCH 21/68] Add optimizations from Joao's suggestion --- src/transformers/generation/stopping_criteria.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 0980de8738da97..63a9eecd78374d 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -2,7 +2,7 @@ import warnings from abc import ABC from copy import deepcopy -from typing import List, Optional, Union +from typing import List, Optional, Union, Tuple, Dict import numpy as np import torch @@ -172,17 +172,19 @@ def get_matching_positions(vocab: List[str], stop_strings: List[str]) -> Tuple[D valid positions, as well as a dictionary mapping tokens to a list of possible overlap lengths at the end of the stop string.""" tok_list = list(vocab.keys()) + reversed_filtered_tok_list = [token[::-1].replace("▁", " ").replace("Ġ", " ") for token in tok_list] strings_to_valid_positions = {} strings_to_end_lengths = {} for stop_string in stop_strings: + reversed_stop_string = stop_string[::-1] strings_to_valid_positions[stop_string] = {} strings_to_end_lengths[stop_string] = {} - for token in tok_list: + for token, reversed_filtered_token in zip(tok_list, reversed_filtered_tok_list): matching_positions = [] possible_end_lengths = [] for i in range(1 - len(token), len(stop_string)): - tok = token[::-1].replace("▁", " ").replace("Ġ", " ") - stop = stop_string[::-1] + tok = reversed_filtered_token + stop = reversed_stop_string if i < 0: tok = tok[-i:] if not tok: From 27318270b4a5ce7769d3635c7a553cd9b737907c Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 15 Feb 2024 13:40:09 +0000 Subject: [PATCH 22/68] Remove TODO --- src/transformers/generation/stopping_criteria.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 63a9eecd78374d..89b768bc02f985 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -244,8 +244,6 @@ def create_embedding_vecs(self): @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: - # TODO Joao - I'm not using the scores at all and just checking the most recent tokens in input_ids - # Is this correct? Should I be sampling from scores? # The maximum length we need to consider is 1 token per character. Note that input_ids can also be # *shorter* than the global max, and the code below should be ready for that maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings]) From 9213298f3fa3f1cecd40af1779ba6b54e7dc2ee7 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 15 Feb 2024 13:40:58 +0000 Subject: [PATCH 23/68] Update src/transformers/generation/stopping_criteria.py Co-authored-by: Joao Gante --- src/transformers/generation/stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 89b768bc02f985..202e5a6fb94287 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -210,7 +210,7 @@ def get_matching_positions(vocab: List[str], stop_strings: List[str]) -> Tuple[D strings_to_end_lengths[stop_string][token] = possible_end_lengths return strings_to_valid_positions, strings_to_end_lengths - def create_embedding_vecs(self): + def create_embedding_vecs(self) -> Dict[str, torch.tensor]: """ This function builds an embedding matrix for each stop string, consisting of possible valid positions and possible end lengths for each token, and the total length of the token string. When tokens have From f48522e1420b62f1bb52e7189fab7a71f9779e98 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 15 Feb 2024 13:42:10 +0000 Subject: [PATCH 24/68] Update tests/generation/test_stopping_criteria.py Co-authored-by: Joao Gante --- tests/generation/test_stopping_criteria.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 1a8f109b70989f..3328bae598fdec 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -139,6 +139,8 @@ def test_stop_string_criteria(self): "<|im_start|><|im_end|<|im_end|", "<|im_end|><|im_start|>", "<|im_end|<>stop<|im_end|", + "en d", + "eNd", ] too_short_strings = [ "<|im_end|", From 7a772b8f9175a0ee90bed7d9bcab6950d5a12020 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 15 Feb 2024 13:42:58 +0000 Subject: [PATCH 25/68] make fixup --- src/transformers/generation/stopping_criteria.py | 6 ++++-- tests/generation/test_stopping_criteria.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 202e5a6fb94287..a7d6fa4e973819 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -2,7 +2,7 @@ import warnings from abc import ABC from copy import deepcopy -from typing import List, Optional, Union, Tuple, Dict +from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -166,7 +166,9 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, self.embedding_vecs = self.create_embedding_vecs() @staticmethod - def get_matching_positions(vocab: List[str], stop_strings: List[str]) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]: + def get_matching_positions( + vocab: List[str], stop_strings: List[str] + ) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]: """This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can validly appear in the stop strings. For each stop string, it returns a dictionary mapping tokens to a list of valid positions, as well as a dictionary mapping tokens to a list of possible overlap lengths at the diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 3328bae598fdec..edb7a67f8ae171 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -147,6 +147,7 @@ def test_stop_string_criteria(self): "|im_end|>", "s", ] + stop_strings = ["<|im_end|>", "stop", "end"] # Use a tokenizer that won't actually have special tokens for these tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") @@ -158,7 +159,7 @@ def test_stop_string_criteria(self): too_short_strings, return_tensors="pt", padding="longest", add_special_tokens=False ) scores = None - criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["<|im_end|>", "stop", "end"]) + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) for i in range(len(true_strings)): self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores)) for i in range(len(false_strings)): @@ -174,7 +175,7 @@ def test_stop_string_criteria(self): too_short_input_ids = tokenizer( too_short_strings, return_tensors="pt", padding="longest", add_special_tokens=False ) - criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=["<|im_end|>", "stop", "end"]) + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) for i in range(len(true_strings)): self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores)) for i in range(len(false_strings)): From c604a2ba13467040a9111d515ef264d6bd6febcd Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 16 Feb 2024 15:38:10 +0000 Subject: [PATCH 26/68] Rename some variables and remove some debugging clauses for clarity --- .../generation/stopping_criteria.py | 54 +++++++------------ 1 file changed, 20 insertions(+), 34 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index a7d6fa4e973819..f81f093e978e52 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -151,36 +151,31 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, self.vocab = tokenizer.get_vocab() self.stop_strings: List[str] = stop_strings - self.strings_to_valid_positions, self.strings_to_end_lengths = self.get_matching_positions( - self.vocab, stop_strings - ) + self.token_valid_positions, self.token_end_overlaps = self._get_matching_positions() self.max_valid_positions = { - stop_string: max([len(val) for val in self.strings_to_valid_positions[stop_string].values()]) + stop_string: max([len(val) for val in self.token_valid_positions[stop_string].values()]) for stop_string in stop_strings } self.max_valid_end_lens = { - stop_string: max([len(val) for val in self.strings_to_end_lengths[stop_string].values()]) + stop_string: max([len(val) for val in self.token_end_overlaps[stop_string].values()]) for stop_string in stop_strings } - self.embedding_vecs = self.create_embedding_vecs() + self.embedding_vecs = self._create_embedding_vecs() - @staticmethod - def get_matching_positions( - vocab: List[str], stop_strings: List[str] - ) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]: + def _get_matching_positions(self) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]: """This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can validly appear in the stop strings. For each stop string, it returns a dictionary mapping tokens to a list of valid positions, as well as a dictionary mapping tokens to a list of possible overlap lengths at the end of the stop string.""" - tok_list = list(vocab.keys()) + tok_list = list(self.vocab.keys()) reversed_filtered_tok_list = [token[::-1].replace("▁", " ").replace("Ġ", " ") for token in tok_list] - strings_to_valid_positions = {} - strings_to_end_lengths = {} - for stop_string in stop_strings: + token_valid_positions = {} + token_end_overlaps = {} + for stop_string in self.stop_strings: reversed_stop_string = stop_string[::-1] - strings_to_valid_positions[stop_string] = {} - strings_to_end_lengths[stop_string] = {} + token_valid_positions[stop_string] = {} + token_end_overlaps[stop_string] = {} for token, reversed_filtered_token in zip(tok_list, reversed_filtered_tok_list): matching_positions = [] possible_end_lengths = [] @@ -189,47 +184,38 @@ def get_matching_positions( stop = reversed_stop_string if i < 0: tok = tok[-i:] - if not tok: - raise ValueError("Tok is null - this is a bug!") i = 0 stop = stop[i : i + len(tok)] - if not stop: - raise ValueError("Stop is null - this is a bug!") if len(tok) > len(stop): tok = tok[: len(stop)] - if not tok: - raise ValueError("Tok is null after stop string truncation - this is a bug!") - if len(tok) != len(stop): - raise ValueError("Truncated token and stop string have different lengths - this is a bug!") if tok == stop: if i == 0: possible_end_lengths.append(len(tok)) else: matching_positions.append(i) if matching_positions: - strings_to_valid_positions[stop_string][token] = matching_positions + token_valid_positions[stop_string][token] = matching_positions if possible_end_lengths: - strings_to_end_lengths[stop_string][token] = possible_end_lengths - return strings_to_valid_positions, strings_to_end_lengths + token_end_overlaps[stop_string][token] = possible_end_lengths + return token_valid_positions, token_end_overlaps - def create_embedding_vecs(self) -> Dict[str, torch.tensor]: + def _create_embedding_vecs(self) -> Dict[str, torch.tensor]: """ This function builds an embedding matrix for each stop string, consisting of possible valid positions and possible end lengths for each token, and the total length of the token string. When tokens have fewer valid positions or end lengths than the maximum, we pad the vectors with -1. """ + # TODO Matt: Merge the embeddings across all stop strings to save space and reduce gather calls? vocab = self.vocab embedding_vecs = {} for stop_string in self.stop_strings: - positions = self.strings_to_valid_positions[stop_string] - end_lens = self.strings_to_end_lengths[stop_string] - # TODO Matt: Merge the embeddings across all stop strings to save space and reduce gather calls? - - # Since this is lots of very small assignments of lists, we build it with numpy rather - # than torch for speed + simplicity, then convert to torch at the end + positions = self.token_valid_positions[stop_string] + end_lens = self.token_end_overlaps[stop_string] max_valid_positions = self.max_valid_positions[stop_string] max_valid_end_lens = self.max_valid_end_lens[stop_string] vec_size = max_valid_positions + max_valid_end_lens + 1 + # Since this is lots of very small assignments of lists, we build it with numpy rather + # than torch for speed + simplicity, then convert to torch at the end gather_vec = np.full((len(self.vocab), vec_size), dtype=np.int32, fill_value=-1) for token, valid_positions in positions.items(): token_idx = vocab[token] From 7dd346af3d9bf3e1a0da84813aa9a60ac4257bf0 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 16 Feb 2024 15:38:23 +0000 Subject: [PATCH 27/68] Add tests for the sub-methods --- tests/generation/test_stopping_criteria.py | 89 ++++++++++++++++++---- 1 file changed, 75 insertions(+), 14 deletions(-) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index edb7a67f8ae171..afb2c784f6f745 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -132,22 +132,21 @@ def test_stop_string_criteria(self): "<|im_start|><|im_end|<|im_end|>", ">><|im_start|>>stop", "stop", - "end", + "e nd", ] false_strings = [ "<|im_start|><|im_end|", "<|im_start|><|im_end|<|im_end|", "<|im_end|><|im_start|>", "<|im_end|<>stop<|im_end|", + "end", "en d", "eNd", - ] - too_short_strings = [ "<|im_end|", "|im_end|>", "s", ] - stop_strings = ["<|im_end|>", "stop", "end"] + stop_strings = ["<|im_end|>", "stop", "e nd"] # Use a tokenizer that won't actually have special tokens for these tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") @@ -155,30 +154,92 @@ def test_stop_string_criteria(self): tokenizer.padding_side = "left" true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False) false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False) - too_short_input_ids = tokenizer( - too_short_strings, return_tensors="pt", padding="longest", add_special_tokens=False - ) + scores = None criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) for i in range(len(true_strings)): self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores)) for i in range(len(false_strings)): self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores)) - for i in range(len(too_short_strings)): - self.assertFalse(criteria(too_short_input_ids["input_ids"][i : i + 1], scores)) # Now try it with a tokenizer where those are actually special tokens tokenizer = AutoTokenizer.from_pretrained("cognitivecomputations/dolphin-2.5-mixtral-8x7b") tokenizer.padding_side = "left" true_input_ids = tokenizer(true_strings, return_tensors="pt", padding="longest", add_special_tokens=False) false_input_ids = tokenizer(false_strings, return_tensors="pt", padding="longest", add_special_tokens=False) - too_short_input_ids = tokenizer( - too_short_strings, return_tensors="pt", padding="longest", add_special_tokens=False - ) + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) for i in range(len(true_strings)): self.assertTrue(criteria(true_input_ids["input_ids"][i : i + 1], scores)) for i in range(len(false_strings)): self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores)) - for i in range(len(too_short_strings)): - self.assertFalse(criteria(too_short_input_ids["input_ids"][i : i + 1], scores)) + + def test_stop_string_matching_positions(self): + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + stop_strings = ["aaaaaaa", "assdfiugsdf", "stop"] + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) + all_token_valid_positions, all_token_end_overlaps = criteria._get_matching_positions() + for stop_string in stop_strings: + token_valid_positions = all_token_valid_positions[stop_string] + token_end_overlaps = all_token_end_overlaps[stop_string] + for token, valid_positions in token_valid_positions.items(): + token = token.replace("▁", " ").replace("Ġ", " ") + for position in valid_positions: + trim_length = position + len(token) - len(stop_string) + if trim_length > 0: + # This token runs off the start of the string + self.assertTrue(stop_string.startswith(token[trim_length:])) + else: + self.assertTrue(stop_string[-position - len(token) : -position] == token) + for token, end_overlaps in token_end_overlaps.items(): + token = token.replace("▁", " ").replace("Ġ", " ") + for overlap in end_overlaps: + # Either this token runs off the end of the string, + # or the entire stop string is a substring of the token + self.assertTrue( + ( + stop_string.endswith(token[:overlap]) + or (stop_string in token and overlap == len(stop_string)) + ) + ) + + def test_stop_string_embedding_vecs(self): + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + stop_strings = ["aaaaaaa", "assdfiugsdf", "stop"] + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) + all_embedding_vecs = criteria._create_embedding_vecs() + for stop_string in stop_strings: + embedding_vecs = all_embedding_vecs[stop_string] + max_valid_positions = criteria.max_valid_positions[stop_string] + max_valid_end_lens = criteria.max_valid_end_lens[stop_string] + for token, token_idx in criteria.vocab.items(): + vec = embedding_vecs[token_idx].tolist() + # The embedding contains packed valid positions, end overlap lengths, and the total token length + token = token.replace("▁", " ").replace("Ġ", " ") + + token_valid_positions = vec[:max_valid_positions] + for position in token_valid_positions: + if position == -1: + continue # Padding value + trim_length = position + len(token) - len(stop_string) + if trim_length > 0: + # This token runs off the start of the string + self.assertTrue(stop_string.startswith(token[trim_length:])) + else: + self.assertTrue(stop_string[-position - len(token) : -position] == token) + + token_end_overlaps = vec[max_valid_positions : max_valid_positions + max_valid_end_lens] + for overlap in token_end_overlaps: + if overlap == -1: + continue # Padding value + # Either this token runs off the end of the string, + # or the entire stop string is a substring of the token + self.assertTrue( + ( + stop_string.endswith(token[:overlap]) + or (stop_string in token and overlap == len(stop_string)) + ) + ) + + token_length = vec[-1] + self.assertTrue(len(token) == token_length) From 641ba7272095f92a89b2941beff930dd1c09049d Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 16 Feb 2024 15:40:36 +0000 Subject: [PATCH 28/68] Clarify one test slightly --- tests/generation/test_stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index afb2c784f6f745..736812df5665b6 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -190,7 +190,7 @@ def test_stop_string_matching_positions(self): # This token runs off the start of the string self.assertTrue(stop_string.startswith(token[trim_length:])) else: - self.assertTrue(stop_string[-position - len(token) : -position] == token) + self.assertTrue(stop_string[-position - len(token):].startswith(token)) for token, end_overlaps in token_end_overlaps.items(): token = token.replace("▁", " ").replace("Ġ", " ") for overlap in end_overlaps: From f6721a5ab98e17fcbd0d6d4add3d808eb6f8b470 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 16 Feb 2024 16:30:08 +0000 Subject: [PATCH 29/68] Add stop_strings to GenerationConfig --- src/transformers/generation/configuration_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index f40960c213ea67..4504b8b183941b 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -115,6 +115,8 @@ class GenerationConfig(PushToHubMixin): max_time(`float`, *optional*): The maximum amount of time you allow the computation to run for in seconds. generation will still finish the current pass after allocated time has been passed. + stop_strings(`str or List[str]`, *optional*): + A string or a list of strings that should terminate generation if the model outputs them. > Parameters that control the generation strategy used @@ -306,6 +308,7 @@ def __init__(self, **kwargs): self.min_new_tokens = kwargs.pop("min_new_tokens", None) self.early_stopping = kwargs.pop("early_stopping", False) self.max_time = kwargs.pop("max_time", None) + self.stop_strings = kwargs.pop("stop_strings", None) # Parameters that control the generation strategy used self.do_sample = kwargs.pop("do_sample", False) From 8772bcbb484d1d8e2cf7484f4f97f21b42c05473 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 16 Feb 2024 16:35:55 +0000 Subject: [PATCH 30/68] generate() supports stop_string arg, asks for tokenizer if not provided --- src/transformers/generation/utils.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 36e62794a435a1..c68eaef1206e9c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -80,6 +80,7 @@ MaxTimeCriteria, StoppingCriteria, StoppingCriteriaList, + StopStringCriteria, validate_stopping_criteria, ) @@ -881,7 +882,7 @@ def _get_logits_processor( return processors def _get_stopping_criteria( - self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList] + self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList], **kwargs ) -> StoppingCriteriaList: criteria = StoppingCriteriaList() if generation_config.max_length is not None: @@ -894,6 +895,14 @@ def _get_stopping_criteria( ) if generation_config.max_time is not None: criteria.append(MaxTimeCriteria(max_time=generation_config.max_time)) + if generation_config.stop_strings is not None: + if "tokenizer" not in kwargs: + raise ValueError( + "To generate with stop strings, you need to pass the model's tokenizer to the `generate` method." + ) + criteria.append( + StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=kwargs["tokenizer"]) + ) if generation_config.eos_token_id is not None: criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id)) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) @@ -1527,7 +1536,7 @@ def generate( # 9. prepare stopping criteria prepared_stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria + generation_config=generation_config, stopping_criteria=stopping_criteria, **kwargs ) # 10. go into different generation modes if generation_mode == GenerationMode.ASSISTED_GENERATION: From e423417ad75dd2f36f1aacf34313c212a2bf7457 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 16 Feb 2024 16:36:02 +0000 Subject: [PATCH 31/68] make fixup --- tests/generation/test_stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 736812df5665b6..dc8986ed6c6dbe 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -190,7 +190,7 @@ def test_stop_string_matching_positions(self): # This token runs off the start of the string self.assertTrue(stop_string.startswith(token[trim_length:])) else: - self.assertTrue(stop_string[-position - len(token):].startswith(token)) + self.assertTrue(stop_string[-position - len(token) :].startswith(token)) for token, end_overlaps in token_end_overlaps.items(): token = token.replace("▁", " ").replace("Ġ", " ") for overlap in end_overlaps: From 398a799f0e936cdd967373d732192440cefd62eb Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 19 Feb 2024 13:44:33 +0000 Subject: [PATCH 32/68] Cleanup code and rename variables for clarity --- .../generation/stopping_criteria.py | 48 +++++++++++-------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index f81f093e978e52..0cbe2dfc2fdc7b 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -240,42 +240,50 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa flipped_ids = torch.flip(input_ids, (1,)) string_matches = [] for stop_string in self.stop_strings: + # We need the length of the stop string to know how many characters our token sequence should have target_len = len(stop_string) - # Maximum number of internal positions a single token can match + + # Size of the vector of positions a single token can match max_valid_positions = self.max_valid_positions[stop_string] - # Maximum number of different overlap sizes a single token can have with the end of the string + + # Size of the vector of overlap sizes a single token can have with the end of the string max_valid_end_lens = self.max_valid_end_lens[stop_string] + # The embedding vec contains the valid positions, end_lengths and total lengths for each token embedding_vec = self.embedding_vecs[stop_string] embedded = F.embedding(flipped_ids, embedding_vec) - # Starts contains the number of characters from the string, counting from the end, that the token contains - # It can have multiple values if the same token can overlap different slices of the end of the string - # B x 1 x max_valid_end_lens - starts = embedded[:, :1, max_valid_positions : max_valid_positions + max_valid_end_lens] - # Lengths is the total length of each token. Unlike starts, it always has a single value + # end_lengths is the number of characters from the string, counting from the end, that the token + # contains. It can have multiple values if the same token can overlap different end lengths + end_lengths = embedded[:, :1, max_valid_positions : max_valid_positions + max_valid_end_lens] + + # Lengths is the total length of each token. Unlike end_lengths, it always has a single value lengths = embedded[:, 1:, -1:] # B x (maximum_token_len - 1) x 1 - lengths = lengths.expand((-1, -1, starts.shape[-1])) # B x (maximum_token_len - 1) x max_valid_end_lens - lengths_with_starts = torch.cat([starts, lengths], dim=1) # B x maximum_token_len x max_valid_end_lens - # We concatenate each possible starting length with the lengths of the remaining tokens in input_ids - # Then we cumsum() to get the total length of the string after each token - cumsum = lengths_with_starts.cumsum(dim=1) # B x maximum_token_len x max_valid_end_lens - - # Valid positions are the positions in the string that the token can validly appear after - # B x (maximum_token_len - 1) x max_valid_positions + + # Concatenate lengths onto each possible end_lengths value + lengths = lengths.expand((-1, -1, end_lengths.shape[-1])) + lengths_with_ends = torch.cat([end_lengths, lengths], dim=1) # B x maximum_token_len x max_valid_end_lens + + # cumsum() to get the number of matched characters in the stop string after each token + cumsum = lengths_with_ends.cumsum(dim=1) # B x maximum_token_len x max_valid_end_lens + + # The code above assumes that all tokens are in valid positions. This code masks the ones that are not. + # First we get the vector of positions tokens can validly appear in valid_positions = embedded[:, 1:, :max_valid_positions] - # Tokens can match the start of the string if they have any valid value in the starts vector - initial_match = starts > 0 # B x 1 x max_valid_end_lens + + # Tokens can match the start of the string if they have any valid value in the end_lengths vector + initial_match = end_lengths > 0 + # Tokens can continue the string if the cumsum() so far is one of the valid positions for that token - # Note that we're actually tracking one cumsum() for the list for each possible start overhang length - # B x (maximum_token_len - 1) x max_valid_end_lens + # Note that we're actually tracking one cumsum() for for each possible end_length later_match = torch.any(cumsum[:, :-1, None] == valid_positions[:, :, :, None], axis=2) + # The match vector is a boolean vector that indicates which positions have valid tokens match = torch.cat([initial_match, later_match], dim=1) # Once a single position does not match, all positions following that position are masked mask = (~match).cumsum(dim=1, dtype=torch.int32) - mask = mask == 0 # B x maximum_token_len x max_valid_end_lens + mask = mask == 0 # The string is matched if we reached a cumsum equal to or greater than the length of the string # before hitting the mask From e3140a685aee4b8526b3b8bfa977760ad2a42971 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 19 Feb 2024 13:55:56 +0000 Subject: [PATCH 33/68] Update tokenizer error --- src/transformers/generation/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c68eaef1206e9c..6bb6f4620e078c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -898,7 +898,9 @@ def _get_stopping_criteria( if generation_config.stop_strings is not None: if "tokenizer" not in kwargs: raise ValueError( - "To generate with stop strings, you need to pass the model's tokenizer to the `generate` method." + "There are one or more stop strings, either in the arguments to `generate` or in the " + "model's generation config, but we could not locate a tokenizer. When generating with " + "stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`." ) criteria.append( StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=kwargs["tokenizer"]) From 0008722de399f38b8b9c0654de6e54ee8c49ea47 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 19 Feb 2024 14:18:53 +0000 Subject: [PATCH 34/68] Update tokenizer passing, handle generation on GPU --- src/transformers/generation/stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 0cbe2dfc2fdc7b..471d15be5be502 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -250,7 +250,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa max_valid_end_lens = self.max_valid_end_lens[stop_string] # The embedding vec contains the valid positions, end_lengths and total lengths for each token - embedding_vec = self.embedding_vecs[stop_string] + embedding_vec = self.embedding_vecs[stop_string].to(flipped_ids.device) embedded = F.embedding(flipped_ids, embedding_vec) # end_lengths is the number of characters from the string, counting from the end, that the token From a29c131e61315cfb2965692672796fb5b7d1c447 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 19 Feb 2024 14:30:37 +0000 Subject: [PATCH 35/68] Slightly more explanation cleanup --- src/transformers/generation/stopping_criteria.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 471d15be5be502..b62762cf02b831 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -236,8 +236,10 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa # *shorter* than the global max, and the code below should be ready for that maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings]) input_ids = input_ids[:, -maximum_token_len:] + # Flip input_ids because we're only matching strings at the end of the generated sequence flipped_ids = torch.flip(input_ids, (1,)) + string_matches = [] for stop_string in self.stop_strings: # We need the length of the stop string to know how many characters our token sequence should have @@ -258,11 +260,11 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa end_lengths = embedded[:, :1, max_valid_positions : max_valid_positions + max_valid_end_lens] # Lengths is the total length of each token. Unlike end_lengths, it always has a single value - lengths = embedded[:, 1:, -1:] # B x (maximum_token_len - 1) x 1 + lengths = embedded[:, 1:, -1:] # Concatenate lengths onto each possible end_lengths value lengths = lengths.expand((-1, -1, end_lengths.shape[-1])) - lengths_with_ends = torch.cat([end_lengths, lengths], dim=1) # B x maximum_token_len x max_valid_end_lens + lengths_with_ends = torch.cat([end_lengths, lengths], dim=1) # cumsum() to get the number of matched characters in the stop string after each token cumsum = lengths_with_ends.cumsum(dim=1) # B x maximum_token_len x max_valid_end_lens @@ -271,10 +273,10 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa # First we get the vector of positions tokens can validly appear in valid_positions = embedded[:, 1:, :max_valid_positions] - # Tokens can match the start of the string if they have any valid value in the end_lengths vector + # Tokens match the start of the string if they have a positive value in the end_lengths vector initial_match = end_lengths > 0 - # Tokens can continue the string if the cumsum() so far is one of the valid positions for that token + # Tokens continue the string if the cumsum() so far is one of the valid positions for that token # Note that we're actually tracking one cumsum() for for each possible end_length later_match = torch.any(cumsum[:, :-1, None] == valid_positions[:, :, :, None], axis=2) From 9c359ffe2fc5b6f6156049931549adfd7f371055 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 19 Feb 2024 14:52:37 +0000 Subject: [PATCH 36/68] More comment cleanup --- src/transformers/generation/stopping_criteria.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index b62762cf02b831..2625adba46e710 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -254,12 +254,12 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa # The embedding vec contains the valid positions, end_lengths and total lengths for each token embedding_vec = self.embedding_vecs[stop_string].to(flipped_ids.device) embedded = F.embedding(flipped_ids, embedding_vec) - + # Now we split the embedding vector. valid_positions is the positions in the stop string the token can fit + valid_positions = embedded[:, 1:, :max_valid_positions] # end_lengths is the number of characters from the string, counting from the end, that the token # contains. It can have multiple values if the same token can overlap different end lengths end_lengths = embedded[:, :1, max_valid_positions : max_valid_positions + max_valid_end_lens] - - # Lengths is the total length of each token. Unlike end_lengths, it always has a single value + # Lengths is the total length of each token. Unlike the others, it always has a single value lengths = embedded[:, 1:, -1:] # Concatenate lengths onto each possible end_lengths value @@ -269,11 +269,8 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa # cumsum() to get the number of matched characters in the stop string after each token cumsum = lengths_with_ends.cumsum(dim=1) # B x maximum_token_len x max_valid_end_lens - # The code above assumes that all tokens are in valid positions. This code masks the ones that are not. - # First we get the vector of positions tokens can validly appear in - valid_positions = embedded[:, 1:, :max_valid_positions] - - # Tokens match the start of the string if they have a positive value in the end_lengths vector + # The calculation above assumes that all tokens are in valid positions. Now we mask the ones that are not. + # First, tokens match the start of the string if they have a positive value in the end_lengths vector initial_match = end_lengths > 0 # Tokens continue the string if the cumsum() so far is one of the valid positions for that token From 602222dca8669951bc561f1bd682fce5ddd9364d Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 19 Feb 2024 15:23:42 +0000 Subject: [PATCH 37/68] Factor out the token cleanup so it's more obvious what we're doing, and we can change it later --- src/transformers/generation/stopping_criteria.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 2625adba46e710..d0f8ca94547177 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -163,13 +163,19 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, } self.embedding_vecs = self._create_embedding_vecs() + @staticmethod + def _cleanup_token(token: str) -> str: + if token[0] in ["▁", "Ġ"]: + token = " " + token[1:] + return token + def _get_matching_positions(self) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]: """This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can validly appear in the stop strings. For each stop string, it returns a dictionary mapping tokens to a list of valid positions, as well as a dictionary mapping tokens to a list of possible overlap lengths at the end of the stop string.""" tok_list = list(self.vocab.keys()) - reversed_filtered_tok_list = [token[::-1].replace("▁", " ").replace("Ġ", " ") for token in tok_list] + reversed_filtered_tok_list = [self._cleanup_token(token[::-1]) for token in tok_list] token_valid_positions = {} token_end_overlaps = {} for stop_string in self.stop_strings: From 4c7a7777991353e6f21d3c44b48a41c774d97c43 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 19 Feb 2024 15:37:16 +0000 Subject: [PATCH 38/68] Careful with that cleanup! --- src/transformers/generation/stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index d0f8ca94547177..b99567a774c723 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -175,7 +175,7 @@ def _get_matching_positions(self) -> Tuple[Dict[str, Dict[str, List[int]]], Dict valid positions, as well as a dictionary mapping tokens to a list of possible overlap lengths at the end of the stop string.""" tok_list = list(self.vocab.keys()) - reversed_filtered_tok_list = [self._cleanup_token(token[::-1]) for token in tok_list] + reversed_filtered_tok_list = [self._cleanup_token(token)[::-1] for token in tok_list] token_valid_positions = {} token_end_overlaps = {} for stop_string in self.stop_strings: From b6e01639162933e9567b60ef5428aa5d478f5802 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 23 Feb 2024 16:34:15 +0000 Subject: [PATCH 39/68] Cleanup + optimizations to _get_matching_positions --- src/transformers/generation/stopping_criteria.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index b99567a774c723..faa634e743869a 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -186,19 +186,18 @@ def _get_matching_positions(self) -> Tuple[Dict[str, Dict[str, List[int]]], Dict matching_positions = [] possible_end_lengths = [] for i in range(1 - len(token), len(stop_string)): - tok = reversed_filtered_token - stop = reversed_stop_string if i < 0: - tok = tok[-i:] + tok = reversed_filtered_token[-i:] i = 0 - stop = stop[i : i + len(tok)] - if len(tok) > len(stop): - tok = tok[: len(stop)] - if tok == stop: + else: + tok = reversed_filtered_token + stop = reversed_stop_string[i : i + len(tok)] + if tok.startswith(stop): if i == 0: - possible_end_lengths.append(len(tok)) + possible_end_lengths.append(min(len(tok), len(stop))) else: matching_positions.append(i) + if matching_positions: token_valid_positions[stop_string][token] = matching_positions if possible_end_lengths: From 43d9e084ae8d39cda8b6d8b9937937db3d259443 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 23 Feb 2024 17:14:26 +0000 Subject: [PATCH 40/68] More minor performance tweaks --- src/transformers/generation/stopping_criteria.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index faa634e743869a..5d2aff7b1154ac 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -232,7 +232,7 @@ def _create_embedding_vecs(self) -> Dict[str, torch.tensor]: ] = possible_end_lens for token, token_idx in vocab.items(): gather_vec[token_idx, -1] = len(token) - embedding_vecs[stop_string] = torch.tensor(gather_vec, dtype=torch.int32) + embedding_vecs[stop_string] = torch.tensor(gather_vec, dtype=torch.int32).pin_memory() return embedding_vecs @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) @@ -247,6 +247,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa string_matches = [] for stop_string in self.stop_strings: + embedding_vec = self.embedding_vecs[stop_string].to(flipped_ids.device, non_blocking=True) # We need the length of the stop string to know how many characters our token sequence should have target_len = len(stop_string) @@ -257,7 +258,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa max_valid_end_lens = self.max_valid_end_lens[stop_string] # The embedding vec contains the valid positions, end_lengths and total lengths for each token - embedding_vec = self.embedding_vecs[stop_string].to(flipped_ids.device) + embedded = F.embedding(flipped_ids, embedding_vec) # Now we split the embedding vector. valid_positions is the positions in the stop string the token can fit valid_positions = embedded[:, 1:, :max_valid_positions] From 60eb5769e935b3fa51b327fa87bfca52a3a183d4 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 23 Feb 2024 18:01:58 +0000 Subject: [PATCH 41/68] Implement caching and eliminate some expensive ops (startup time: 200ms -> 9ms) --- .../generation/stopping_criteria.py | 176 +++++++++--------- tests/generation/test_stopping_criteria.py | 24 ++- 2 files changed, 110 insertions(+), 90 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 5d2aff7b1154ac..2a850b1f650a9c 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -2,6 +2,7 @@ import warnings from abc import ABC from copy import deepcopy +from functools import lru_cache from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -150,90 +151,12 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, stop_strings = [stop_strings] self.vocab = tokenizer.get_vocab() - self.stop_strings: List[str] = stop_strings - self.token_valid_positions, self.token_end_overlaps = self._get_matching_positions() - - self.max_valid_positions = { - stop_string: max([len(val) for val in self.token_valid_positions[stop_string].values()]) - for stop_string in stop_strings - } - self.max_valid_end_lens = { - stop_string: max([len(val) for val in self.token_end_overlaps[stop_string].values()]) - for stop_string in stop_strings - } - self.embedding_vecs = self._create_embedding_vecs() - - @staticmethod - def _cleanup_token(token: str) -> str: - if token[0] in ["▁", "Ġ"]: - token = " " + token[1:] - return token + self.tok_list, self.tok_indices = tuple(self.vocab.keys()), tuple(self.vocab.values()) + self.stop_strings: Tuple[str, ...] = tuple(stop_strings) - def _get_matching_positions(self) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]: - """This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can - validly appear in the stop strings. For each stop string, it returns a dictionary mapping tokens to a list of - valid positions, as well as a dictionary mapping tokens to a list of possible overlap lengths at the - end of the stop string.""" - tok_list = list(self.vocab.keys()) - reversed_filtered_tok_list = [self._cleanup_token(token)[::-1] for token in tok_list] - token_valid_positions = {} - token_end_overlaps = {} - for stop_string in self.stop_strings: - reversed_stop_string = stop_string[::-1] - token_valid_positions[stop_string] = {} - token_end_overlaps[stop_string] = {} - for token, reversed_filtered_token in zip(tok_list, reversed_filtered_tok_list): - matching_positions = [] - possible_end_lengths = [] - for i in range(1 - len(token), len(stop_string)): - if i < 0: - tok = reversed_filtered_token[-i:] - i = 0 - else: - tok = reversed_filtered_token - stop = reversed_stop_string[i : i + len(tok)] - if tok.startswith(stop): - if i == 0: - possible_end_lengths.append(min(len(tok), len(stop))) - else: - matching_positions.append(i) - - if matching_positions: - token_valid_positions[stop_string][token] = matching_positions - if possible_end_lengths: - token_end_overlaps[stop_string][token] = possible_end_lengths - return token_valid_positions, token_end_overlaps - - def _create_embedding_vecs(self) -> Dict[str, torch.tensor]: - """ - This function builds an embedding matrix for each stop string, consisting of possible valid positions - and possible end lengths for each token, and the total length of the token string. When tokens have - fewer valid positions or end lengths than the maximum, we pad the vectors with -1. - """ - # TODO Matt: Merge the embeddings across all stop strings to save space and reduce gather calls? - vocab = self.vocab - embedding_vecs = {} - for stop_string in self.stop_strings: - positions = self.token_valid_positions[stop_string] - end_lens = self.token_end_overlaps[stop_string] - max_valid_positions = self.max_valid_positions[stop_string] - max_valid_end_lens = self.max_valid_end_lens[stop_string] - vec_size = max_valid_positions + max_valid_end_lens + 1 - # Since this is lots of very small assignments of lists, we build it with numpy rather - # than torch for speed + simplicity, then convert to torch at the end - gather_vec = np.full((len(self.vocab), vec_size), dtype=np.int32, fill_value=-1) - for token, valid_positions in positions.items(): - token_idx = vocab[token] - gather_vec[token_idx, : len(valid_positions)] = valid_positions - for token, possible_end_lens in end_lens.items(): - token_idx = vocab[token] - gather_vec[ - token_idx, max_valid_positions : max_valid_positions + len(possible_end_lens) - ] = possible_end_lens - for token, token_idx in vocab.items(): - gather_vec[token_idx, -1] = len(token) - embedding_vecs[stop_string] = torch.tensor(gather_vec, dtype=torch.int32).pin_memory() - return embedding_vecs + self.embedding_vecs, self.max_valid_positions, self.max_valid_end_lens = _stop_string_create_embedding_vecs( + self.tok_list, self.tok_indices, self.stop_strings + ) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: @@ -346,3 +269,90 @@ def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_leng elif stopping_max_length is None: new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) return new_stopping_criteria + + +def _stop_string_get_matching_positions( + tok_list, tok_indices, stop_strings +) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]: + """This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can + validly appear in the stop strings. For each stop string, it returns a dictionary mapping tokens to a list of + valid positions, as well as a dictionary mapping tokens to a list of possible overlap lengths at the + end of the stop string.""" + + def _cleanup_token(token: str) -> str: + if token[0] in ["▁", "Ġ"]: + token = " " + token[1:] + return token + + reversed_filtered_tok_list = [_cleanup_token(token)[::-1] for token in tok_list] + token_valid_positions = {} + token_end_overlaps = {} + for stop_string in stop_strings: + reversed_stop_string = stop_string[::-1] + token_valid_positions[stop_string] = {} + token_end_overlaps[stop_string] = {} + for token, reversed_filtered_token, tok_idx in zip(tok_list, reversed_filtered_tok_list, tok_indices): + matching_positions = [] + possible_end_lengths = [] + for i in range(1 - len(token), len(stop_string)): + if i < 0: + tok = reversed_filtered_token[-i:] + i = 0 + else: + tok = reversed_filtered_token + stop = reversed_stop_string[i : i + len(tok)] + if tok.startswith(stop): + if i == 0: + possible_end_lengths.append(min(len(tok), len(stop))) + else: + matching_positions.append(i) + + if matching_positions: + token_valid_positions[stop_string][tok_idx] = matching_positions + if possible_end_lengths: + token_end_overlaps[stop_string][tok_idx] = possible_end_lengths + return token_valid_positions, token_end_overlaps + + +@lru_cache(8) +def _stop_string_create_embedding_vecs(tok_list, tok_indices, stop_strings) -> Dict[str, torch.tensor]: + """ + This function builds an embedding matrix for each stop string, consisting of possible valid positions + and possible end lengths for each token, and the total length of the token string. When tokens have + fewer valid positions or end lengths than the maximum, we pad the vectors with -1. + """ + # TODO Matt: Merge the embeddings across all stop strings to save space and reduce gather calls? + token_valid_positions, token_end_overlaps = _stop_string_get_matching_positions( + tok_list, tok_indices, stop_strings + ) + + embedding_vecs = {} + for stop_string in stop_strings: + positions = token_valid_positions[stop_string] + end_lens = token_end_overlaps[stop_string] + max_valid_positions = max([len(val) for val in positions.values()]) + max_valid_end_lens = max([len(val) for val in end_lens.values()]) + vec_size = max_valid_positions + max_valid_end_lens + 1 + # Since this is lots of very small assignments of lists, we build it with numpy rather + # than torch for speed + simplicity, then convert to torch at the end + gather_vec = np.full((len(tok_list), vec_size), dtype=np.int32, fill_value=-1) + for token_idx, valid_positions in positions.items(): + gather_vec[token_idx, : len(valid_positions)] = valid_positions + for token_idx, possible_end_lens in end_lens.items(): + gather_vec[ + token_idx, max_valid_positions : max_valid_positions + len(possible_end_lens) + ] = possible_end_lens + for token, token_idx in zip(tok_list, tok_indices): + gather_vec[token_idx, -1] = len(token) + embedding_vecs[stop_string] = torch.tensor(gather_vec, dtype=torch.int32).pin_memory() + + # TODO Remove this block and stop returning these values after the embedding vec is merged + max_valid_positions = { + stop_string: max([len(val) for val in token_valid_positions[stop_string].values()]) + for stop_string in stop_strings + } + max_valid_end_lens = { + stop_string: max([len(val) for val in token_end_overlaps[stop_string].values()]) + for stop_string in stop_strings + } + return embedding_vecs, max_valid_positions, max_valid_end_lens diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index dc8986ed6c6dbe..dcf242944a171d 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -35,6 +35,11 @@ validate_stopping_criteria, ) +from transformers.generation.stopping_criteria import ( + _stop_string_create_embedding_vecs, + _stop_string_get_matching_positions, +) + @require_torch class StoppingCriteriaTestCase(unittest.TestCase): @@ -178,12 +183,15 @@ def test_stop_string_matching_positions(self): tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") stop_strings = ["aaaaaaa", "assdfiugsdf", "stop"] criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) - all_token_valid_positions, all_token_end_overlaps = criteria._get_matching_positions() + idx_to_token = {v: k for k, v in tokenizer.get_vocab().items()} + all_token_valid_positions, all_token_end_overlaps = _stop_string_get_matching_positions( + tok_list=criteria.tok_list, tok_indices=criteria.tok_indices, stop_strings=criteria.stop_strings + ) for stop_string in stop_strings: token_valid_positions = all_token_valid_positions[stop_string] token_end_overlaps = all_token_end_overlaps[stop_string] - for token, valid_positions in token_valid_positions.items(): - token = token.replace("▁", " ").replace("Ġ", " ") + for token_idx, valid_positions in token_valid_positions.items(): + token = idx_to_token[token_idx].replace("▁", " ").replace("Ġ", " ") for position in valid_positions: trim_length = position + len(token) - len(stop_string) if trim_length > 0: @@ -191,8 +199,8 @@ def test_stop_string_matching_positions(self): self.assertTrue(stop_string.startswith(token[trim_length:])) else: self.assertTrue(stop_string[-position - len(token) :].startswith(token)) - for token, end_overlaps in token_end_overlaps.items(): - token = token.replace("▁", " ").replace("Ġ", " ") + for token_idx, end_overlaps in token_end_overlaps.items(): + token = idx_to_token[token_idx].replace("▁", " ").replace("Ġ", " ") for overlap in end_overlaps: # Either this token runs off the end of the string, # or the entire stop string is a substring of the token @@ -207,12 +215,14 @@ def test_stop_string_embedding_vecs(self): tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") stop_strings = ["aaaaaaa", "assdfiugsdf", "stop"] criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) - all_embedding_vecs = criteria._create_embedding_vecs() + all_embedding_vecs, *_ = _stop_string_create_embedding_vecs( + tok_list=criteria.tok_list, tok_indices=criteria.tok_indices, stop_strings=criteria.stop_strings + ) for stop_string in stop_strings: embedding_vecs = all_embedding_vecs[stop_string] max_valid_positions = criteria.max_valid_positions[stop_string] max_valid_end_lens = criteria.max_valid_end_lens[stop_string] - for token, token_idx in criteria.vocab.items(): + for token, token_idx in zip(criteria.tok_list, criteria.tok_indices): vec = embedding_vecs[token_idx].tolist() # The embedding contains packed valid positions, end overlap lengths, and the total token length token = token.replace("▁", " ").replace("Ġ", " ") From ff422118a4d9fe898a6824e3a81079a0a0797388 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 23 Feb 2024 18:12:19 +0000 Subject: [PATCH 42/68] Remove the pin_memory call --- src/transformers/generation/stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 2a850b1f650a9c..17b42093aa5675 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -344,7 +344,7 @@ def _stop_string_create_embedding_vecs(tok_list, tok_indices, stop_strings) -> D ] = possible_end_lens for token, token_idx in zip(tok_list, tok_indices): gather_vec[token_idx, -1] = len(token) - embedding_vecs[stop_string] = torch.tensor(gather_vec, dtype=torch.int32).pin_memory() + embedding_vecs[stop_string] = torch.tensor(gather_vec, dtype=torch.int32) # TODO Remove this block and stop returning these values after the embedding vec is merged max_valid_positions = { From ae800a66c6f90b7cc3081db368af6d957f01003e Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 23 Feb 2024 18:58:55 +0000 Subject: [PATCH 43/68] Parallelize across all stop strings! --- .../generation/stopping_criteria.py | 125 +++++++++--------- tests/generation/test_stopping_criteria.py | 85 ++++++------ 2 files changed, 103 insertions(+), 107 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 17b42093aa5675..40634fba0b3a74 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -154,71 +154,69 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, self.tok_list, self.tok_indices = tuple(self.vocab.keys()), tuple(self.vocab.values()) self.stop_strings: Tuple[str, ...] = tuple(stop_strings) - self.embedding_vecs, self.max_valid_positions, self.max_valid_end_lens = _stop_string_create_embedding_vecs( + self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = _stop_string_create_embedding_vec( self.tok_list, self.tok_indices, self.stop_strings ) + self.maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings]) + self.num_stop_strings = len(self.stop_strings) + self.target_lens = torch.tensor([len(stop_string) for stop_string in stop_strings], dtype=torch.int32) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + embedding_vec = self.embedding_vec.to(input_ids.device) # The maximum length we need to consider is 1 token per character. Note that input_ids can also be # *shorter* than the global max, and the code below should be ready for that - maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings]) - input_ids = input_ids[:, -maximum_token_len:] + input_ids = input_ids[:, -self.maximum_token_len :] # Flip input_ids because we're only matching strings at the end of the generated sequence flipped_ids = torch.flip(input_ids, (1,)) - string_matches = [] - for stop_string in self.stop_strings: - embedding_vec = self.embedding_vecs[stop_string].to(flipped_ids.device, non_blocking=True) - # We need the length of the stop string to know how many characters our token sequence should have - target_len = len(stop_string) + # Size of the vector of positions a single token can match + max_valid_positions = self.max_valid_positions - # Size of the vector of positions a single token can match - max_valid_positions = self.max_valid_positions[stop_string] + # The embedding vec contains the valid positions, end_lengths and total lengths for each token + embedded = F.embedding(flipped_ids, embedding_vec) - # Size of the vector of overlap sizes a single token can have with the end of the string - max_valid_end_lens = self.max_valid_end_lens[stop_string] - - # The embedding vec contains the valid positions, end_lengths and total lengths for each token - - embedded = F.embedding(flipped_ids, embedding_vec) - # Now we split the embedding vector. valid_positions is the positions in the stop string the token can fit - valid_positions = embedded[:, 1:, :max_valid_positions] - # end_lengths is the number of characters from the string, counting from the end, that the token - # contains. It can have multiple values if the same token can overlap different end lengths - end_lengths = embedded[:, :1, max_valid_positions : max_valid_positions + max_valid_end_lens] - # Lengths is the total length of each token. Unlike the others, it always has a single value - lengths = embedded[:, 1:, -1:] + # Now we split the embedding vector. valid_positions is the positions in the stop string the token can fit + valid_positions = embedded[:, 1:, : max_valid_positions * self.num_stop_strings].unflatten( + -1, (self.num_stop_strings, -1) + ) + # end_lengths is the number of characters from the string, counting from the end, that the token + # contains. It can have multiple values if the same token can overlap different end lengths + end_lengths = embedded[:, :1, max_valid_positions * self.num_stop_strings : -1].unflatten( + -1, (self.num_stop_strings, -1) + ) + # Lengths is the total length of each token. Unlike the others, it always has a single value + lengths = embedded[:, 1:, None, -1:] # Insert a dummy dimension for stop_strings even though lengths are const - # Concatenate lengths onto each possible end_lengths value - lengths = lengths.expand((-1, -1, end_lengths.shape[-1])) - lengths_with_ends = torch.cat([end_lengths, lengths], dim=1) + # Concatenate lengths onto each possible end_lengths value + lengths = lengths.expand((-1, -1, end_lengths.shape[-2], end_lengths.shape[-1])) + lengths_with_ends = torch.cat([end_lengths, lengths], dim=1) - # cumsum() to get the number of matched characters in the stop string after each token - cumsum = lengths_with_ends.cumsum(dim=1) # B x maximum_token_len x max_valid_end_lens + # cumsum() to get the number of matched characters in the stop string after each token + cumsum = lengths_with_ends.cumsum(dim=1) # B x maximum_token_len x num_stop_strings x max_valid_end_lens - # The calculation above assumes that all tokens are in valid positions. Now we mask the ones that are not. - # First, tokens match the start of the string if they have a positive value in the end_lengths vector - initial_match = end_lengths > 0 + # The calculation above assumes that all tokens are in valid positions. Now we mask the ones that are not. + # First, tokens match the start of the string if they have a positive value in the end_lengths vector + initial_match = end_lengths > 0 - # Tokens continue the string if the cumsum() so far is one of the valid positions for that token - # Note that we're actually tracking one cumsum() for for each possible end_length - later_match = torch.any(cumsum[:, :-1, None] == valid_positions[:, :, :, None], axis=2) + # Tokens continue the string if the cumsum() so far is one of the valid positions for that token + # Note that we're actually tracking one cumsum() for for each possible end_length + later_match = torch.any(cumsum[:, :-1, :, None] == valid_positions[:, :, :, :, None], axis=-2) - # The match vector is a boolean vector that indicates which positions have valid tokens - match = torch.cat([initial_match, later_match], dim=1) + # The match vector is a boolean vector that indicates which positions have valid tokens + match = torch.cat([initial_match, later_match], dim=1) - # Once a single position does not match, all positions following that position are masked - mask = (~match).cumsum(dim=1, dtype=torch.int32) - mask = mask == 0 + # Once a single position does not match, all positions following that position are masked + mask = (~match).cumsum(dim=1, dtype=torch.int32) + mask = mask == 0 - # The string is matched if we reached a cumsum equal to or greater than the length of the string - # before hitting the mask - string_matches.append(torch.amax(cumsum * mask, dim=(1, 2)) >= target_len) + # The string is matched if we reached a cumsum equal to or greater than the length of the string + # before hitting the mask + string_matches = torch.amax(cumsum * mask, dim=(1, -1)) >= self.target_lens[None, :] # Now we concatenate the match booleans across all strings and check if any are True - string_matches = torch.cat(string_matches, dim=0) + # TODO After Raushan's PR, return a per-sample vector here return torch.any(string_matches).item() @@ -315,7 +313,7 @@ def _cleanup_token(token: str) -> str: @lru_cache(8) -def _stop_string_create_embedding_vecs(tok_list, tok_indices, stop_strings) -> Dict[str, torch.tensor]: +def _stop_string_create_embedding_vec(tok_list, tok_indices, stop_strings) -> Dict[str, torch.tensor]: """ This function builds an embedding matrix for each stop string, consisting of possible valid positions and possible end lengths for each token, and the total length of the token string. When tokens have @@ -326,33 +324,32 @@ def _stop_string_create_embedding_vecs(tok_list, tok_indices, stop_strings) -> D tok_list, tok_indices, stop_strings ) - embedding_vecs = {} - for stop_string in stop_strings: + max_valid_positions = max([len(val) for positions in token_valid_positions.values() for val in positions.values()]) + max_valid_end_lens = max([len(val) for positions in token_end_overlaps.values() for val in positions.values()]) + vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1 + gather_vec = np.full((len(tok_list), vec_size), dtype=np.int32, fill_value=-1) + + for i, stop_string in enumerate(stop_strings): positions = token_valid_positions[stop_string] end_lens = token_end_overlaps[stop_string] - max_valid_positions = max([len(val) for val in positions.values()]) - max_valid_end_lens = max([len(val) for val in end_lens.values()]) - vec_size = max_valid_positions + max_valid_end_lens + 1 + # Since this is lots of very small assignments of lists, we build it with numpy rather # than torch for speed + simplicity, then convert to torch at the end - gather_vec = np.full((len(tok_list), vec_size), dtype=np.int32, fill_value=-1) for token_idx, valid_positions in positions.items(): - gather_vec[token_idx, : len(valid_positions)] = valid_positions + gather_vec[ + token_idx, max_valid_positions * i : max_valid_positions * i + len(valid_positions) + ] = valid_positions for token_idx, possible_end_lens in end_lens.items(): gather_vec[ - token_idx, max_valid_positions : max_valid_positions + len(possible_end_lens) + token_idx, + max_valid_positions * len(stop_strings) + max_valid_end_lens * i : max_valid_positions + * len(stop_strings) + + max_valid_end_lens * i + + len(possible_end_lens), ] = possible_end_lens for token, token_idx in zip(tok_list, tok_indices): gather_vec[token_idx, -1] = len(token) - embedding_vecs[stop_string] = torch.tensor(gather_vec, dtype=torch.int32) - - # TODO Remove this block and stop returning these values after the embedding vec is merged - max_valid_positions = { - stop_string: max([len(val) for val in token_valid_positions[stop_string].values()]) - for stop_string in stop_strings - } - max_valid_end_lens = { - stop_string: max([len(val) for val in token_end_overlaps[stop_string].values()]) - for stop_string in stop_strings - } - return embedding_vecs, max_valid_positions, max_valid_end_lens + + gather_vec = torch.tensor(gather_vec, dtype=torch.int32) + + return gather_vec, max_valid_positions, max_valid_end_lens diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index dcf242944a171d..e9b2c11a77593e 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -36,7 +36,6 @@ ) from transformers.generation.stopping_criteria import ( - _stop_string_create_embedding_vecs, _stop_string_get_matching_positions, ) @@ -211,45 +210,45 @@ def test_stop_string_matching_positions(self): ) ) - def test_stop_string_embedding_vecs(self): - tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - stop_strings = ["aaaaaaa", "assdfiugsdf", "stop"] - criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) - all_embedding_vecs, *_ = _stop_string_create_embedding_vecs( - tok_list=criteria.tok_list, tok_indices=criteria.tok_indices, stop_strings=criteria.stop_strings - ) - for stop_string in stop_strings: - embedding_vecs = all_embedding_vecs[stop_string] - max_valid_positions = criteria.max_valid_positions[stop_string] - max_valid_end_lens = criteria.max_valid_end_lens[stop_string] - for token, token_idx in zip(criteria.tok_list, criteria.tok_indices): - vec = embedding_vecs[token_idx].tolist() - # The embedding contains packed valid positions, end overlap lengths, and the total token length - token = token.replace("▁", " ").replace("Ġ", " ") - - token_valid_positions = vec[:max_valid_positions] - for position in token_valid_positions: - if position == -1: - continue # Padding value - trim_length = position + len(token) - len(stop_string) - if trim_length > 0: - # This token runs off the start of the string - self.assertTrue(stop_string.startswith(token[trim_length:])) - else: - self.assertTrue(stop_string[-position - len(token) : -position] == token) - - token_end_overlaps = vec[max_valid_positions : max_valid_positions + max_valid_end_lens] - for overlap in token_end_overlaps: - if overlap == -1: - continue # Padding value - # Either this token runs off the end of the string, - # or the entire stop string is a substring of the token - self.assertTrue( - ( - stop_string.endswith(token[:overlap]) - or (stop_string in token and overlap == len(stop_string)) - ) - ) - - token_length = vec[-1] - self.assertTrue(len(token) == token_length) + # def test_stop_string_embedding_vecs(self): + # tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + # stop_strings = ["aaaaaaa", "assdfiugsdf", "stop"] + # criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) + # all_embedding_vecs, *_ = _stop_string_create_embedding_vec( + # tok_list=criteria.tok_list, tok_indices=criteria.tok_indices, stop_strings=criteria.stop_strings + # ) + # for stop_string in stop_strings: + # embedding_vecs = all_embedding_vecs[stop_string] + # max_valid_positions = criteria.max_valid_positions[stop_string] + # max_valid_end_lens = criteria.max_valid_end_lens[stop_string] + # for token, token_idx in zip(criteria.tok_list, criteria.tok_indices): + # vec = embedding_vecs[token_idx].tolist() + # # The embedding contains packed valid positions, end overlap lengths, and the total token length + # token = token.replace("▁", " ").replace("Ġ", " ") + # + # token_valid_positions = vec[:max_valid_positions] + # for position in token_valid_positions: + # if position == -1: + # continue # Padding value + # trim_length = position + len(token) - len(stop_string) + # if trim_length > 0: + # # This token runs off the start of the string + # self.assertTrue(stop_string.startswith(token[trim_length:])) + # else: + # self.assertTrue(stop_string[-position - len(token) : -position] == token) + # + # token_end_overlaps = vec[max_valid_positions : max_valid_positions + max_valid_end_lens] + # for overlap in token_end_overlaps: + # if overlap == -1: + # continue # Padding value + # # Either this token runs off the end of the string, + # # or the entire stop string is a substring of the token + # self.assertTrue( + # ( + # stop_string.endswith(token[:overlap]) + # or (stop_string in token and overlap == len(stop_string)) + # ) + # ) + # + # token_length = vec[-1] + # self.assertTrue(len(token) == token_length) From 46c0a9c6df8cd21d4fc4842727a9a48ca529189a Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 23 Feb 2024 19:03:53 +0000 Subject: [PATCH 44/68] Quick fix for tensor devices --- src/transformers/generation/stopping_criteria.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 40634fba0b3a74..d8dc01b5a4fd9d 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -164,6 +164,7 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: embedding_vec = self.embedding_vec.to(input_ids.device) + target_lens = self.target_lens.to(input_ids.device) # The maximum length we need to consider is 1 token per character. Note that input_ids can also be # *shorter* than the global max, and the code below should be ready for that input_ids = input_ids[:, -self.maximum_token_len :] @@ -213,7 +214,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa # The string is matched if we reached a cumsum equal to or greater than the length of the string # before hitting the mask - string_matches = torch.amax(cumsum * mask, dim=(1, -1)) >= self.target_lens[None, :] + string_matches = torch.amax(cumsum * mask, dim=(1, -1)) >= target_lens[None, :] # Now we concatenate the match booleans across all strings and check if any are True # TODO After Raushan's PR, return a per-sample vector here From b9a066d34d9d35f7f26c2a72a70bf1a2db6c04da Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 23 Feb 2024 19:11:59 +0000 Subject: [PATCH 45/68] Update embeddings test for the new format --- tests/generation/test_stopping_criteria.py | 89 ++++++++++++---------- 1 file changed, 47 insertions(+), 42 deletions(-) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index e9b2c11a77593e..fdf792873c0d1c 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -36,6 +36,7 @@ ) from transformers.generation.stopping_criteria import ( + _stop_string_create_embedding_vec, _stop_string_get_matching_positions, ) @@ -210,45 +211,49 @@ def test_stop_string_matching_positions(self): ) ) - # def test_stop_string_embedding_vecs(self): - # tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - # stop_strings = ["aaaaaaa", "assdfiugsdf", "stop"] - # criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) - # all_embedding_vecs, *_ = _stop_string_create_embedding_vec( - # tok_list=criteria.tok_list, tok_indices=criteria.tok_indices, stop_strings=criteria.stop_strings - # ) - # for stop_string in stop_strings: - # embedding_vecs = all_embedding_vecs[stop_string] - # max_valid_positions = criteria.max_valid_positions[stop_string] - # max_valid_end_lens = criteria.max_valid_end_lens[stop_string] - # for token, token_idx in zip(criteria.tok_list, criteria.tok_indices): - # vec = embedding_vecs[token_idx].tolist() - # # The embedding contains packed valid positions, end overlap lengths, and the total token length - # token = token.replace("▁", " ").replace("Ġ", " ") - # - # token_valid_positions = vec[:max_valid_positions] - # for position in token_valid_positions: - # if position == -1: - # continue # Padding value - # trim_length = position + len(token) - len(stop_string) - # if trim_length > 0: - # # This token runs off the start of the string - # self.assertTrue(stop_string.startswith(token[trim_length:])) - # else: - # self.assertTrue(stop_string[-position - len(token) : -position] == token) - # - # token_end_overlaps = vec[max_valid_positions : max_valid_positions + max_valid_end_lens] - # for overlap in token_end_overlaps: - # if overlap == -1: - # continue # Padding value - # # Either this token runs off the end of the string, - # # or the entire stop string is a substring of the token - # self.assertTrue( - # ( - # stop_string.endswith(token[:overlap]) - # or (stop_string in token and overlap == len(stop_string)) - # ) - # ) - # - # token_length = vec[-1] - # self.assertTrue(len(token) == token_length) + def test_stop_string_embedding_vecs(self): + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + stop_strings = ["aaaaaaa", "assdfiugsdf", "stop"] + criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) + embedding_vec, max_valid_positions, max_valid_end_lens = _stop_string_create_embedding_vec( + tok_list=criteria.tok_list, tok_indices=criteria.tok_indices, stop_strings=criteria.stop_strings + ) + valid_positions_vec = embedding_vec[:, : max_valid_positions * len(stop_strings)].unflatten( + -1, (len(stop_strings), -1) + ) + end_overlaps_vec = embedding_vec[:, max_valid_positions * len(stop_strings) : -1].unflatten( + -1, (len(stop_strings), -1) + ) + token_lengths = embedding_vec[:, -1] + + for i, stop_string in enumerate(stop_strings): + for token, token_idx in zip(criteria.tok_list, criteria.tok_indices): + # The embedding contains packed valid positions, end overlap lengths, and the total token length + token = token.replace("▁", " ").replace("Ġ", " ") + + token_valid_positions = valid_positions_vec[token_idx, i].tolist() + for position in token_valid_positions: + if position == -1: + continue # Padding value + trim_length = position + len(token) - len(stop_string) + if trim_length > 0: + # This token runs off the start of the string + self.assertTrue(stop_string.startswith(token[trim_length:])) + else: + self.assertTrue(stop_string[-position - len(token) : -position] == token) + + token_end_overlaps = end_overlaps_vec[token_idx, i].tolist() + for overlap in token_end_overlaps: + if overlap == -1: + continue # Padding value + # Either this token runs off the end of the string, + # or the entire stop string is a substring of the token + self.assertTrue( + ( + stop_string.endswith(token[:overlap]) + or (stop_string in token and overlap == len(stop_string)) + ) + ) + + token_length = token_lengths[token_idx].item() + self.assertTrue(len(token) == token_length) From 692523c4fff3188370c366cd2ce4de80f41bf0db Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 23 Feb 2024 19:25:58 +0000 Subject: [PATCH 46/68] Fix test imports --- src/transformers/generation/stopping_criteria.py | 9 ++++----- tests/generation/test_stopping_criteria.py | 9 ++++----- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index d8dc01b5a4fd9d..76bfc8677a2b5c 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -163,8 +163,8 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: - embedding_vec = self.embedding_vec.to(input_ids.device) - target_lens = self.target_lens.to(input_ids.device) + self.embedding_vec = self.embedding_vec.to(input_ids.device) + self.target_lens = self.target_lens.to(input_ids.device) # The maximum length we need to consider is 1 token per character. Note that input_ids can also be # *shorter* than the global max, and the code below should be ready for that input_ids = input_ids[:, -self.maximum_token_len :] @@ -176,7 +176,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa max_valid_positions = self.max_valid_positions # The embedding vec contains the valid positions, end_lengths and total lengths for each token - embedded = F.embedding(flipped_ids, embedding_vec) + embedded = F.embedding(flipped_ids, self.embedding_vec) # Now we split the embedding vector. valid_positions is the positions in the stop string the token can fit valid_positions = embedded[:, 1:, : max_valid_positions * self.num_stop_strings].unflatten( @@ -214,7 +214,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa # The string is matched if we reached a cumsum equal to or greater than the length of the string # before hitting the mask - string_matches = torch.amax(cumsum * mask, dim=(1, -1)) >= target_lens[None, :] + string_matches = torch.amax(cumsum * mask, dim=(1, -1)) >= self.target_lens[None, :] # Now we concatenate the match booleans across all strings and check if any are True # TODO After Raushan's PR, return a per-sample vector here @@ -320,7 +320,6 @@ def _stop_string_create_embedding_vec(tok_list, tok_indices, stop_strings) -> Di and possible end lengths for each token, and the total length of the token string. When tokens have fewer valid positions or end lengths than the maximum, we pad the vectors with -1. """ - # TODO Matt: Merge the embeddings across all stop strings to save space and reduce gather calls? token_valid_positions, token_end_overlaps = _stop_string_get_matching_positions( tok_list, tok_indices, stop_strings ) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index fdf792873c0d1c..bafbecb32979a5 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -34,11 +34,10 @@ StopStringCriteria, validate_stopping_criteria, ) - -from transformers.generation.stopping_criteria import ( - _stop_string_create_embedding_vec, - _stop_string_get_matching_positions, -) + from transformers.generation.stopping_criteria import ( + _stop_string_create_embedding_vec, + _stop_string_get_matching_positions, + ) @require_torch From 2ba7f8eda8933a246461af412c5e05044a084b92 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 26 Feb 2024 14:12:29 +0000 Subject: [PATCH 47/68] Manual patching for BERT-like tokenizers --- src/transformers/generation/stopping_criteria.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 76bfc8677a2b5c..96b2d08d14072b 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -281,6 +281,8 @@ def _stop_string_get_matching_positions( def _cleanup_token(token: str) -> str: if token[0] in ["▁", "Ġ"]: token = " " + token[1:] + elif token[0] == "##": + token = token[2:] return token reversed_filtered_tok_list = [_cleanup_token(token)[::-1] for token in tok_list] From 8b95ec152bd0f435478b41166f116d502c08ccda Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 26 Feb 2024 16:36:27 +0000 Subject: [PATCH 48/68] Return a bool vector instead of a single True/False --- src/transformers/generation/stopping_criteria.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 96b2d08d14072b..11f9b6a020ac78 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -162,7 +162,7 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, self.target_lens = torch.tensor([len(stop_string) for stop_string in stop_strings], dtype=torch.int32) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.Tensor: self.embedding_vec = self.embedding_vec.to(input_ids.device) self.target_lens = self.target_lens.to(input_ids.device) # The maximum length we need to consider is 1 token per character. Note that input_ids can also be @@ -217,8 +217,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa string_matches = torch.amax(cumsum * mask, dim=(1, -1)) >= self.target_lens[None, :] # Now we concatenate the match booleans across all strings and check if any are True - # TODO After Raushan's PR, return a per-sample vector here - return torch.any(string_matches).item() + return torch.any(string_matches, dim=-1) class EosTokenCriteria(StoppingCriteria): From 1b46b208859d8a9b34b027583b42168d2025e666 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 26 Feb 2024 16:37:23 +0000 Subject: [PATCH 49/68] Better comment --- src/transformers/generation/stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 11f9b6a020ac78..0b4577cecd605c 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -216,7 +216,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa # before hitting the mask string_matches = torch.amax(cumsum * mask, dim=(1, -1)) >= self.target_lens[None, :] - # Now we concatenate the match booleans across all strings and check if any are True + # We return a per-sample vector that is True is any stop string is matched for that sample return torch.any(string_matches, dim=-1) From 350a850e50535f91bfb6bed7449ef31098f5ab68 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 26 Feb 2024 16:37:36 +0000 Subject: [PATCH 50/68] Better comment --- src/transformers/generation/stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 0b4577cecd605c..2612d37cb33d8d 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -216,7 +216,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa # before hitting the mask string_matches = torch.amax(cumsum * mask, dim=(1, -1)) >= self.target_lens[None, :] - # We return a per-sample vector that is True is any stop string is matched for that sample + # We return a per-sample vector that is True if any stop string is matched for that sample return torch.any(string_matches, dim=-1) From 0b85c6c6159f844e2f1dc4fa6f88a03de1c02d03 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 27 Feb 2024 17:08:09 +0000 Subject: [PATCH 51/68] Add tests from @zucchini-nlp --- tests/generation/test_stopping_criteria.py | 49 ++++++++++++++++++++++ tests/generation/test_utils.py | 37 ++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index bafbecb32979a5..97172c3fd1e452 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -256,3 +256,52 @@ def test_stop_string_embedding_vecs(self): token_length = token_lengths[token_idx].item() self.assertTrue(len(token) == token_length) + + def test_criterias_per_row(self): + text = "They completed the challenging puzzle, revealing the hidden image at the end" + stop_strings = ["end"] + + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + tokenizer.pad_token_id = tokenizer.eos_token_id + inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False) + + scores = None + criteria = StoppingCriteriaList( + [ + MaxLengthCriteria(max_length=20), + StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings), + ] + ) + + # trigger stopping when at leat one criteria is satisfied, one value per batch + self.assertTrue(criteria(inputs["input_ids"], scores)) + + # return False when neither is satisfied + self.assertFalse(criteria(inputs["input_ids"][:, :-1], scores)) + + def test_criterias_per_row_batched(self): + text = [ + "They completed the challenging puzzle, revealing the hidden image at the end", + "Today a dragon flew over France", + "The aroma of freshly baked pizza filled the kitchen", + ] + stop_strings = ["end"] + + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + tokenizer.pad_token_id = tokenizer.eos_token_id + tokenizer.padding_side = "left" + inputs = tokenizer(text, return_tensors="pt", padding="longest", add_special_tokens=False) + + scores = None + criteria = StoppingCriteriaList( + [ + MaxLengthCriteria(max_length=20), + StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings), + ] + ) + + # trigger stopping when at leat one criteria is satisfied + self.assertListEqual(criteria(inputs["input_ids"], scores).tolist(), [True, False, False]) + + # False when neither is satisfied + self.assertListEqual(criteria(inputs["input_ids"][:, :-1], scores).tolist(), [False, False, False]) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index d6b4840c4910c9..37b9d3dba5ec0b 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2355,6 +2355,43 @@ def test_constrained_beam_search_example_integration(self): self.assertListEqual(outputs, ["Wie alt sind Sie?"]) + @slow + def test_per_row_stopping_criteria(self): + text = [ + "They completed the challenging puzzle, revealing the hidden", + "Today a dragon flew over France", + "The aroma of freshly baked pizza filled the kitchen", + ] + stop_strings = ["secrets"] + + model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + tokenizer.padding_side = "left" + tokenizer.pad_token_id = tokenizer.eos_token_id + input_ids = tokenizer(text, return_tensors="pt", padding="longest", add_special_tokens=False).input_ids.to( + torch_device + ) + + # normal generation with one stopping criteria + out = model.generate(input_ids, max_length=15) + out_text = tokenizer.batch_decode(out) + expected_out = [ + "They completed the challenging puzzle, revealing the hidden secrets of the world.\n", + "<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced", + "The aroma of freshly baked pizza filled the kitchen with a sense of freshness", + ] + self.assertListEqual(out_text, expected_out) + + # generation should stop at "secrets" for first batch only, filling the rest with eos tokens + out = model.generate(input_ids, max_length=15, stop_strings=stop_strings, tokenizer=tokenizer) + out_text = tokenizer.batch_decode(out) + expected_out = [ + "They completed the challenging puzzle, revealing the hidden secrets<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", + "<|endoftext|><|endoftext|><|endoftext|>Today a dragon flew over France and the French government was forced", + "The aroma of freshly baked pizza filled the kitchen with a sense of freshness", + ] + self.assertListEqual(out_text, expected_out) + def test_constrained_beam_search_mixin_type_checks(self): # PT-only test: TF doesn't have constrained beam search tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random") From e8c769d2b4aaa26a2c5b784c1bfd230aeb346047 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 12 Mar 2024 17:28:44 +0000 Subject: [PATCH 52/68] Amy's list creation nit --- src/transformers/generation/stopping_criteria.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 2612d37cb33d8d..295256af5c65c8 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -325,8 +325,8 @@ def _stop_string_create_embedding_vec(tok_list, tok_indices, stop_strings) -> Di tok_list, tok_indices, stop_strings ) - max_valid_positions = max([len(val) for positions in token_valid_positions.values() for val in positions.values()]) - max_valid_end_lens = max([len(val) for positions in token_end_overlaps.values() for val in positions.values()]) + max_valid_positions = max(len(val) for positions in token_valid_positions.values() for val in positions.values()) + max_valid_end_lens = max(len(val) for positions in token_end_overlaps.values() for val in positions.values()) vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1 gather_vec = np.full((len(tok_list), vec_size), dtype=np.int32, fill_value=-1) From 14de1c3c00bd6b79a5566c6794a400c8c432f2e1 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 12 Mar 2024 17:32:10 +0000 Subject: [PATCH 53/68] tok_list -> token_list --- .../generation/stopping_criteria.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 295256af5c65c8..cf162456b3f215 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -151,11 +151,11 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, stop_strings = [stop_strings] self.vocab = tokenizer.get_vocab() - self.tok_list, self.tok_indices = tuple(self.vocab.keys()), tuple(self.vocab.values()) + self.token_list, self.tok_indices = tuple(self.vocab.keys()), tuple(self.vocab.values()) self.stop_strings: Tuple[str, ...] = tuple(stop_strings) self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = _stop_string_create_embedding_vec( - self.tok_list, self.tok_indices, self.stop_strings + self.token_list, self.tok_indices, self.stop_strings ) self.maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings]) self.num_stop_strings = len(self.stop_strings) @@ -270,7 +270,7 @@ def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_leng def _stop_string_get_matching_positions( - tok_list, tok_indices, stop_strings + token_list, tok_indices, stop_strings ) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]: """This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can validly appear in the stop strings. For each stop string, it returns a dictionary mapping tokens to a list of @@ -284,14 +284,14 @@ def _cleanup_token(token: str) -> str: token = token[2:] return token - reversed_filtered_tok_list = [_cleanup_token(token)[::-1] for token in tok_list] + reversed_filtered_token_list = [_cleanup_token(token)[::-1] for token in token_list] token_valid_positions = {} token_end_overlaps = {} for stop_string in stop_strings: reversed_stop_string = stop_string[::-1] token_valid_positions[stop_string] = {} token_end_overlaps[stop_string] = {} - for token, reversed_filtered_token, tok_idx in zip(tok_list, reversed_filtered_tok_list, tok_indices): + for token, reversed_filtered_token, tok_idx in zip(token_list, reversed_filtered_token_list, tok_indices): matching_positions = [] possible_end_lengths = [] for i in range(1 - len(token), len(stop_string)): @@ -315,20 +315,20 @@ def _cleanup_token(token: str) -> str: @lru_cache(8) -def _stop_string_create_embedding_vec(tok_list, tok_indices, stop_strings) -> Dict[str, torch.tensor]: +def _stop_string_create_embedding_vec(token_list, tok_indices, stop_strings) -> Dict[str, torch.tensor]: """ This function builds an embedding matrix for each stop string, consisting of possible valid positions and possible end lengths for each token, and the total length of the token string. When tokens have fewer valid positions or end lengths than the maximum, we pad the vectors with -1. """ token_valid_positions, token_end_overlaps = _stop_string_get_matching_positions( - tok_list, tok_indices, stop_strings + token_list, tok_indices, stop_strings ) max_valid_positions = max(len(val) for positions in token_valid_positions.values() for val in positions.values()) max_valid_end_lens = max(len(val) for positions in token_end_overlaps.values() for val in positions.values()) vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1 - gather_vec = np.full((len(tok_list), vec_size), dtype=np.int32, fill_value=-1) + gather_vec = np.full((len(token_list), vec_size), dtype=np.int32, fill_value=-1) for i, stop_string in enumerate(stop_strings): positions = token_valid_positions[stop_string] @@ -348,7 +348,7 @@ def _stop_string_create_embedding_vec(tok_list, tok_indices, stop_strings) -> Di + max_valid_end_lens * i + len(possible_end_lens), ] = possible_end_lens - for token, token_idx in zip(tok_list, tok_indices): + for token, token_idx in zip(token_list, tok_indices): gather_vec[token_idx, -1] = len(token) gather_vec = torch.tensor(gather_vec, dtype=torch.int32) From b8961e8dd3e2ddbf558f922c02889eda71771e4f Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 13 Mar 2024 15:35:48 +0000 Subject: [PATCH 54/68] Push a big expanded docstring (should we put it somewhere else?) --- .../generation/stopping_criteria.py | 108 +++++++++++++++++- 1 file changed, 102 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index cf162456b3f215..5065f6e41632af 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -135,9 +135,110 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa class StopStringCriteria(StoppingCriteria): """ - This class can be used to stop generation whenever specific string sequences are encountered. It preprocesses + This class can be used to stop generation whenever specific string sequences are generated. It preprocesses the strings together with the tokenizer vocab to find positions where tokens can validly complete the stop strings. + Generation is stopped as soon as a token is generated that completes any of the stop strings. + We want to catch any instance in which the stop string would be present in the decoded output, which means + we must also catch cases with "overhangs" off one or both ends. To make this more concrete, for the stop string + "stop", any of the following token sequences would trigger the match: + + - ["st", "op"] + - ["stop"] + - ["st", "opera"] + - ["sto", "opper"] + - ["las", "topper"] + + Note that a match will only be triggered if the stop string is at the end of the generated sequence. In other + words, these sequences will not trigger a match: + + - ["stop", "at"] + - ["st", "op", "at"] + - ["st", "opera", "tion"] + + The reason these are not a match is that the stop string does not overlap with the final token. If you can remove + one or more tokens from the end of the sequence without destroying the stop string, then this criterion will not + match that stop string. This is by design; because this check is run after each token is generated, we can't miss a + valid stop string if one is generated, but we don't want to halt generation just because the stop string exists + somewhere in the past input_ids. + + How is the match actually performed, though? We do it in quite a confusing way, because we want the entire match + process to be compilable with Torch or XLA, which means we cannot use standard string methods. However, it is possible, + with some work, to do string matching with pure tensor operations. We'll begin by describing the algorithm we use + with standard string operations, and then at the end we'll explain how this is converted to pure tensor operations + at generation time. + + The key to the algorithm is an observation: Because the stop string must overlap with the end of the token sequence, we can start at + the end of the sequence and work backwards. Specifically, we check that there is an overlap between the *start* of + the final token and the *end* of the stop_string, or to put it another way, stop_string[-i:] == token[:i] for + some i > 0. If you look at the positive examples above, you'll see the last token in all of them fulfills this + property: + + - ["st", "op"] (overlap is "op") + - ["stop"] (overlap is "stop") + - ["st", "opera"] (overlap is "op") + - ["sto", "pper"] (overlap is "p") + - ["las", "topper"] (overlap is "top") + + It's impossible to construct a matching sequence that does not have this property (feel free to verify this + yourself). However, although this overlap between the start of the final token and the end of the stop string is + necessary for a match, it is not sufficient. We also need to check that the rest of the token sequence is + consistent with the stop string. + + How do we do that? Let's say the stop string is N characters long, and the initial overlap covers the final + M characters. Then, we have N - M characters left to match. If the next token is less than M - N tokens long, then + the entire token must match: token == stop_string[-(M + len(token)): -M]. If the next token is longer than M - N + tokens, then we consider only the final M - N characters of the token. This allows for the token to have an overhang + off the start of the stop string. + + Again, let's make this concrete with a worked example. We'll use the stop string "stop" and the token sequence + ["las", "topper"]. The length of the stop string is 4. The final token is "topper", and its overlap with the stop + string is "top", which has length 3. We continue to the next token, "las", and we have 4 - 3 = 1 character left to + match. We check that "las"[-1:] == stop[:1], which is true. We have now matched 4 characters, which is the length of + the stop string, and we are done. + + At this point, hopefully you agree that we have an algorithm that detects the presence of a stop string, but you + may not see how we can convert this to tensor operations, particularly since we want to avoid data-dependent + conditional branching in the compiled code, and ideally vectorize everything so it can be efficiently computed on + GPU. The key is to realize that although we don't have access to string operations inside the generation loop, + we can use them in a precomputation stage! + + For every token in the tokenizer vocabulary, we precompute the values + we need for the above algorithm: The length of that token's overlap with the end of the stop string, the + position(s) in the stop string where that token matches perfectly, and the length of the token. We then pack + these values into a single vector per token, and stack those vectors into an embedding tensor which we can + gather from at runtime to get the values we need. + + This is the approach we take in this class. The precomputation is done in the `_stop_string_create_embedding_vec` + function. Then, at runtime in the `__call__()` method, we implement the algorithm above in purely vectorized + fashion, starting from an input_ids vector containing the token IDs in the sequence: + + - Gather from the embedding vector using input_ids as indices, and split the packed vectors into end overlap lengths, + valid token positions, and token lengths. + - Make a vector of the length of every token in the sequence, except for the final token, where we use the + end-overlap length instead. + - Compute the cumulative sum of the sequence, starting from the end. This represents the number of characters in the stop string that + we would match after each token, assuming that token is a valid fit for the sequence at that point. + - To determine if the tokens are valid at each position, we check that the cumulative length so far matches + one of the values in their valid positions vector. Where it does not, we mask that token and all tokens + following it. + - We then check the highest unmasked value in the cumulative sum. This represents the length of the total string + match before we reached a token that did not match the stop string. If it is equal to or greater than the length + of the stop string, the stop string is matched. + + This is almost the complete algorithm, and the remaining details just handle edge cases: For example, what do + we do if a token can have multiple possible overlap lengths with the stop string? For example, consider the + stop string "banana", and the token sequences ["ba", "nana"] and ["bana", "nana"]. Both of these sequences + contain the stop string and should trigger a match. However, the overlap of the final token is different! In + the first case, the overlap is "nana". In the second case, the overlap is "na". When we start from the end + of the sequence and work backwards, we cannot know in advance which overlap length, if any, will lead to a valid + match, and therefore we must test all possible overlap lengths. + + Therefore, for the stop string "banana" and the token "nana", we store two valid end overlap lengths: 2 and 4. + We then perform the above algorithm, starting from each value, and test whether each results in a match. + Thanks to vectorization, we can run these tests in parallel (in fact, we can run the test for every possible + overlap length and all stop strings in parallel). + Args: tokenizer (`PreTrainedTokenizer`): The model's associated tokenizer (necessary to extract vocab and tokenize the termination sequences) @@ -316,11 +417,6 @@ def _cleanup_token(token: str) -> str: @lru_cache(8) def _stop_string_create_embedding_vec(token_list, tok_indices, stop_strings) -> Dict[str, torch.tensor]: - """ - This function builds an embedding matrix for each stop string, consisting of possible valid positions - and possible end lengths for each token, and the total length of the token string. When tokens have - fewer valid positions or end lengths than the maximum, we pad the vectors with -1. - """ token_valid_positions, token_end_overlaps = _stop_string_get_matching_positions( token_list, tok_indices, stop_strings ) From 7ed55ad2301b8e1710d5396fd663b3f7cc3fa5f2 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 13 Mar 2024 16:02:57 +0000 Subject: [PATCH 55/68] Expand docstrings --- .../generation/stopping_criteria.py | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 5065f6e41632af..1141c8be2b907a 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -186,15 +186,16 @@ class StopStringCriteria(StoppingCriteria): consistent with the stop string. How do we do that? Let's say the stop string is N characters long, and the initial overlap covers the final - M characters. Then, we have N - M characters left to match. If the next token is less than M - N tokens long, then - the entire token must match: token == stop_string[-(M + len(token)): -M]. If the next token is longer than M - N - tokens, then we consider only the final M - N characters of the token. This allows for the token to have an overhang + M characters. Then, we have N - M characters left to match. If the next token is less than N - M tokens long, then + the entire token must match: token == stop_string[-M - len(token): -M]. If the next token is longer than N - M + tokens, then we consider only the final N - M characters of the token. This allows for the token to have an overhang off the start of the stop string. Again, let's make this concrete with a worked example. We'll use the stop string "stop" and the token sequence ["las", "topper"]. The length of the stop string is 4. The final token is "topper", and its overlap with the stop string is "top", which has length 3. We continue to the next token, "las", and we have 4 - 3 = 1 character left to - match. We check that "las"[-1:] == stop[:1], which is true. We have now matched 4 characters, which is the length of + match. This is less than the length of "las", so we only need a partial match for this token to complete the string. + We check that "las"[-1:] == stop[:1], which is true. We have now matched 4 characters, which is the length of the stop string, and we are done. At this point, hopefully you agree that we have an algorithm that detects the presence of a stop string, but you @@ -221,7 +222,7 @@ class StopStringCriteria(StoppingCriteria): we would match after each token, assuming that token is a valid fit for the sequence at that point. - To determine if the tokens are valid at each position, we check that the cumulative length so far matches one of the values in their valid positions vector. Where it does not, we mask that token and all tokens - following it. + preceding it. - We then check the highest unmasked value in the cumulative sum. This represents the length of the total string match before we reached a token that did not match the stop string. If it is equal to or greater than the length of the stop string, the stop string is matched. @@ -374,9 +375,12 @@ def _stop_string_get_matching_positions( token_list, tok_indices, stop_strings ) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]: """This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can - validly appear in the stop strings. For each stop string, it returns a dictionary mapping tokens to a list of - valid positions, as well as a dictionary mapping tokens to a list of possible overlap lengths at the - end of the stop string.""" + validly appear in the stop strings. For each token, it computes a list of positions in the stop string where the + token appears, as well as a list of the possible "end overlaps" for that token - that is, the number of characters + from the end of the stop string that overlap with the start of the token, which can have more than one value. + + The reason for computing these may seem a bit cryptic - please see the docstring for StopStringCriteria for a full + explanation of what these values are for!""" def _cleanup_token(token: str) -> str: if token[0] in ["▁", "Ġ"]: @@ -417,6 +421,9 @@ def _cleanup_token(token: str) -> str: @lru_cache(8) def _stop_string_create_embedding_vec(token_list, tok_indices, stop_strings) -> Dict[str, torch.tensor]: + """This function precomputes everything needed for the run-time checks in StopStringCriteria, and packs + them into an embedding tensor that can be accessed with pure tensor operations. For the specifics of the values + that are precomputed and what they are used for, please refer to the StopStringCriteria docstring!""" token_valid_positions, token_end_overlaps = _stop_string_get_matching_positions( token_list, tok_indices, stop_strings ) From cbb9d1478021557c4d094fd052ed59e5fd79c285 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 13 Mar 2024 16:17:47 +0000 Subject: [PATCH 56/68] Docstring fixups --- .../generation/stopping_criteria.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 1141c8be2b907a..47578ebae1d71b 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -165,20 +165,19 @@ class StopStringCriteria(StoppingCriteria): How is the match actually performed, though? We do it in quite a confusing way, because we want the entire match process to be compilable with Torch or XLA, which means we cannot use standard string methods. However, it is possible, with some work, to do string matching with pure tensor operations. We'll begin by describing the algorithm we use - with standard string operations, and then at the end we'll explain how this is converted to pure tensor operations - at generation time. + with standard string operations, and then at the end we'll explain how this is converted to pure tensor operations. The key to the algorithm is an observation: Because the stop string must overlap with the end of the token sequence, we can start at - the end of the sequence and work backwards. Specifically, we check that there is an overlap between the *start* of - the final token and the *end* of the stop_string, or to put it another way, stop_string[-i:] == token[:i] for + the end of the sequence and work backwards. Specifically, we check that there is an overlap between the start of + the final token and the end of the stop_string, or to put it another way, stop_string[-i:] == token[:i] for some i > 0. If you look at the positive examples above, you'll see the last token in all of them fulfills this property: - - ["st", "op"] (overlap is "op") - - ["stop"] (overlap is "stop") - - ["st", "opera"] (overlap is "op") - - ["sto", "pper"] (overlap is "p") - - ["las", "topper"] (overlap is "top") + - ["st", "op"] (overlap is "op", overlap length == 2) + - ["stop"] (overlap is "stop", overlap length == 4) + - ["st", "opera"] (overlap is "op", overlap length == 2) + - ["sto", "pper"] (overlap is "p", overlap length == 1) + - ["las", "topper"] (overlap is "top", overlap length == 3) It's impossible to construct a matching sequence that does not have this property (feel free to verify this yourself). However, although this overlap between the start of the final token and the end of the stop string is @@ -240,6 +239,15 @@ class StopStringCriteria(StoppingCriteria): Thanks to vectorization, we can run these tests in parallel (in fact, we can run the test for every possible overlap length and all stop strings in parallel). + The second detail is how we handle cases when the token sequence has an overhang off the start of the stop string, + as in the case of ["las", "top"], since we do not store "start overlaps" in the same way we do for end overlaps. + Instead, we simply store (in the valid_positions vector) that the token "las" is valid before "top", in the same + way that the token "s" is. Therefore, the total length it computes in the case of ["las", "top"] is 6 rather than 4, + because it doesn't truncate the match to the length of the stop string. However, since the algorithm concludes by + checking that the maximum match length is equal to or greater than the length of the stop string, this does not + affect the correctness of its final answer; both ["las", "top"] with a total length of 6, and ["s", "top"] with a + total length of 4, will be correctly identified as matches, because both are >= 4. + Args: tokenizer (`PreTrainedTokenizer`): The model's associated tokenizer (necessary to extract vocab and tokenize the termination sequences) From 7db95c1aa77b5d71745984140b976def4f6d2824 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 21 Mar 2024 14:38:21 +0000 Subject: [PATCH 57/68] Rebase --- src/transformers/generation/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6bb6f4620e078c..8859185477f1d5 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -88,6 +88,7 @@ if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel from .streamers import BaseStreamer + from ..tokenization_utils_base import PreTrainedTokenizerBase logger = logging.get_logger(__name__) @@ -882,7 +883,7 @@ def _get_logits_processor( return processors def _get_stopping_criteria( - self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList], **kwargs + self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList], tokenizer: Optional["PreTrainedTokenizerBase"] = None, **kwargs ) -> StoppingCriteriaList: criteria = StoppingCriteriaList() if generation_config.max_length is not None: @@ -1396,6 +1397,7 @@ def generate( synced_gpus = True else: synced_gpus = False + tokenizer = kwargs.pop("tokenizer", None) logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() @@ -1538,7 +1540,7 @@ def generate( # 9. prepare stopping criteria prepared_stopping_criteria = self._get_stopping_criteria( - generation_config=generation_config, stopping_criteria=stopping_criteria, **kwargs + generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs ) # 10. go into different generation modes if generation_mode == GenerationMode.ASSISTED_GENERATION: From 49b0f21e3a745325329a2beba5cf704d6247ed05 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 21 Mar 2024 14:44:00 +0000 Subject: [PATCH 58/68] make fixup --- src/transformers/generation/utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8859185477f1d5..32e0bca49d97c5 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -87,8 +87,8 @@ if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel - from .streamers import BaseStreamer from ..tokenization_utils_base import PreTrainedTokenizerBase + from .streamers import BaseStreamer logger = logging.get_logger(__name__) @@ -883,7 +883,11 @@ def _get_logits_processor( return processors def _get_stopping_criteria( - self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList], tokenizer: Optional["PreTrainedTokenizerBase"] = None, **kwargs + self, + generation_config: GenerationConfig, + stopping_criteria: Optional[StoppingCriteriaList], + tokenizer: Optional["PreTrainedTokenizerBase"] = None, + **kwargs, ) -> StoppingCriteriaList: criteria = StoppingCriteriaList() if generation_config.max_length is not None: From c9aefe6403d8ef66ec6780a2d11568098c62ae3c Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 21 Mar 2024 17:19:34 +0000 Subject: [PATCH 59/68] Make a properly general method for figuring out token strings --- .../generation/stopping_criteria.py | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 47578ebae1d71b..7016f190427589 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -260,8 +260,7 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, if isinstance(stop_strings, str): stop_strings = [stop_strings] - self.vocab = tokenizer.get_vocab() - self.token_list, self.tok_indices = tuple(self.vocab.keys()), tuple(self.vocab.values()) + self.token_list, self.tok_indices = self.clean_tokenizer_vocab(tokenizer) self.stop_strings: Tuple[str, ...] = tuple(stop_strings) self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = _stop_string_create_embedding_vec( @@ -271,6 +270,21 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, self.num_stop_strings = len(self.stop_strings) self.target_lens = torch.tensor([len(stop_string) for stop_string in stop_strings], dtype=torch.int32) + @staticmethod + def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef"): + # TODO Matt this work is done every time the criterion is initialized - can we cache it? + vocab = tokenizer.get_vocab() + clean_token_list = [] + clean_token_indices = [] + sentence_base = tokenizer(static_prefix, add_special_tokens=False)["input_ids"] + tokens_base = [tokenizer._convert_id_to_token(tok) for tok in sentence_base] + for token, token_idx in vocab.items(): + token_string = tokenizer.convert_tokens_to_string(tokens_base + [token]) + token_string = token_string[token_string.index(static_prefix) + len(static_prefix) :] + clean_token_list.append(token_string) + clean_token_indices.append(token_idx) + return clean_token_list, clean_token_indices + @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.Tensor: self.embedding_vec = self.embedding_vec.to(input_ids.device) @@ -390,29 +404,22 @@ def _stop_string_get_matching_positions( The reason for computing these may seem a bit cryptic - please see the docstring for StopStringCriteria for a full explanation of what these values are for!""" - def _cleanup_token(token: str) -> str: - if token[0] in ["▁", "Ġ"]: - token = " " + token[1:] - elif token[0] == "##": - token = token[2:] - return token - - reversed_filtered_token_list = [_cleanup_token(token)[::-1] for token in token_list] token_valid_positions = {} token_end_overlaps = {} for stop_string in stop_strings: reversed_stop_string = stop_string[::-1] token_valid_positions[stop_string] = {} token_end_overlaps[stop_string] = {} - for token, reversed_filtered_token, tok_idx in zip(token_list, reversed_filtered_token_list, tok_indices): + for token, tok_idx in zip(token_list, tok_indices): + reversed_token = token[::-1] matching_positions = [] possible_end_lengths = [] for i in range(1 - len(token), len(stop_string)): if i < 0: - tok = reversed_filtered_token[-i:] + tok = reversed_token[-i:] i = 0 else: - tok = reversed_filtered_token + tok = reversed_token stop = reversed_stop_string[i : i + len(tok)] if tok.startswith(stop): if i == 0: From 443cd5d615fb7e6ca80d78ad00f3ea9bbc235373 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 21 Mar 2024 17:41:07 +0000 Subject: [PATCH 60/68] Fix naming throughout the functions --- .../generation/stopping_criteria.py | 23 ++++++++++++------- tests/generation/test_stopping_criteria.py | 6 ++--- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 7016f190427589..86730df40099b9 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -260,11 +260,11 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, if isinstance(stop_strings, str): stop_strings = [stop_strings] - self.token_list, self.tok_indices = self.clean_tokenizer_vocab(tokenizer) + self.token_list, self.token_indices = self.clean_tokenizer_vocab(tokenizer) self.stop_strings: Tuple[str, ...] = tuple(stop_strings) self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = _stop_string_create_embedding_vec( - self.token_list, self.tok_indices, self.stop_strings + self.token_list, self.token_indices, self.stop_strings ) self.maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings]) self.num_stop_strings = len(self.stop_strings) @@ -272,6 +272,13 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, @staticmethod def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef"): + """ + This method turns a tokenizer vocab into a "clean" vocab where each token represents the actual string + it will yield, without any special prefixes like "##" or "Ġ". This is trickier than it looks - the method + tokenizer.convert_tokens_to_string() does not always return the correct string because of issues with prefix + space addition/removal. To work around this, we add a static prefix to the start of the token, then remove + it (and any prefix that may have been introduced with it) after calling convert_tokens_to_string(). + """ # TODO Matt this work is done every time the criterion is initialized - can we cache it? vocab = tokenizer.get_vocab() clean_token_list = [] @@ -283,7 +290,7 @@ def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef"): token_string = token_string[token_string.index(static_prefix) + len(static_prefix) :] clean_token_list.append(token_string) clean_token_indices.append(token_idx) - return clean_token_list, clean_token_indices + return tuple(clean_token_list), tuple(clean_token_indices) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.Tensor: @@ -394,7 +401,7 @@ def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_leng def _stop_string_get_matching_positions( - token_list, tok_indices, stop_strings + token_list, token_indices, stop_strings ) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]: """This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can validly appear in the stop strings. For each token, it computes a list of positions in the stop string where the @@ -410,7 +417,7 @@ def _stop_string_get_matching_positions( reversed_stop_string = stop_string[::-1] token_valid_positions[stop_string] = {} token_end_overlaps[stop_string] = {} - for token, tok_idx in zip(token_list, tok_indices): + for token, tok_idx in zip(token_list, token_indices): reversed_token = token[::-1] matching_positions = [] possible_end_lengths = [] @@ -435,12 +442,12 @@ def _stop_string_get_matching_positions( @lru_cache(8) -def _stop_string_create_embedding_vec(token_list, tok_indices, stop_strings) -> Dict[str, torch.tensor]: +def _stop_string_create_embedding_vec(token_list, token_indices, stop_strings) -> Dict[str, torch.tensor]: """This function precomputes everything needed for the run-time checks in StopStringCriteria, and packs them into an embedding tensor that can be accessed with pure tensor operations. For the specifics of the values that are precomputed and what they are used for, please refer to the StopStringCriteria docstring!""" token_valid_positions, token_end_overlaps = _stop_string_get_matching_positions( - token_list, tok_indices, stop_strings + token_list, token_indices, stop_strings ) max_valid_positions = max(len(val) for positions in token_valid_positions.values() for val in positions.values()) @@ -466,7 +473,7 @@ def _stop_string_create_embedding_vec(token_list, tok_indices, stop_strings) -> + max_valid_end_lens * i + len(possible_end_lens), ] = possible_end_lens - for token, token_idx in zip(token_list, tok_indices): + for token, token_idx in zip(token_list, token_indices): gather_vec[token_idx, -1] = len(token) gather_vec = torch.tensor(gather_vec, dtype=torch.int32) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 97172c3fd1e452..3f6e2b93f26200 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -184,7 +184,7 @@ def test_stop_string_matching_positions(self): criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) idx_to_token = {v: k for k, v in tokenizer.get_vocab().items()} all_token_valid_positions, all_token_end_overlaps = _stop_string_get_matching_positions( - tok_list=criteria.tok_list, tok_indices=criteria.tok_indices, stop_strings=criteria.stop_strings + token_list=criteria.token_list, token_indices=criteria.token_indices, stop_strings=criteria.stop_strings ) for stop_string in stop_strings: token_valid_positions = all_token_valid_positions[stop_string] @@ -215,7 +215,7 @@ def test_stop_string_embedding_vecs(self): stop_strings = ["aaaaaaa", "assdfiugsdf", "stop"] criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) embedding_vec, max_valid_positions, max_valid_end_lens = _stop_string_create_embedding_vec( - tok_list=criteria.tok_list, tok_indices=criteria.tok_indices, stop_strings=criteria.stop_strings + token_list=criteria.token_list, token_indices=criteria.token_indices, stop_strings=criteria.stop_strings ) valid_positions_vec = embedding_vec[:, : max_valid_positions * len(stop_strings)].unflatten( -1, (len(stop_strings), -1) @@ -226,7 +226,7 @@ def test_stop_string_embedding_vecs(self): token_lengths = embedding_vec[:, -1] for i, stop_string in enumerate(stop_strings): - for token, token_idx in zip(criteria.tok_list, criteria.tok_indices): + for token, token_idx in zip(criteria.token_list, criteria.token_indices): # The embedding contains packed valid positions, end overlap lengths, and the total token length token = token.replace("▁", " ").replace("Ġ", " ") From e1c9c0e02cb8047d3ff2f12eaf222dc7d6b73228 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 21 Mar 2024 18:46:19 +0000 Subject: [PATCH 61/68] Move cache, refactor, fix tests --- .../generation/stopping_criteria.py | 199 ++++++++++-------- tests/generation/test_stopping_criteria.py | 16 +- 2 files changed, 119 insertions(+), 96 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 86730df40099b9..a1cdaa265a496d 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -1,8 +1,8 @@ import time import warnings from abc import ABC +from collections import OrderedDict from copy import deepcopy -from functools import lru_cache from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -14,6 +14,9 @@ logger = logging.get_logger(__name__) +# We maintain a module-level cache of the embedding vectors for the stop string criterion +# because they are slow to compute +STOP_STRING_EMBEDDING_CACHE = OrderedDict() STOPPING_CRITERIA_INPUTS_DOCSTRING = r""" @@ -259,17 +262,38 @@ class StopStringCriteria(StoppingCriteria): def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, List[str]]): if isinstance(stop_strings, str): stop_strings = [stop_strings] - - self.token_list, self.token_indices = self.clean_tokenizer_vocab(tokenizer) self.stop_strings: Tuple[str, ...] = tuple(stop_strings) - - self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = _stop_string_create_embedding_vec( - self.token_list, self.token_indices, self.stop_strings + vocab = tokenizer.get_vocab() + token_list, token_indices = tuple(vocab.keys()), tuple(vocab.values()) + self.embedding_vec, self.max_valid_positions, self.max_valid_end_lens = self.clean_and_embed_tokens_with_cache( + token_list, token_indices, self.stop_strings, tokenizer ) + self.maximum_token_len = max([len(stop_string) for stop_string in self.stop_strings]) self.num_stop_strings = len(self.stop_strings) self.target_lens = torch.tensor([len(stop_string) for stop_string in stop_strings], dtype=torch.int32) + def clean_and_embed_tokens_with_cache(self, token_list, token_indices, stop_strings, tokenizer): + # We don't use the tokenizer in the cache key, because I don't trust it to have well-behaved equality + if (token_list, token_indices, stop_strings) in STOP_STRING_EMBEDDING_CACHE: + embedding_vec, max_valid_positions, max_valid_end_lens = STOP_STRING_EMBEDDING_CACHE[ + (token_list, token_indices, self.stop_strings) + ] + STOP_STRING_EMBEDDING_CACHE.move_to_end((token_list, token_indices, stop_strings)) + else: + clean_token_list, clean_token_indices = self.clean_tokenizer_vocab(tokenizer) + embedding_vec, max_valid_positions, max_valid_end_lens = self._stop_string_create_embedding_vec( + clean_token_list, clean_token_indices, stop_strings + ) + STOP_STRING_EMBEDDING_CACHE[(token_list, token_indices, stop_strings)] = ( + embedding_vec, + max_valid_positions, + max_valid_end_lens, + ) + if len(STOP_STRING_EMBEDDING_CACHE) > 8: + STOP_STRING_EMBEDDING_CACHE.popitem(last=False) + return embedding_vec, max_valid_positions, max_valid_end_lens + @staticmethod def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef"): """ @@ -292,6 +316,88 @@ def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef"): clean_token_indices.append(token_idx) return tuple(clean_token_list), tuple(clean_token_indices) + @staticmethod + def _stop_string_get_matching_positions( + token_list, token_indices, stop_strings + ) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]: + """This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can + validly appear in the stop strings. For each token, it computes a list of positions in the stop string where the + token appears, as well as a list of the possible "end overlaps" for that token - that is, the number of characters + from the end of the stop string that overlap with the start of the token, which can have more than one value. + + The reason for computing these may seem a bit cryptic - please see the docstring for StopStringCriteria for a full + explanation of what these values are for!""" + + token_valid_positions = {} + token_end_overlaps = {} + for stop_string in stop_strings: + reversed_stop_string = stop_string[::-1] + token_valid_positions[stop_string] = {} + token_end_overlaps[stop_string] = {} + for token, tok_idx in zip(token_list, token_indices): + reversed_token = token[::-1] + matching_positions = [] + possible_end_lengths = [] + for i in range(1 - len(token), len(stop_string)): + if i < 0: + tok = reversed_token[-i:] + i = 0 + else: + tok = reversed_token + stop = reversed_stop_string[i : i + len(tok)] + if tok.startswith(stop): + if i == 0: + possible_end_lengths.append(min(len(tok), len(stop))) + else: + matching_positions.append(i) + + if matching_positions: + token_valid_positions[stop_string][tok_idx] = matching_positions + if possible_end_lengths: + token_end_overlaps[stop_string][tok_idx] = possible_end_lengths + return token_valid_positions, token_end_overlaps + + @staticmethod + def _stop_string_create_embedding_vec(token_list, token_indices, stop_strings) -> Dict[str, torch.tensor]: + """This function precomputes everything needed for the run-time checks in StopStringCriteria, and packs + them into an embedding tensor that can be accessed with pure tensor operations. For the specifics of the values + that are precomputed and what they are used for, please refer to the StopStringCriteria docstring!""" + token_valid_positions, token_end_overlaps = StopStringCriteria._stop_string_get_matching_positions( + token_list, token_indices, stop_strings + ) + + max_valid_positions = max( + len(val) for positions in token_valid_positions.values() for val in positions.values() + ) + max_valid_end_lens = max(len(val) for positions in token_end_overlaps.values() for val in positions.values()) + vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1 + gather_vec = np.full((len(token_list), vec_size), dtype=np.int32, fill_value=-1) + + for i, stop_string in enumerate(stop_strings): + positions = token_valid_positions[stop_string] + end_lens = token_end_overlaps[stop_string] + + # Since this is lots of very small assignments of lists, we build it with numpy rather + # than torch for speed + simplicity, then convert to torch at the end + for token_idx, valid_positions in positions.items(): + gather_vec[ + token_idx, max_valid_positions * i : max_valid_positions * i + len(valid_positions) + ] = valid_positions + for token_idx, possible_end_lens in end_lens.items(): + gather_vec[ + token_idx, + max_valid_positions * len(stop_strings) + max_valid_end_lens * i : max_valid_positions + * len(stop_strings) + + max_valid_end_lens * i + + len(possible_end_lens), + ] = possible_end_lens + for token, token_idx in zip(token_list, token_indices): + gather_vec[token_idx, -1] = len(token) + + gather_vec = torch.tensor(gather_vec, dtype=torch.int32) + + return gather_vec, max_valid_positions, max_valid_end_lens + @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.Tensor: self.embedding_vec = self.embedding_vec.to(input_ids.device) @@ -398,84 +504,3 @@ def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_leng elif stopping_max_length is None: new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) return new_stopping_criteria - - -def _stop_string_get_matching_positions( - token_list, token_indices, stop_strings -) -> Tuple[Dict[str, Dict[str, List[int]]], Dict[str, Dict[str, List[int]]]]: - """This function preprocesses stop strings and the tokenizer vocabulary to determine where tokens can - validly appear in the stop strings. For each token, it computes a list of positions in the stop string where the - token appears, as well as a list of the possible "end overlaps" for that token - that is, the number of characters - from the end of the stop string that overlap with the start of the token, which can have more than one value. - - The reason for computing these may seem a bit cryptic - please see the docstring for StopStringCriteria for a full - explanation of what these values are for!""" - - token_valid_positions = {} - token_end_overlaps = {} - for stop_string in stop_strings: - reversed_stop_string = stop_string[::-1] - token_valid_positions[stop_string] = {} - token_end_overlaps[stop_string] = {} - for token, tok_idx in zip(token_list, token_indices): - reversed_token = token[::-1] - matching_positions = [] - possible_end_lengths = [] - for i in range(1 - len(token), len(stop_string)): - if i < 0: - tok = reversed_token[-i:] - i = 0 - else: - tok = reversed_token - stop = reversed_stop_string[i : i + len(tok)] - if tok.startswith(stop): - if i == 0: - possible_end_lengths.append(min(len(tok), len(stop))) - else: - matching_positions.append(i) - - if matching_positions: - token_valid_positions[stop_string][tok_idx] = matching_positions - if possible_end_lengths: - token_end_overlaps[stop_string][tok_idx] = possible_end_lengths - return token_valid_positions, token_end_overlaps - - -@lru_cache(8) -def _stop_string_create_embedding_vec(token_list, token_indices, stop_strings) -> Dict[str, torch.tensor]: - """This function precomputes everything needed for the run-time checks in StopStringCriteria, and packs - them into an embedding tensor that can be accessed with pure tensor operations. For the specifics of the values - that are precomputed and what they are used for, please refer to the StopStringCriteria docstring!""" - token_valid_positions, token_end_overlaps = _stop_string_get_matching_positions( - token_list, token_indices, stop_strings - ) - - max_valid_positions = max(len(val) for positions in token_valid_positions.values() for val in positions.values()) - max_valid_end_lens = max(len(val) for positions in token_end_overlaps.values() for val in positions.values()) - vec_size = len(stop_strings) * (max_valid_positions + max_valid_end_lens) + 1 - gather_vec = np.full((len(token_list), vec_size), dtype=np.int32, fill_value=-1) - - for i, stop_string in enumerate(stop_strings): - positions = token_valid_positions[stop_string] - end_lens = token_end_overlaps[stop_string] - - # Since this is lots of very small assignments of lists, we build it with numpy rather - # than torch for speed + simplicity, then convert to torch at the end - for token_idx, valid_positions in positions.items(): - gather_vec[ - token_idx, max_valid_positions * i : max_valid_positions * i + len(valid_positions) - ] = valid_positions - for token_idx, possible_end_lens in end_lens.items(): - gather_vec[ - token_idx, - max_valid_positions * len(stop_strings) + max_valid_end_lens * i : max_valid_positions - * len(stop_strings) - + max_valid_end_lens * i - + len(possible_end_lens), - ] = possible_end_lens - for token, token_idx in zip(token_list, token_indices): - gather_vec[token_idx, -1] = len(token) - - gather_vec = torch.tensor(gather_vec, dtype=torch.int32) - - return gather_vec, max_valid_positions, max_valid_end_lens diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 3f6e2b93f26200..7bf8d54c0c07e4 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -34,10 +34,6 @@ StopStringCriteria, validate_stopping_criteria, ) - from transformers.generation.stopping_criteria import ( - _stop_string_create_embedding_vec, - _stop_string_get_matching_positions, - ) @require_torch @@ -182,9 +178,10 @@ def test_stop_string_matching_positions(self): tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") stop_strings = ["aaaaaaa", "assdfiugsdf", "stop"] criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) + token_list, token_indices = criteria.clean_tokenizer_vocab(tokenizer) idx_to_token = {v: k for k, v in tokenizer.get_vocab().items()} - all_token_valid_positions, all_token_end_overlaps = _stop_string_get_matching_positions( - token_list=criteria.token_list, token_indices=criteria.token_indices, stop_strings=criteria.stop_strings + all_token_valid_positions, all_token_end_overlaps = criteria._stop_string_get_matching_positions( + token_list=token_list, token_indices=token_indices, stop_strings=criteria.stop_strings ) for stop_string in stop_strings: token_valid_positions = all_token_valid_positions[stop_string] @@ -214,8 +211,9 @@ def test_stop_string_embedding_vecs(self): tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") stop_strings = ["aaaaaaa", "assdfiugsdf", "stop"] criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) - embedding_vec, max_valid_positions, max_valid_end_lens = _stop_string_create_embedding_vec( - token_list=criteria.token_list, token_indices=criteria.token_indices, stop_strings=criteria.stop_strings + token_list, token_indices = criteria.clean_tokenizer_vocab(tokenizer) + embedding_vec, max_valid_positions, max_valid_end_lens = criteria._stop_string_create_embedding_vec( + token_list=token_list, token_indices=token_indices, stop_strings=criteria.stop_strings ) valid_positions_vec = embedding_vec[:, : max_valid_positions * len(stop_strings)].unflatten( -1, (len(stop_strings), -1) @@ -226,7 +224,7 @@ def test_stop_string_embedding_vecs(self): token_lengths = embedding_vec[:, -1] for i, stop_string in enumerate(stop_strings): - for token, token_idx in zip(criteria.token_list, criteria.token_indices): + for token, token_idx in zip(token_list, token_indices): # The embedding contains packed valid positions, end overlap lengths, and the total token length token = token.replace("▁", " ").replace("Ġ", " ") From f49ec00b1fa2ff5389d5185d1d20a4b4dc0fd1eb Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 21 Mar 2024 18:57:32 +0000 Subject: [PATCH 62/68] Add comment --- src/transformers/generation/stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index a1cdaa265a496d..dd6cd1d2e2a2f9 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -291,7 +291,7 @@ def clean_and_embed_tokens_with_cache(self, token_list, token_indices, stop_stri max_valid_end_lens, ) if len(STOP_STRING_EMBEDDING_CACHE) > 8: - STOP_STRING_EMBEDDING_CACHE.popitem(last=False) + STOP_STRING_EMBEDDING_CACHE.popitem(last=False) # Pop from the start, the least recently used item return embedding_vec, max_valid_positions, max_valid_end_lens @staticmethod From e90aaba5c1f2b7f57fb45f1d9ae9b9ddf87f3f2c Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 21 Mar 2024 19:14:12 +0000 Subject: [PATCH 63/68] Remove finished TODO --- src/transformers/generation/stopping_criteria.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index dd6cd1d2e2a2f9..1afa15ebed6d96 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -303,7 +303,6 @@ def clean_tokenizer_vocab(tokenizer, static_prefix="abcdef"): space addition/removal. To work around this, we add a static prefix to the start of the token, then remove it (and any prefix that may have been introduced with it) after calling convert_tokens_to_string(). """ - # TODO Matt this work is done every time the criterion is initialized - can we cache it? vocab = tokenizer.get_vocab() clean_token_list = [] clean_token_indices = [] From bb27d82ec3d9ea819403450e1700c95d676be03c Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 22 Mar 2024 15:14:51 +0000 Subject: [PATCH 64/68] Remove finished TODO --- .../generation/stopping_criteria.py | 25 +++++++++++++++++++ src/transformers/generation/utils.py | 7 +++--- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 1afa15ebed6d96..50afc9c8b05149 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -257,6 +257,31 @@ class StopStringCriteria(StoppingCriteria): stop_strings (`Union[str, List[str]]`): A list of strings that should end generation. If a string is passed, it will be treated like a list with a single element. + + Examples: + + ```python + >>> from transformers import AutoModelForCausalLM, AutoTokenizer + + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") + >>> model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2") + >>> inputs = tokenizer("The biggest states in the USA by land area:", return_tensors="pt") + + >>> gen_out = model.generate(**inputs) + >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0]) + The biggest states in the USA by land area: + - Alaska + - Texas + - California + + >>> # Passing one or more stop strings will halt generation after those strings are emitted + >>> # Note that generating with stop strings requires you to pass the tokenizer too + >>> gen_out = model.generate(**inputs, stop_strings=["Texas"], tokenizer=tokenizer) + >>> print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0]) + The biggest states in the USA by land area: + - Alaska + - Texas + ``` """ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_strings: Union[str, List[str]]): diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 32e0bca49d97c5..360ea6ca8f714c 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -901,14 +901,14 @@ def _get_stopping_criteria( if generation_config.max_time is not None: criteria.append(MaxTimeCriteria(max_time=generation_config.max_time)) if generation_config.stop_strings is not None: - if "tokenizer" not in kwargs: + if tokenizer is None: raise ValueError( "There are one or more stop strings, either in the arguments to `generate` or in the " "model's generation config, but we could not locate a tokenizer. When generating with " "stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`." ) criteria.append( - StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=kwargs["tokenizer"]) + StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer) ) if generation_config.eos_token_id is not None: criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id)) @@ -1392,6 +1392,7 @@ def generate( """ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() + tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) self._validate_model_kwargs(model_kwargs.copy()) @@ -1401,7 +1402,7 @@ def generate( synced_gpus = True else: synced_gpus = False - tokenizer = kwargs.pop("tokenizer", None) + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() From 431701970bdee8a5af45eaa9dc482118724577b6 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 2 Apr 2024 17:17:55 +0100 Subject: [PATCH 65/68] make fixup --- src/transformers/generation/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 360ea6ca8f714c..2a046196d85046 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -907,9 +907,7 @@ def _get_stopping_criteria( "model's generation config, but we could not locate a tokenizer. When generating with " "stop strings, you must pass the model's tokenizer to the `tokenizer` argument of `generate`." ) - criteria.append( - StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer) - ) + criteria.append(StopStringCriteria(stop_strings=generation_config.stop_strings, tokenizer=tokenizer)) if generation_config.eos_token_id is not None: criteria.append(EosTokenCriteria(eos_token_id=generation_config.eos_token_id)) criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) From 19df6a82b69666b6a05f9a16caa6ccd127ee5c55 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 11 Apr 2024 14:59:04 +0100 Subject: [PATCH 66/68] Update src/transformers/generation/stopping_criteria.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/generation/stopping_criteria.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 50afc9c8b05149..985b356a1ad11c 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -149,7 +149,7 @@ class StopStringCriteria(StoppingCriteria): - ["st", "op"] - ["stop"] - ["st", "opera"] - - ["sto", "opper"] + - ["sto", "pper"] - ["las", "topper"] Note that a match will only be triggered if the stop string is at the end of the generated sequence. In other From 8b52039193b3a5064ca23c6824732a3de1d4b691 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 11 Apr 2024 16:39:01 +0100 Subject: [PATCH 67/68] Update and shorten docstring --- .../generation/stopping_criteria.py | 109 ++++++++---------- 1 file changed, 46 insertions(+), 63 deletions(-) diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index 985b356a1ad11c..5a42f474be2692 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -151,6 +151,7 @@ class StopStringCriteria(StoppingCriteria): - ["st", "opera"] - ["sto", "pper"] - ["las", "topper"] + - ["s", "to", "pped"] Note that a match will only be triggered if the stop string is at the end of the generated sequence. In other words, these sequences will not trigger a match: @@ -181,75 +182,57 @@ class StopStringCriteria(StoppingCriteria): - ["st", "opera"] (overlap is "op", overlap length == 2) - ["sto", "pper"] (overlap is "p", overlap length == 1) - ["las", "topper"] (overlap is "top", overlap length == 3) + - ["s", "to", "pped"] (overlap is "p", overlap length == 1) It's impossible to construct a matching sequence that does not have this property (feel free to verify this yourself). However, although this overlap between the start of the final token and the end of the stop string is necessary for a match, it is not sufficient. We also need to check that the rest of the token sequence is consistent with the stop string. - How do we do that? Let's say the stop string is N characters long, and the initial overlap covers the final - M characters. Then, we have N - M characters left to match. If the next token is less than N - M tokens long, then - the entire token must match: token == stop_string[-M - len(token): -M]. If the next token is longer than N - M - tokens, then we consider only the final N - M characters of the token. This allows for the token to have an overhang - off the start of the stop string. - - Again, let's make this concrete with a worked example. We'll use the stop string "stop" and the token sequence - ["las", "topper"]. The length of the stop string is 4. The final token is "topper", and its overlap with the stop - string is "top", which has length 3. We continue to the next token, "las", and we have 4 - 3 = 1 character left to - match. This is less than the length of "las", so we only need a partial match for this token to complete the string. - We check that "las"[-1:] == stop[:1], which is true. We have now matched 4 characters, which is the length of - the stop string, and we are done. - - At this point, hopefully you agree that we have an algorithm that detects the presence of a stop string, but you - may not see how we can convert this to tensor operations, particularly since we want to avoid data-dependent - conditional branching in the compiled code, and ideally vectorize everything so it can be efficiently computed on - GPU. The key is to realize that although we don't have access to string operations inside the generation loop, - we can use them in a precomputation stage! - - For every token in the tokenizer vocabulary, we precompute the values - we need for the above algorithm: The length of that token's overlap with the end of the stop string, the - position(s) in the stop string where that token matches perfectly, and the length of the token. We then pack - these values into a single vector per token, and stack those vectors into an embedding tensor which we can - gather from at runtime to get the values we need. - - This is the approach we take in this class. The precomputation is done in the `_stop_string_create_embedding_vec` - function. Then, at runtime in the `__call__()` method, we implement the algorithm above in purely vectorized - fashion, starting from an input_ids vector containing the token IDs in the sequence: - - - Gather from the embedding vector using input_ids as indices, and split the packed vectors into end overlap lengths, - valid token positions, and token lengths. - - Make a vector of the length of every token in the sequence, except for the final token, where we use the - end-overlap length instead. - - Compute the cumulative sum of the sequence, starting from the end. This represents the number of characters in the stop string that - we would match after each token, assuming that token is a valid fit for the sequence at that point. - - To determine if the tokens are valid at each position, we check that the cumulative length so far matches - one of the values in their valid positions vector. Where it does not, we mask that token and all tokens - preceding it. - - We then check the highest unmasked value in the cumulative sum. This represents the length of the total string - match before we reached a token that did not match the stop string. If it is equal to or greater than the length - of the stop string, the stop string is matched. - - This is almost the complete algorithm, and the remaining details just handle edge cases: For example, what do - we do if a token can have multiple possible overlap lengths with the stop string? For example, consider the - stop string "banana", and the token sequences ["ba", "nana"] and ["bana", "nana"]. Both of these sequences - contain the stop string and should trigger a match. However, the overlap of the final token is different! In - the first case, the overlap is "nana". In the second case, the overlap is "na". When we start from the end - of the sequence and work backwards, we cannot know in advance which overlap length, if any, will lead to a valid - match, and therefore we must test all possible overlap lengths. - - Therefore, for the stop string "banana" and the token "nana", we store two valid end overlap lengths: 2 and 4. - We then perform the above algorithm, starting from each value, and test whether each results in a match. - Thanks to vectorization, we can run these tests in parallel (in fact, we can run the test for every possible - overlap length and all stop strings in parallel). - - The second detail is how we handle cases when the token sequence has an overhang off the start of the stop string, - as in the case of ["las", "top"], since we do not store "start overlaps" in the same way we do for end overlaps. - Instead, we simply store (in the valid_positions vector) that the token "las" is valid before "top", in the same - way that the token "s" is. Therefore, the total length it computes in the case of ["las", "top"] is 6 rather than 4, - because it doesn't truncate the match to the length of the stop string. However, since the algorithm concludes by - checking that the maximum match length is equal to or greater than the length of the stop string, this does not - affect the correctness of its final answer; both ["las", "top"] with a total length of 6, and ["s", "top"] with a - total length of 4, will be correctly identified as matches, because both are >= 4. + How do we do that? Let's use ["s", "to", "pped"] as an example. We know that the final token, "pped", has an + overlap of 1 with the stop string, "stop". We then go back to the previous token, "to". Since we have already + matched 1 character from the stop string, the remainder to check is "sto". We check that the next token "to" + matches the end of the remainder, which it does. We have now matched 3 characters from the stop string, and the + remainder to match is "s". We go back to the previous token again, which is also "s". This is a match, and so + we have matched the entire stop string. + + How does it work when the tokens run off the start of the stop string, though? Let's consider the example of + ["las", "topper"]. The final token, "topper", has an overlap of 3 with the stop string, "stop". Therefore, + the remaining stop string to match is "s". We go back to the previous token, "las". Because the remainder to + match is just "s", with length 1, we consider only the final 1 character from the token, which is "s". This + matches the stop string, and so the entire string is matched. + + How do we compute these matches with tensor operations, though? Simply: we efficiently precompute the necessary + information for all tokens! For every token, we compute: + - Its overlap with the end of the stop string, if any + - The positions inside the stop string where the token matches, including matches that run off the start. + - The total length of the token + + For example, for the token "pped", we would compute an end overlap of 1, no internal matching positions, + and a length of 4. For the token "to", we would compute no end overlap, a single internal matching position + of 1 (counting from the end), and a length of 2. For the token "s", we would compute no end overlap, + a single internal matching position of 3 (again counting from the end) and a length of 1. + + As long as we have this information, we can execute the algorithm above without any string comparison + operations. We simply perform the following steps: + - Check if the final token has an end-overlap with the start string + - Continue backwards, keeping track of how much of the stop string we've matched so far + - At each point, check if the next token has the current position as one of its valid positions + - Continue until either a match fails, or we completely match the whole stop string + + Again, consider ["s", "to", "pped"] as an example. "pped" has an end overlap of 1, so we can begin a match. + We have matched 1 character so far, so we check that the next token "to", has 1 as a valid position (again, + counting from the end). It does, so we add the length of "to" to our position tracker. We have now matched + 3 characters, so we check that the next token "s" has 3 as a valid position. It does, so we add its length + to the position tracker. The position tracker is now 4, which is the length of the stop string. We have matched the + entire stop string. + + In the second case, ["las", "topper"], "topper" has an end overlap of 3, so we can begin a match. We have + matched 3 characters so far, so we check that the next token "las" has 3 as a valid position. It does, because we + allow tokens to match positions that run off the start of the stop string. We add its length to the position + tracker. The position tracker is now 6, which is greater than the length of the stop string! Don't panic, though - + this also counts as a match of the stop string. We have matched the entire stop string. + Args: tokenizer (`PreTrainedTokenizer`): From 0aa201cbe6d8c5e7e29cb9595c84dc0dfaee1659 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 11 Apr 2024 18:58:40 +0100 Subject: [PATCH 68/68] Update tests to be shorter/clearer and test specific cases --- tests/generation/test_stopping_criteria.py | 103 ++++++--------------- 1 file changed, 28 insertions(+), 75 deletions(-) diff --git a/tests/generation/test_stopping_criteria.py b/tests/generation/test_stopping_criteria.py index 7bf8d54c0c07e4..1a22491b9aa0f6 100644 --- a/tests/generation/test_stopping_criteria.py +++ b/tests/generation/test_stopping_criteria.py @@ -175,85 +175,38 @@ def test_stop_string_criteria(self): self.assertFalse(criteria(false_input_ids["input_ids"][i : i + 1], scores)) def test_stop_string_matching_positions(self): - tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - stop_strings = ["aaaaaaa", "assdfiugsdf", "stop"] - criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) - token_list, token_indices = criteria.clean_tokenizer_vocab(tokenizer) - idx_to_token = {v: k for k, v in tokenizer.get_vocab().items()} - all_token_valid_positions, all_token_end_overlaps = criteria._stop_string_get_matching_positions( - token_list=token_list, token_indices=token_indices, stop_strings=criteria.stop_strings + stop_string = "stop" + token_list = ["last", "top", "topper", "s", "p"] + token_indices = list(range(len(token_list))) + all_token_valid_positions, all_token_end_overlaps = StopStringCriteria._stop_string_get_matching_positions( + token_list=token_list, token_indices=token_indices, stop_strings=[stop_string] ) - for stop_string in stop_strings: - token_valid_positions = all_token_valid_positions[stop_string] - token_end_overlaps = all_token_end_overlaps[stop_string] - for token_idx, valid_positions in token_valid_positions.items(): - token = idx_to_token[token_idx].replace("▁", " ").replace("Ġ", " ") - for position in valid_positions: - trim_length = position + len(token) - len(stop_string) - if trim_length > 0: - # This token runs off the start of the string - self.assertTrue(stop_string.startswith(token[trim_length:])) - else: - self.assertTrue(stop_string[-position - len(token) :].startswith(token)) - for token_idx, end_overlaps in token_end_overlaps.items(): - token = idx_to_token[token_idx].replace("▁", " ").replace("Ġ", " ") - for overlap in end_overlaps: - # Either this token runs off the end of the string, - # or the entire stop string is a substring of the token - self.assertTrue( - ( - stop_string.endswith(token[:overlap]) - or (stop_string in token and overlap == len(stop_string)) - ) - ) + valid_positions = { + token_list[idx]: positions for idx, positions in all_token_valid_positions[stop_string].items() + } + end_overlaps = {token_list[idx]: overlaps for idx, overlaps in all_token_end_overlaps[stop_string].items()} + self.assertEqual(valid_positions, {"s": [3], "last": [2]}) + self.assertEqual(end_overlaps, {"top": [3], "topper": [3], "p": [1]}) def test_stop_string_embedding_vecs(self): - tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") - stop_strings = ["aaaaaaa", "assdfiugsdf", "stop"] - criteria = StopStringCriteria(tokenizer=tokenizer, stop_strings=stop_strings) - token_list, token_indices = criteria.clean_tokenizer_vocab(tokenizer) - embedding_vec, max_valid_positions, max_valid_end_lens = criteria._stop_string_create_embedding_vec( - token_list=token_list, token_indices=token_indices, stop_strings=criteria.stop_strings - ) - valid_positions_vec = embedding_vec[:, : max_valid_positions * len(stop_strings)].unflatten( - -1, (len(stop_strings), -1) + stop_string = "stop" + token_list = ["last", "top", "topper", "s", "p"] + token_indices = list(range(len(token_list))) + embedding_vec, max_valid_positions, max_valid_end_lens = StopStringCriteria._stop_string_create_embedding_vec( + token_list=token_list, token_indices=token_indices, stop_strings=[stop_string] ) - end_overlaps_vec = embedding_vec[:, max_valid_positions * len(stop_strings) : -1].unflatten( - -1, (len(stop_strings), -1) - ) - token_lengths = embedding_vec[:, -1] - - for i, stop_string in enumerate(stop_strings): - for token, token_idx in zip(token_list, token_indices): - # The embedding contains packed valid positions, end overlap lengths, and the total token length - token = token.replace("▁", " ").replace("Ġ", " ") - - token_valid_positions = valid_positions_vec[token_idx, i].tolist() - for position in token_valid_positions: - if position == -1: - continue # Padding value - trim_length = position + len(token) - len(stop_string) - if trim_length > 0: - # This token runs off the start of the string - self.assertTrue(stop_string.startswith(token[trim_length:])) - else: - self.assertTrue(stop_string[-position - len(token) : -position] == token) - - token_end_overlaps = end_overlaps_vec[token_idx, i].tolist() - for overlap in token_end_overlaps: - if overlap == -1: - continue # Padding value - # Either this token runs off the end of the string, - # or the entire stop string is a substring of the token - self.assertTrue( - ( - stop_string.endswith(token[:overlap]) - or (stop_string in token and overlap == len(stop_string)) - ) - ) - - token_length = token_lengths[token_idx].item() - self.assertTrue(len(token) == token_length) + + # Positions inside the stop string where the token matches (excluding end overlaps) + valid_positions = embedding_vec[:, 0].tolist() + self.assertEqual(valid_positions, [2, -1, -1, 3, -1]) + + # Overlap lengths between end of stop string and start of token + end_overlaps = embedding_vec[:, 1].tolist() + self.assertEqual(end_overlaps, [-1, 3, 3, -1, 1]) + + # Length of each token + token_lengths = embedding_vec[:, 2].tolist() + self.assertEqual(token_lengths, [len(token) for token in token_list]) def test_criterias_per_row(self): text = "They completed the challenging puzzle, revealing the hidden image at the end"