From 65cd4d2bf43c7b3d1c3473a2655aa349e968d9f9 Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 14 Feb 2024 13:37:15 +0100 Subject: [PATCH 1/8] make compatible with torch.compile --- src/transformers/generation/logits_process.py | 152 ++++++++++-------- 1 file changed, 82 insertions(+), 70 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 04120e39fbd27c..7cc8c003a141be 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -45,7 +45,7 @@ class LogitsProcessor: """Abstract base class for all logit processors that can be applied during generation.""" @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: raise NotImplementedError( f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." ) @@ -55,7 +55,7 @@ class LogitsWarper: """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: raise NotImplementedError( f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." ) @@ -68,7 +68,9 @@ class LogitsProcessorList(list): [`LogitsProcessor`] or [`LogitsWarper`] to the inputs. """ - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int, **kwargs + ) -> torch.FloatTensor: r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -92,9 +94,9 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa f"Make sure that all the required parameters: {list(function_args.keys())} for " f"{processor.__class__} are passed to the logits processor." ) - scores = processor(input_ids, scores, **kwargs) + scores = processor(input_ids, scores, cur_len, **kwargs) else: - scores = processor(input_ids, scores) + scores = processor(input_ids, scores, cur_len) return scores @@ -149,11 +151,12 @@ def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]): self.eos_token_id = eos_token_id @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - cur_len = input_ids.shape[-1] + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) + eos_token_id = torch.tensor(self.eos_token_id, device=scores.device) + eos_token_mask = torch.isin(vocab_tensor, eos_token_id) if cur_len < self.min_length: - for i in self.eos_token_id: - scores[:, i] = -float("inf") + scores = torch.where(eos_token_mask, -math.inf, scores) return scores @@ -210,12 +213,13 @@ def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id self.eos_token_id = eos_token_id @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) + eos_token_id = torch.tensor(self.eos_token_id, device=scores.device) + eos_token_mask = torch.isin(vocab_tensor, eos_token_id) + new_tokens_length = cur_len - self.prompt_length_to_skip if new_tokens_length < self.min_new_tokens: - for i in self.eos_token_id: - scores[:, i] = -float("inf") - + scores = torch.where(eos_token_mask, -math.inf, scores) return scores @@ -280,7 +284,7 @@ def __init__(self, temperature: float): self.temperature = temperature @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: scores = scores / self.temperature return scores @@ -329,13 +333,13 @@ def __init__(self, penalty: float): self.penalty = penalty @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: score = torch.gather(scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities score = torch.where(score < 0, score * self.penalty, score / self.penalty) - scores.scatter_(1, input_ids, score) + scores = scores.scatter(1, input_ids, score) return scores @@ -384,13 +388,13 @@ def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor): self.encoder_input_ids = encoder_input_ids @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: score = torch.gather(scores, 1, self.encoder_input_ids) # if score < 0 then hallucination penalty has to be multiplied to increase the token probabilities score = torch.where(score < 0, score * self.penalty, score / self.penalty) - scores.scatter_(1, self.encoder_input_ids, score) + scores = scores.scatter(1, self.encoder_input_ids, score) return scores @@ -444,7 +448,7 @@ def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens self.min_tokens_to_keep = min_tokens_to_keep @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: sorted_logits, sorted_indices = torch.sort(scores, descending=False) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) @@ -504,7 +508,7 @@ def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_t self.filter_value = filter_value @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: top_k = min(self.top_k, scores.size(-1)) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None] @@ -577,7 +581,7 @@ def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_t self.min_tokens_to_keep = min_tokens_to_keep @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: # calculate entropy normalized = torch.nn.functional.log_softmax(scores, dim=-1) p = torch.exp(normalized) @@ -654,7 +658,7 @@ def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_toke self.min_tokens_to_keep = min_tokens_to_keep @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: # Determine which indices to remove probabilities = scores.softmax(dim=-1) indices_to_remove = probabilities < self.epsilon @@ -726,16 +730,17 @@ def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_toke f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}" ) - self.epsilon = torch.tensor(epsilon) + self.epsilon = epsilon self.filter_value = filter_value self.min_tokens_to_keep = min_tokens_to_keep @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: # Calculate the adaptive cutoff + epsilon = torch.tensor(self.epsilon, device=scores.device) probabilities = scores.softmax(dim=-1) entropy = torch.distributions.Categorical(logits=scores).entropy() - eta = torch.min(self.epsilon, torch.sqrt(self.epsilon) * torch.exp(-entropy))[..., None] + eta = torch.min(epsilon, torch.sqrt(epsilon) * torch.exp(-entropy))[..., None] indices_to_remove = probabilities < eta # Keep the words with the 'min_tokens_to_keep'-highest probabilities @@ -861,7 +866,7 @@ def __init__(self, ngram_size: int): self.ngram_size = ngram_size @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: num_batch_hypotheses = scores.shape[0] cur_len = input_ids.shape[-1] banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len) @@ -921,7 +926,7 @@ def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor) self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: # B x num_beams num_hypos = scores.shape[0] num_beams = num_hypos // self.batch_size @@ -1011,7 +1016,7 @@ def __init__(self, sequence_bias: Dict[Tuple[int], float]): self.prepared_bias_variables = False @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: # 1 - Prepares the bias tensors. This is only needed the first time the logit processor is called. if not self.prepared_bias_variables: self._prepare_bias_variables(scores) @@ -1226,7 +1231,7 @@ def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[ self._num_beams = num_beams @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: mask = torch.full_like(scores, -math.inf) for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])): for beam_id, sent in enumerate(beam_sent): @@ -1335,6 +1340,7 @@ def __call__( scores: torch.FloatTensor, current_tokens: torch.LongTensor, beam_group_idx: int, + cur_len: int, ) -> torch.FloatTensor: r""" Args: @@ -1411,12 +1417,11 @@ def __init__(self, bos_token_id: int): self.bos_token_id = bos_token_id @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - cur_len = input_ids.shape[-1] + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: if cur_len == 1: - num_tokens = scores.shape[1] - scores[:, [i for i in range(num_tokens) if i != self.bos_token_id]] = -float("inf") - scores[:, self.bos_token_id] = 0 + mask = torch.full_like(scores, -math.inf) + mask[:, self.bos_token_id] = 0 + scores = scores + mask return scores @@ -1460,13 +1465,11 @@ def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]): self.eos_token_id = eos_token_id @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - cur_len = input_ids.shape[-1] + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: if cur_len == self.max_length - 1: - num_tokens = scores.shape[1] - scores[:, [i for i in range(num_tokens) if i not in self.eos_token_id]] = -float("inf") - for i in self.eos_token_id: - scores[:, i] = 0 + mask = torch.full_like(scores, -math.inf) + mask[:, self.eos_token_id] = 0 + scores = scores + mask return scores @@ -1480,13 +1483,17 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor): """ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: # set all nan values to 0.0 - scores[scores != scores] = 0.0 + zeros = torch.zeros_like(scores) + scores = scores.masked_scatter(scores != scores, zeros) + + maxs = torch.ones_like(scores) * torch.finfo(scores.dtype).max + mins = torch.ones_like(scores) * torch.finfo(scores.dtype).min # set all +/-inf values to max/min possible value - scores[scores == float("inf")] = torch.finfo(scores.dtype).max - scores[scores == float("-inf")] = torch.finfo(scores.dtype).min + scores = scores.masked_scatter(scores == float("inf"), maxs) + scores = scores.masked_scatter(scores == -float("inf"), mins) return scores @@ -1572,14 +1579,15 @@ def __init__( self.eos_token_id = eos_token_id @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - cur_len = input_ids.shape[-1] + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + penalties = torch.zeros_like(scores) if cur_len > self.regulation_start: for i in self.eos_token_id: penalty_idx = cur_len - self.regulation_start # To support negative logits we compute the penalty of the absolute value and add to the original logit - scores[:, i] = scores[:, i] + torch.abs(scores[:, i]) * (pow(self.regulation_factor, penalty_idx) - 1) - return scores + penalty = torch.abs(scores[:, i]) * (pow(self.regulation_factor, penalty_idx) - 1) + penalties[:, i] = penalty + return scores + penalties class LogitNormalization(LogitsProcessor, LogitsWarper): @@ -1614,7 +1622,7 @@ class LogitNormalization(LogitsProcessor, LogitsWarper): """ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: scores = scores.log_softmax(dim=-1) return scores @@ -1662,9 +1670,12 @@ def set_begin_index(self, begin_index): self.begin_index = begin_index @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - if input_ids.shape[1] == self.begin_index: - scores[:, self.begin_suppress_tokens] = -float("inf") + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) + begin_suppress_tokens = torch.tensor(self.begin_suppress_tokens, device=scores.device) + suppress_token_mask = torch.isin(vocab_tensor, begin_suppress_tokens) + if cur_len == self.begin_index: + scores = torch.where(suppress_token_mask, -float("inf"), scores) return scores @@ -1702,8 +1713,11 @@ def __init__(self, suppress_tokens): self.suppress_tokens = list(suppress_tokens) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - scores[:, self.suppress_tokens] = -float("inf") + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) + suppress_tokens = torch.tensor(self.suppress_tokens, device=scores.device) + suppress_token_mask = torch.isin(vocab_tensor, suppress_tokens) + scores = torch.where(suppress_token_mask, -float("inf"), scores) return scores @@ -1749,12 +1763,12 @@ def __init__(self, force_token_map: List[List[int]]): self.force_token_map = dict(force_token_map) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - generation_idx = input_ids.shape[-1] - current_token = self.force_token_map.get(generation_idx, None) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + current_token = self.force_token_map.get(cur_len, None) if current_token is not None: - scores[:, :] = -float("inf") - scores[:, current_token] = 0 + mask = torch.full_like(scores, -float("inf")) + mask[:, current_token] = 0 + scores = scores + mask return scores @@ -1841,7 +1855,7 @@ def set_begin_index(self, begin_index): self.begin_index = begin_index @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: # suppress <|notimestamps|> which is handled by without_timestamps scores[:, self.no_timestamps_token_id] = -float("inf") @@ -1923,7 +1937,7 @@ def set_begin_index(self, begin_index): self.begin_index = begin_index @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: if input_ids.shape[1] == self.begin_index: if self.start_of_trans_offset > 1: with torch.no_grad(): @@ -1993,7 +2007,7 @@ def __init__(self, guidance_scale): ) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: # simple check to make sure we have compatible batch sizes between our # logits scores (cond + uncond) and input ids (cond only) if scores.shape[0] != 2 * input_ids.shape[0]: @@ -2037,11 +2051,9 @@ def __init__(self, input_start_len: int, semantic_vocab_size: int, codebook_size self.semantic_vocab_size = semantic_vocab_size self.codebook_size = codebook_size - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: - curr_len = input_ids.shape[-1] - + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: # even -> first codebook, odd -> second codebook - is_first_codebook = ((curr_len - self.input_start_len) % 2) == 0 + is_first_codebook = ((cur_len - self.input_start_len) % 2) == 0 if is_first_codebook: scores[:, : self.semantic_vocab_size] = -float("inf") @@ -2122,7 +2134,7 @@ def __init__( "first_pass": True, } - def get_unconditional_logits(self, input_ids): + def get_unconditional_logits(self, input_ids: torch.LongTensor) -> torch.FloatTensor: if self.unconditional_context["first_pass"]: if self.unconditional_context["input_ids"] is None: self.unconditional_context["input_ids"] = input_ids[:, -1:] @@ -2158,7 +2170,7 @@ def get_unconditional_logits(self, input_ids): return out.logits - def __call__(self, input_ids, scores): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: scores = torch.nn.functional.log_softmax(scores, dim=-1) if self.guidance_scale == 1: return scores @@ -2196,7 +2208,7 @@ def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float): self.min_eos_p = min_eos_p @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: if self.min_eos_p: probs = torch.nn.functional.softmax(scores.float(), dim=-1) # create scores full of -inf except for the eos_token_id From a24088159a596daae7003ff5ac77edbc6bc2a45f Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 19 Feb 2024 10:57:51 +0100 Subject: [PATCH 2/8] fixes and warnings --- src/transformers/generation/logits_process.py | 62 ++++++++++++++++--- .../generation/stopping_criteria.py | 4 ++ src/transformers/generation/utils.py | 5 ++ 3 files changed, 61 insertions(+), 10 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 7cc8c003a141be..0f49b1d3869e8c 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -34,6 +34,9 @@ scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam search or log softmax for each vocabulary token when using beam search + cur_len: (`int`): The current sequence length of generated text. For compatibility with `torch.compile`, `input_ids` + sequence length is the maxiumum length that can be generated, and `cur_len` indicates the length that was actually + generated. Return: `torch.FloatTensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores. @@ -78,6 +81,9 @@ def __call__( scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`): Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam search or log softmax for each vocabulary token when using beam search + cur_len: (`int`): The current sequence length of generated text. For compatibility with `torch.compile`, `input_ids` + sequence length is the maxiumum length that can be generated, and `cur_len` indicates the length that was actually + generated. kwargs (`Dict[str, Any]`, *optional*): Additional kwargs that are specific to a logits processor. @@ -146,6 +152,10 @@ def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]): eos_token_id = [eos_token_id] if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id): logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") + logger.warning( + f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." + "do not add `min_length` in generation config for efficient compilation" + ) self.min_length = min_length self.eos_token_id = eos_token_id @@ -207,6 +217,10 @@ def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id eos_token_id = [eos_token_id] if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id): logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") + logger.warning( + f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." + "do not add `min_new_tokens` in generation config for efficient compilation" + ) self.prompt_length_to_skip = prompt_length_to_skip self.min_new_tokens = min_new_tokens @@ -863,6 +877,10 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor): def __init__(self, ngram_size: int): if not isinstance(ngram_size, int) or ngram_size <= 0: raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") + logger.warning( + f"{self.__class__.__name__} will cannot be used with `torch.compile`" + "To compile generation, do not add `no_repeat_ngram_size`to generation config" + ) self.ngram_size = ngram_size @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @@ -919,6 +937,10 @@ def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor) raise ValueError( f"`encoder_ngram_size` has to be a strictly positive integer, but is {encoder_ngram_size}" ) + logger.warning( + f"{self.__class__.__name__} will cannot be used with `torch.compile`" + "To compile generation, do not add `encoder_no_repeat_ngram_size`to generation config" + ) self.ngram_size = encoder_ngram_size if len(encoder_input_ids.shape) == 1: encoder_input_ids = encoder_input_ids.unsqueeze(0) @@ -1227,6 +1249,10 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor): """ def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int): + logger.warning( + f"{self.__class__.__name__} cannot be used with `torch.compile`" + "To compile generation, do not add `prefix_allowed_tokens_fn`to generation config" + ) self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn self._num_beams = num_beams @@ -1414,6 +1440,10 @@ class ForcedBOSTokenLogitsProcessor(LogitsProcessor): """ def __init__(self, bos_token_id: int): + logger.warning( + f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." + "do not add `forced_bos_token_id` in generation config for efficient compilation" + ) self.bos_token_id = bos_token_id @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @@ -1421,7 +1451,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_l if cur_len == 1: mask = torch.full_like(scores, -math.inf) mask[:, self.bos_token_id] = 0 - scores = scores + mask + scores = mask return scores @@ -1459,6 +1489,10 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): """ def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]): + logger.warning( + f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." + "do not add `forced_eos_token_id` in generation config for efficient compilation" + ) self.max_length = max_length if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] @@ -1469,7 +1503,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_l if cur_len == self.max_length - 1: mask = torch.full_like(scores, -math.inf) mask[:, self.eos_token_id] = 0 - scores = scores + mask + scores = mask return scores @@ -1485,15 +1519,11 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor): @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: # set all nan values to 0.0 - zeros = torch.zeros_like(scores) - scores = scores.masked_scatter(scores != scores, zeros) - - maxs = torch.ones_like(scores) * torch.finfo(scores.dtype).max - mins = torch.ones_like(scores) * torch.finfo(scores.dtype).min + scores = torch.where(scores != scores, 0.0, scores) # set all +/-inf values to max/min possible value - scores = scores.masked_scatter(scores == float("inf"), maxs) - scores = scores.masked_scatter(scores == -float("inf"), mins) + scores = torch.where(scores == float("inf"), torch.finfo(scores.dtype).max, scores) + scores = torch.where(scores == -float("inf"), torch.finfo(scores.dtype).min, scores) return scores @@ -1572,6 +1602,10 @@ def __init__( eos_token_id: Union[int, List[int]], input_ids_seq_length: int, ): + logger.warning( + f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." + "do not add `exponential_decay_length_penalty` in generation config for efficient compilation" + ) self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length self.regulation_factor = exponential_decay_length_penalty[1] if isinstance(eos_token_id, int): @@ -1663,6 +1697,10 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): """ def __init__(self, begin_suppress_tokens, begin_index): + logger.warning( + f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." + "do not add `begin_suppress_tokens` in generation config for efficient compilation" + ) self.begin_suppress_tokens = list(begin_suppress_tokens) self.begin_index = begin_index @@ -1768,7 +1806,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_l if current_token is not None: mask = torch.full_like(scores, -float("inf")) mask[:, current_token] = 0 - scores = scores + mask + scores = mask return scores @@ -2133,6 +2171,10 @@ def __init__( "past_key_values": None, "first_pass": True, } + logger.warning( + f"{self.__class__.__name__} cannot be used with `torch.compile`" + "To compile generation, do not add `guidance_scale`to generation config" + ) def get_unconditional_logits(self, input_ids: torch.LongTensor) -> torch.FloatTensor: if self.unconditional_context["first_pass"]: diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index ca3e8509644081..e85a709a14a133 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -118,6 +118,10 @@ class MaxTimeCriteria(StoppingCriteria): """ def __init__(self, max_time: float, initial_timestamp: Optional[float] = None): + logger.warning( + f"{self.__class__.__name__} cannot be used with `torch.compile`" + "To compile generation, do not add `max_time`to generation config" + ) self.max_time = max_time self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0bbdd643421996..eccf4f6f8afa79 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -836,6 +836,11 @@ def _get_logits_processor( EncoderNoRepeatNGramLogitsProcessor(generation_config.encoder_no_repeat_ngram_size, encoder_input_ids) ) if generation_config.bad_words_ids is not None: + if generation_config.bias is not None: + logger.warning( + "If using `torch.compile`,'NoBadWordsLogitsProcessor' cannot be used together with `SequenceBiasLogitsProcessor` " + "To compile generation, add only one to generation config: `bias` or 'bad_words_ids'" + ) processors.append( NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id) ) From bfb5b1839a0f3da1afff9b038deb1a5bb22b19fd Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 20 Feb 2024 18:16:17 +0100 Subject: [PATCH 3/8] run tests --- src/transformers/generation/logits_process.py | 6 +- src/transformers/generation/utils.py | 40 +++--- tests/generation/test_logits_process.py | 116 +++++++++--------- 3 files changed, 85 insertions(+), 77 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 0f49b1d3869e8c..a4f095d57daaf0 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -94,8 +94,8 @@ def __call__( """ for processor in self: function_args = inspect.signature(processor.__call__).parameters - if len(function_args) > 2: - if not all(arg in kwargs for arg in list(function_args.keys())[2:]): + if len(function_args) > 3: + if not all(arg in kwargs for arg in list(function_args.keys())[3:]): raise ValueError( f"Make sure that all the required parameters: {list(function_args.keys())} for " f"{processor.__class__} are passed to the logits processor." @@ -1364,9 +1364,9 @@ def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor, + cur_len: int, current_tokens: torch.LongTensor, beam_group_idx: int, - cur_len: int, ) -> torch.FloatTensor: r""" Args: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index eccf4f6f8afa79..ecceaec9a29c1f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -836,7 +836,7 @@ def _get_logits_processor( EncoderNoRepeatNGramLogitsProcessor(generation_config.encoder_no_repeat_ngram_size, encoder_input_ids) ) if generation_config.bad_words_ids is not None: - if generation_config.bias is not None: + if generation_config.sequence_bias is not None: logger.warning( "If using `torch.compile`,'NoBadWordsLogitsProcessor' cannot be used together with `SequenceBiasLogitsProcessor` " "To compile generation, add only one to generation config: `bias` or 'bad_words_ids'" @@ -1895,7 +1895,7 @@ def contrastive_search( unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) this_peer_finished = False # used by synced_gpus only - batch_size = input_ids.shape[0] + batch_size, cur_len = input_ids.shape while True: if synced_gpus: @@ -1961,8 +1961,8 @@ def contrastive_search( # contrastive_search main logic start: # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by # degeneration penalty - logit_for_next_step = logits_processor(input_ids, logit_for_next_step) - logit_for_next_step = logits_warper(input_ids, logit_for_next_step) + logit_for_next_step = logits_processor(input_ids, logit_for_next_step, cur_len=cur_len) + logit_for_next_step = logits_warper(input_ids, logit_for_next_step, cur_len=cur_len) next_probs = nn.functional.softmax(logit_for_next_step, dim=-1) top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k) @@ -2140,6 +2140,8 @@ def contrastive_search( if unfinished_sequences.max() == 0: this_peer_finished = True + cur_len += 1 + # stop if we exceed the maximum length if stopping_criteria(input_ids, scores): this_peer_finished = True @@ -2335,6 +2337,7 @@ def greedy_search( # keep track of which sequences are already finished unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + _, cur_len = input_ids.shape this_peer_finished = False # used by synced_gpus only while True: @@ -2365,7 +2368,7 @@ def greedy_search( next_token_logits = outputs.logits[:, -1, :] # pre-process distribution - next_tokens_scores = logits_processor(input_ids, next_token_logits) + next_tokens_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len) # Store scores, attentions and hidden_states when required if return_dict_in_generate: @@ -2412,6 +2415,8 @@ def greedy_search( if unfinished_sequences.max() == 0: this_peer_finished = True + cur_len += 1 + # stop if we exceed the maximum length if stopping_criteria(input_ids, scores): this_peer_finished = True @@ -2616,6 +2621,7 @@ def sample( # keep track of which sequences are already finished unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) + _, cur_len = input_ids.shape this_peer_finished = False # used by synced_gpus only # auto-regressive generation @@ -2647,8 +2653,8 @@ def sample( next_token_logits = outputs.logits[:, -1, :] # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) + next_token_scores = logits_processor(input_ids, next_token_logits, cur_len=cur_len) + next_token_scores = logits_warper(input_ids, next_token_scores, cur_len=cur_len) # Store scores, attentions and hidden_states when required if return_dict_in_generate: @@ -2696,6 +2702,8 @@ def sample( if unfinished_sequences.max() == 0: this_peer_finished = True + cur_len += 1 + # stop if we exceed the maximum length if stopping_criteria(input_ids, scores): this_peer_finished = True @@ -3003,7 +3011,7 @@ def beam_search( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) - next_token_scores_processed = logits_processor(input_ids, next_token_scores) + next_token_scores_processed = logits_processor(input_ids, next_token_scores, cur_len=cur_len) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( next_token_scores_processed ) @@ -3338,8 +3346,8 @@ def beam_sample( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) - next_token_scores_processed = logits_processor(input_ids, next_token_scores) - next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed) + next_token_scores_processed = logits_processor(input_ids, next_token_scores, cur_len=cur_len) + next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed, cur_len=cur_len) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( next_token_scores_processed ) @@ -3707,7 +3715,7 @@ def group_beam_search( vocab_size = next_token_scores.shape[-1] next_token_scores_processed = logits_processor( - group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx + group_input_ids, next_token_scores, cur_len=cur_len, current_tokens=current_tokens, beam_group_idx=beam_group_idx ) next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) next_token_scores = next_token_scores.expand_as(next_token_scores_processed) @@ -4067,7 +4075,7 @@ def constrained_beam_search( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) - next_token_scores_processed = logits_processor(input_ids, next_token_scores) + next_token_scores_processed = logits_processor(input_ids, next_token_scores, cur_len=cur_len) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( next_token_scores_processed @@ -4419,10 +4427,10 @@ def assisted_decoding( new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present if len(logits_processor) > 0: for i in range(candidate_length + 1): - new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) + new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :], cur_len=cur_len + i) if len(logits_warper) > 0: for i in range(candidate_length + 1): - new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :]) + new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :], cur_len=cur_len + i) # 3. Select the accepted tokens. There are two possible cases: # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) @@ -4685,12 +4693,12 @@ def top_k_top_p_filtering( if top_k > 0: logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( - None, logits + None, logits, None ) if 0 <= top_p <= 1.0: logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)( - None, logits + None, logits, None ) return logits diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index 95150a9c33cd36..7d3d5b7e5ea6ed 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -72,13 +72,13 @@ def test_min_length_dist_processor(self): # check that min length is applied at length 5 input_ids = ids_tensor((batch_size, 5), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) - scores_before_min_length = min_dist_processor(input_ids, scores) + scores_before_min_length = min_dist_processor(input_ids, scores, cur_len=5) self.assertListEqual(scores_before_min_length[:, eos_token_id].tolist(), 4 * [-float("inf")]) # check that min length is not applied anymore at length 15 input_ids = ids_tensor((batch_size, 15), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) - scores_before_min_length = min_dist_processor(input_ids, scores) + scores_before_min_length = min_dist_processor(input_ids, scores, cur_len=15) self.assertFalse(torch.isinf(scores_before_min_length).any()) @parameterized.expand([(0,), ([0, 18],)]) @@ -97,7 +97,7 @@ def test_new_min_length_dist_processor(self, eos_token_id: Union[int, List[int]] expected_eos_scores_before_min_length *= len(eos_token_id) scores = self._get_uniform_logits(batch_size, vocab_size) - scores_before_min_length = new_min_dist_processor(input_ids, scores) + scores_before_min_length = new_min_dist_processor(input_ids, scores, cur_len=5) self.assertListEqual( scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length ) @@ -108,7 +108,7 @@ def test_new_min_length_dist_processor(self, eos_token_id: Union[int, List[int]] # check that min length is applied at length 2 input_ids = ids_tensor((batch_size, 2), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) - scores_before_min_length = new_min_dist_processor(input_ids, scores) + scores_before_min_length = new_min_dist_processor(input_ids, scores, cur_len=2) self.assertListEqual( scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length ) @@ -116,7 +116,7 @@ def test_new_min_length_dist_processor(self, eos_token_id: Union[int, List[int]] # check that min new length is applied at length 6 (because it has only 1 new token) input_ids = ids_tensor((batch_size, 6), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) - scores_before_min_length = new_min_dist_processor(input_ids, scores) + scores_before_min_length = new_min_dist_processor(input_ids, scores, cur_len=6) self.assertListEqual( scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length ) @@ -124,7 +124,7 @@ def test_new_min_length_dist_processor(self, eos_token_id: Union[int, List[int]] # check that min new length is applied at length 7 (because it has only 2 new tokens) input_ids = ids_tensor((batch_size, 7), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) - scores_before_min_length = new_min_dist_processor(input_ids, scores) + scores_before_min_length = new_min_dist_processor(input_ids, scores, cur_len=7) self.assertListEqual( scores_before_min_length[:, eos_token_id].flatten().tolist(), expected_eos_scores_before_min_length ) @@ -132,13 +132,13 @@ def test_new_min_length_dist_processor(self, eos_token_id: Union[int, List[int]] # check that min new length is not applied anymore at length 8 input_ids = ids_tensor((batch_size, 8), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) - scores_before_min_length = new_min_dist_processor(input_ids, scores) + scores_before_min_length = new_min_dist_processor(input_ids, scores, cur_len=8) self.assertFalse(torch.isinf(scores_before_min_length).any()) # check that min new length is not applied anymore at length 15 input_ids = ids_tensor((batch_size, 15), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) - scores_before_min_length = new_min_dist_processor(input_ids, scores) + scores_before_min_length = new_min_dist_processor(input_ids, scores, cur_len=15) self.assertFalse(torch.isinf(scores_before_min_length).any()) def test_temperature_dist_warper(self): @@ -157,8 +157,8 @@ def test_temperature_dist_warper(self): temp_dist_warper_sharper = TemperatureLogitsWarper(temperature=0.5) temp_dist_warper_smoother = TemperatureLogitsWarper(temperature=1.3) - warped_prob_sharp = nn.functional.softmax(temp_dist_warper_sharper(input_ids, scores.clone()), dim=-1) - warped_prob_smooth = nn.functional.softmax(temp_dist_warper_smoother(input_ids, scores.clone()), dim=-1) + warped_prob_sharp = nn.functional.softmax(temp_dist_warper_sharper(input_ids, scores.clone(), None), dim=-1) + warped_prob_smooth = nn.functional.softmax(temp_dist_warper_smoother(input_ids, scores.clone(), None), dim=-1) # uniform distribution stays uniform self.assertTrue(torch.allclose(probs[0, :], warped_prob_sharp[0, :], atol=1e-3)) @@ -184,7 +184,7 @@ def test_repetition_penalty_dist_process(self): rep_penalty_proc = RepetitionPenaltyLogitsProcessor(penalty=2.0) - scores = rep_penalty_proc(input_ids, scores.clone()) + scores = rep_penalty_proc(input_ids, scores.clone(), None) # check that values were correctly changed self.assertAlmostEqual(scores[0, 0].item(), -(1 / vocab_size) * 2) @@ -205,7 +205,7 @@ def test_encoder_repetition_penalty_dist_process(self): rep_penalty_proc = EncoderRepetitionPenaltyLogitsProcessor(penalty=2.0, encoder_input_ids=input_ids) - scores = rep_penalty_proc(input_ids, scores.clone()) + scores = rep_penalty_proc(input_ids, scores.clone(), None) # check that values were correctly changed self.assertAlmostEqual(scores[0, 0].item(), -(1 / vocab_size) / 2) @@ -231,7 +231,7 @@ def test_top_k_dist_warper(self): top_k_warp = TopKLogitsWarper(3) - scores = top_k_warp(input_ids, ramp_logits) + scores = top_k_warp(input_ids, ramp_logits, None) # check that correct tokens are filtered self.assertListEqual(torch.isinf(scores[0]).tolist(), 7 * [True] + 3 * [False]) @@ -243,12 +243,12 @@ def test_top_k_dist_warper(self): logits = self._get_uniform_logits(batch_size=batch_size, length=length) top_k_warp_safety_check = TopKLogitsWarper(top_k=1, filter_value=0.0, min_tokens_to_keep=3) - scores = top_k_warp_safety_check(input_ids, logits) + scores = top_k_warp_safety_check(input_ids, logits, None) # uniform dist is not changed self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [0, 0]) ramp_logits = torch.arange(length, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(batch_size, 1) - scores = top_k_warp_safety_check(input_ids, ramp_logits) + scores = top_k_warp_safety_check(input_ids, ramp_logits, None) # min_tokens overwrites k: 3 tokens are kept => 2 tokens are nullified self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2]) @@ -264,7 +264,7 @@ def test_top_p_dist_warper(self): ) top_p_warp = TopPLogitsWarper(0.8) - filtered_dist = torch.exp(top_p_warp(input_ids, dist)) + filtered_dist = torch.exp(top_p_warp(input_ids, dist, None)) # dist should be filtered to keep min num values so that sum is >= top_p # exp (-inf) => 0 @@ -283,7 +283,7 @@ def test_top_p_dist_warper(self): # make sure at least 2 tokens are kept top_p_warp = TopPLogitsWarper(0.9, min_tokens_to_keep=2, filter_value=0.0) - filtered_dist = top_p_warp(input_ids, ramp_logits) + filtered_dist = top_p_warp(input_ids, ramp_logits, None) # first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2]) @@ -299,7 +299,7 @@ def test_typical_dist_warper(self): ) typical_warp = TypicalLogitsWarper(0.5) - filtered_dist = torch.exp(typical_warp(input_ids, dist)) + filtered_dist = torch.exp(typical_warp(input_ids, dist, None)) # dist should be filtered to keep min num values so that sum is >= 0.7 # exp (-inf) => 0 @@ -314,7 +314,7 @@ def test_typical_dist_warper(self): logits = self._get_uniform_logits(batch_size=batch_size, length=length) typical_warp_safety_check = TypicalLogitsWarper(mass=0.5, filter_value=0.0, min_tokens_to_keep=3) - scores = typical_warp_safety_check(input_ids, logits) + scores = typical_warp_safety_check(input_ids, logits, None) # uniform dist is not changed self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [0, 0]) @@ -328,7 +328,7 @@ def test_typical_dist_warper(self): # make sure at least 2 tokens are kept typical_warp = TypicalLogitsWarper(0.7, min_tokens_to_keep=2, filter_value=0.0) - filtered_dist = typical_warp(input_ids, ramp_logits) + filtered_dist = typical_warp(input_ids, ramp_logits, None) # first batch should keep two tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2]) @@ -346,7 +346,7 @@ def test_epsilon_dist_warper(self): ) epsilon_warp = EpsilonLogitsWarper(0.1) - filtered_dist = torch.exp(epsilon_warp(input_ids, dist)) + filtered_dist = torch.exp(epsilon_warp(input_ids, dist, None)) # dist should be filtered to only keep values with proba >= 0.1 # exp (-inf) => 0 @@ -365,7 +365,7 @@ def test_epsilon_dist_warper(self): # make sure at least 2 tokens are kept epsilon_warp = EpsilonLogitsWarper(5e-2, min_tokens_to_keep=2, filter_value=0.0) - filtered_dist = epsilon_warp(input_ids, ramp_logits) + filtered_dist = epsilon_warp(input_ids, ramp_logits, None) # first batch should keep 3 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2]) @@ -381,7 +381,7 @@ def test_eta_dist_warper(self): ) eta_warp = EtaLogitsWarper(0.0625) - filtered_dist = torch.exp(eta_warp(input_ids, dist)) + filtered_dist = torch.exp(eta_warp(input_ids, dist, None)) # dist should be filtered to only keep values with proba >= min(0.0625, sqrt(0.0625) * e^-H(p)) # min(0.0625, 0.1320) is the cutoff for the first row and min(0.0625, 0.1644) is for the second @@ -402,7 +402,7 @@ def test_eta_dist_warper(self): # make sure at least 2 tokens are kept eta_warp = EtaLogitsWarper(0.1, min_tokens_to_keep=2, filter_value=0.0) - filtered_dist = eta_warp(input_ids, ramp_logits) + filtered_dist = eta_warp(input_ids, ramp_logits, None) # first batch should keep 2 tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2. self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2]) @@ -417,8 +417,8 @@ def test_no_repeat_ngram_dist_processor(self): no_repeat_proc_2_gram = NoRepeatNGramLogitsProcessor(2) no_repeat_proc_3_gram = NoRepeatNGramLogitsProcessor(3) - filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone()) - filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone()) + filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone(), None) + filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone(), None) # 2-gram would forbid 2nd and 3rd token (1,2) at 1st batch and 1st token (0) at 2nd batch self.assertListEqual(torch.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [True, False, False]]) @@ -441,8 +441,8 @@ def test_encoder_no_repeat_ngram_dist_processor(self): no_repeat_proc_2_gram = EncoderNoRepeatNGramLogitsProcessor(2, encoder_input_ids=encoder_input_ids) no_repeat_proc_3_gram = EncoderNoRepeatNGramLogitsProcessor(3, encoder_input_ids=encoder_input_ids) - filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone()) - filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone()) + filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone(), None) + filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone(), None) # 2-gram would forbid 1st and 2nd token at 1st beam and 1st token (0) at 2nd beam self.assertListEqual(torch.isinf(filtered_scores_2_gram).tolist(), [[False, True, True], [False, True, False]]) @@ -464,8 +464,8 @@ def test_encoder_no_repeat_ngram_dist_processor(self): no_repeat_proc_2_gram = EncoderNoRepeatNGramLogitsProcessor(2, encoder_input_ids=encoder_input_ids) no_repeat_proc_3_gram = EncoderNoRepeatNGramLogitsProcessor(3, encoder_input_ids=encoder_input_ids) - filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone()) - filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone()) + filtered_scores_2_gram = no_repeat_proc_2_gram(input_ids, scores.clone(), None) + filtered_scores_3_gram = no_repeat_proc_3_gram(input_ids, scores.clone(), None) # 2gram # Batch 1 @@ -501,7 +501,7 @@ def test_no_bad_words_dist_processor(self): no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=bad_word_tokens, eos_token_id=eos_token_id) - filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone()) + filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone(), None) # batch 1: 1st, 2nd, and 4th (0, 1, 3) token are forbidden # batch 2: 1st, 2nd, and 3rd (0, 1, 2) token are forbidden @@ -512,7 +512,7 @@ def test_no_bad_words_dist_processor(self): # check edge case no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=[[4]], eos_token_id=eos_token_id) - filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone()) + filtered_scores = no_bad_words_dist_proc(input_ids, scores.clone(), None) self.assertTrue(torch.allclose(scores, filtered_scores, atol=1e-3)) def test_bias_dist_processor(self): @@ -531,7 +531,7 @@ def test_bias_dist_processor(self): scores = torch.zeros((batch_size, vocab_size), dtype=torch.float, device=torch_device) bias_dist_proc = SequenceBiasLogitsProcessor(sequence_bias=sequence_bias) - filtered_scores = bias_dist_proc(input_ids, scores.clone()) + filtered_scores = bias_dist_proc(input_ids, scores.clone(), None) # batch 1: positive bias: tokens (1, 4); negative bias: tokens (0, 3); neutral: tokens (2) # batch 2: positive bias: tokens (1, 4); negative bias: tokens (0, 2); neutral: tokens (3) @@ -562,13 +562,13 @@ def test_processor_list(self): no_bad_words_dist_proc = NoBadWordsLogitsProcessor(bad_words_ids=[[1]], eos_token_id=eos_token_id) # no processor list - scores = min_dist_proc(input_ids, scores) - scores = temp_dist_warp(input_ids, scores) - scores = rep_penalty_proc(input_ids, scores) - scores = top_k_warp(input_ids, scores) - scores = top_p_warp(input_ids, scores) - scores = no_repeat_proc(input_ids, scores) - scores = no_bad_words_dist_proc(input_ids, scores) + scores = min_dist_proc(input_ids, scores, sequence_length) + scores = temp_dist_warp(input_ids, scores, sequence_length) + scores = rep_penalty_proc(input_ids, scores, sequence_length) + scores = top_k_warp(input_ids, scores, sequence_length) + scores = top_p_warp(input_ids, scores, sequence_length) + scores = no_repeat_proc(input_ids, scores, sequence_length) + scores = no_bad_words_dist_proc(input_ids, scores, sequence_length) # with processor list processor = LogitsProcessorList( @@ -582,7 +582,7 @@ def test_processor_list(self): no_bad_words_dist_proc, ] ) - scores_comp = processor(input_ids, scores_comp) + scores_comp = processor(input_ids, scores_comp, sequence_length) # scores should be equal self.assertTrue(torch.allclose(scores, scores_comp, atol=1e-3)) @@ -602,7 +602,7 @@ def prefix_allowed_tokens_fn(batch_id, inputs_ids): prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(prefix_allowed_tokens_fn, 1) - filtered_scores = prefix_constrained_logits_proc(input_ids, scores.clone()) + filtered_scores = prefix_constrained_logits_proc(input_ids, scores.clone(), None) # batch 1: 1st, 2nd (0, 1) token are allowed # batch 2: 3rd, 4th (2, 3) token are allowed @@ -615,7 +615,7 @@ def empty_prefix_allowed_tokens_fn(batch_id, inputs_ids): prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(empty_prefix_allowed_tokens_fn, 1) - self.assertRaises(ValueError, prefix_constrained_logits_proc, input_ids, scores.clone()) + self.assertRaises(ValueError, prefix_constrained_logits_proc, input_ids, scores.clone(), None) def test_hamming_diversity(self): vocab_size = 4 @@ -631,7 +631,7 @@ def test_hamming_diversity(self): diversity_penalty=1.0, num_beams=num_beams, num_beam_groups=num_beam_groups ) - processed_scores = diversity_logits_processor(None, scores, current_tokens, 1) + processed_scores = diversity_logits_processor(None, scores, None, current_tokens, 1) self.assertTrue( torch.allclose( @@ -654,14 +654,14 @@ def test_forced_bos_token_logits_processor(self): # check that all scores are -inf except the bos_token_id score input_ids = ids_tensor((batch_size, 1), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) - scores = logits_processor(input_ids, scores) + scores = logits_processor(input_ids, scores, cur_len=1) self.assertTrue(torch.isneginf(scores[:, bos_token_id + 1 :]).all()) self.assertListEqual(scores[:, bos_token_id].tolist(), 4 * [0]) # score for bos_token_id shold be zero # check that bos_token_id is not forced if current length is greater than 1 input_ids = ids_tensor((batch_size, 4), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) - scores = logits_processor(input_ids, scores) + scores = logits_processor(input_ids, scores, cur_len=4) self.assertFalse(torch.isinf(scores).any()) def test_forced_eos_token_logits_processor(self): @@ -675,14 +675,14 @@ def test_forced_eos_token_logits_processor(self): # check that all scores are -inf except the eos_token_id when max_length-1 is reached input_ids = ids_tensor((batch_size, 4), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) - scores = logits_processor(input_ids, scores) + scores = logits_processor(input_ids, scores, cur_len=4) self.assertTrue(torch.isneginf(scores[:, eos_token_id + 1 :]).all()) self.assertListEqual(scores[:, eos_token_id].tolist(), 4 * [0]) # score for eos_token_id should be zero # check that eos_token_id is not forced if max_length-1 is not reached input_ids = ids_tensor((batch_size, 3), vocab_size=20) scores = self._get_uniform_logits(batch_size, vocab_size) - scores = logits_processor(input_ids, scores) + scores = logits_processor(input_ids, scores, cur_len=3) self.assertFalse(torch.isinf(scores).any()) def test_remove_nan_inf_logits_processor(self): @@ -693,7 +693,7 @@ def test_remove_nan_inf_logits_processor(self): logits_processor = InfNanRemoveLogitsProcessor() - scores = logits_processor(input_ids, scores) + scores = logits_processor(input_ids, scores, None) self.assertTrue( torch.allclose( @@ -726,21 +726,21 @@ def test_exponential_decay_length_penalty(self): # check that penalty is not applied before start scores = self._get_uniform_logits(batch_size, vocab_size) scores_before_start = torch.clone(scores) # clone scores as precessor updates them inplace - scores_before_start = length_decay_processor(input_ids, scores_before_start) + scores_before_start = length_decay_processor(input_ids, scores_before_start, cur_len=2) self.assertListEqual(scores_before_start[:, eos_token_id].tolist(), scores[:, eos_token_id].tolist()) # check that penalty is applied after start input_ids = ids_tensor((batch_size, 20), vocab_size=vocab_size) scores = self._get_uniform_logits(batch_size, vocab_size) scores_after_start = torch.clone(scores) # clone scores as precessor updates them inplace - scores_after_start = length_decay_processor(input_ids, scores_after_start) + scores_after_start = length_decay_processor(input_ids, scores_after_start, cur_len=20) self.assertTrue(torch.gt(scores_after_start[:, eos_token_id], scores[:, eos_token_id]).all()) # check the penalty increases negative scores input_ids = ids_tensor((batch_size, 20), vocab_size=vocab_size) scores = torch.neg(self._get_uniform_logits(batch_size, vocab_size)) scores_after_start = torch.clone(scores) # clone scores as precessor updates them inplace - scores_after_start = length_decay_processor(input_ids, scores_after_start) + scores_after_start = length_decay_processor(input_ids, scores_after_start, cur_len=20) self.assertTrue(torch.gt(scores_after_start[:, eos_token_id], scores[:, eos_token_id]).all()) def test_normalization(self): @@ -751,7 +751,7 @@ def test_normalization(self): ) logit_normalization = LogitNormalization() - normalized_scores = logit_normalization(input_ids, scores).exp() + normalized_scores = logit_normalization(input_ids, scores, None).exp() ones = torch.ones(scores.shape[0], device=torch_device, dtype=torch.float) self.assertTrue(normalized_scores.sum(dim=-1).allclose(ones)) @@ -779,7 +779,7 @@ def lsm(x): cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor( 1.5, dummy_model, input_ids, torch.ones_like(input_ids, dtype=torch.long) ) - out = cfg(input_ids, logits_cond)[0, -1] + out = cfg(input_ids, logits_cond, None)[0, -1] res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1] @@ -790,7 +790,7 @@ def lsm(x): # explicit unconditional prompt input_ids = torch.LongTensor([[0]]) cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(1.5, dummy_model, input_ids) - out = cfg(input_ids, logits_cond)[0, -1] + out = cfg(input_ids, logits_cond, None)[0, -1] res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1] @@ -801,7 +801,7 @@ def lsm(x): # all implicit input_ids = torch.LongTensor([[0]]) cfg = UnbatchedClassifierFreeGuidanceLogitsProcessor(1.5, dummy_model) - out = cfg(input_ids, logits_cond)[0, -1] + out = cfg(input_ids, logits_cond, None)[0, -1] res = (lsm(logits_uncond) + 1.5 * (lsm(logits_cond) - lsm(logits_uncond)))[0, -1] @@ -818,7 +818,7 @@ def test_early_stop_processor(self): scores[0][eos_token_id] = -6 ## less than log(min_eos_p) esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) - actual_scores = esp(input_ids, scores) + actual_scores = esp(input_ids, scores, None) expected_scores_list = [ scores[0].tolist(), [float("-inf"), float("-inf"), scores[0][0], float("-inf")], @@ -834,7 +834,7 @@ def test_early_stop_processor_multi_eos(self): scores[0][eos_token_id] = -6 ## less than log(min_eos_p) esp = BarkEosPrioritizerLogitsProcessor(eos_token_id=eos_token_id, min_eos_p=min_eos_p) - actual_scores = esp(input_ids, scores) + actual_scores = esp(input_ids, scores, None) expected_scores_list = [ scores[0].tolist(), [float("-inf"), float("-inf"), scores[0][0], scores[0][0]], From 78ab061584adba1385f136a772e3212097e8c796 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 20 Feb 2024 18:50:29 +0100 Subject: [PATCH 4/8] make style --- src/transformers/generation/utils.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5eaa544dc8a752..f11ba288c5ab24 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3819,7 +3819,11 @@ def group_beam_search( vocab_size = next_token_scores.shape[-1] next_token_scores_processed = logits_processor( - group_input_ids, next_token_scores, cur_len=cur_len, current_tokens=current_tokens, beam_group_idx=beam_group_idx + group_input_ids, + next_token_scores, + cur_len=cur_len, + current_tokens=current_tokens, + beam_group_idx=beam_group_idx, ) next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) next_token_scores = next_token_scores.expand_as(next_token_scores_processed) @@ -4552,10 +4556,14 @@ def assisted_decoding( next_token_logits = new_logits.clone() if len(logits_processor) > 0: for i in range(candidate_length + 1): - new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :], cur_len=cur_len + i) + new_logits[:, i, :] = logits_processor( + candidate_input_ids[:, : cur_len + i], new_logits[:, i, :], cur_len=cur_len + i + ) if len(logits_warper) > 0: for i in range(candidate_length + 1): - new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :], cur_len=cur_len + i) + new_logits[:, i, :] = logits_warper( + candidate_input_ids[:, : cur_len + i], new_logits[:, i, :], cur_len=cur_len + i + ) # 3. Select the accepted tokens. There are two possible cases: # Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding) From 3750dd31e337417d67f4c48089a3e1a0ed379f3e Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Mon, 26 Feb 2024 20:45:22 +0500 Subject: [PATCH 5/8] Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante --- src/transformers/generation/logits_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 35c04765ffd8fb..a8850b66c4aa3e 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -35,7 +35,7 @@ Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam search or log softmax for each vocabulary token when using beam search cur_len: (`int`): The current sequence length of generated text. For compatibility with `torch.compile`, `input_ids` - sequence length is the maxiumum length that can be generated, and `cur_len` indicates the length that was actually + sequence length is the maximum length that can be generated, and `cur_len` indicates the length that was actually generated. Return: From 402bad613e8e7fbd1d59ff6e98bed01b4a0cbe7a Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Mon, 26 Feb 2024 20:45:29 +0500 Subject: [PATCH 6/8] Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante --- src/transformers/generation/logits_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index a8850b66c4aa3e..8947c196c4fc53 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -82,7 +82,7 @@ def __call__( Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam search or log softmax for each vocabulary token when using beam search cur_len: (`int`): The current sequence length of generated text. For compatibility with `torch.compile`, `input_ids` - sequence length is the maxiumum length that can be generated, and `cur_len` indicates the length that was actually + sequence length is the maximum length that can be generated, and `cur_len` indicates the length that was actually generated. kwargs (`Dict[str, Any]`, *optional*): Additional kwargs that are specific to a logits processor. From f6d659c0b225932b72d90b4bba257a5d1f0f26c5 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Mon, 26 Feb 2024 20:45:34 +0500 Subject: [PATCH 7/8] Update src/transformers/generation/logits_process.py Co-authored-by: Joao Gante --- src/transformers/generation/logits_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 8947c196c4fc53..05f6e683187604 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -154,7 +154,7 @@ def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]): logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") logger.warning( f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." - "do not add `min_length` in generation config for efficient compilation" + "Do not add `min_length` in generation config for efficient compilation" ) self.min_length = min_length From 35bd688981a8f85ae1241886ea18deae521d15be Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 26 Feb 2024 21:02:54 +0100 Subject: [PATCH 8/8] temporary changes, will not compile --- src/transformers/generation/logits_process.py | 288 +++++++++++++----- .../generation/stopping_criteria.py | 2 +- 2 files changed, 212 insertions(+), 78 deletions(-) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 05f6e683187604..19585403deae0f 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -15,12 +15,14 @@ import inspect import math +import warnings from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch from ..utils import add_start_docstrings +from ..utils.import_utils import is_torchdynamo_compiling from ..utils.logging import get_logger @@ -48,7 +50,9 @@ class LogitsProcessor: """Abstract base class for all logit processors that can be applied during generation.""" @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: raise NotImplementedError( f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." ) @@ -58,7 +62,9 @@ class LogitsWarper: """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: raise NotImplementedError( f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." ) @@ -72,7 +78,7 @@ class LogitsProcessorList(list): """ def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int, **kwargs + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None, **kwargs ) -> torch.FloatTensor: r""" Args: @@ -152,16 +158,26 @@ def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]): eos_token_id = [eos_token_id] if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id): logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") - logger.warning( - f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." - "Do not add `min_length` in generation config for efficient compilation" - ) + if is_torchdynamo_compiling(): + logger.warning_once( + f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." + "Do not add `min_length` in generation config for efficient compilation" + ) self.min_length = min_length self.eos_token_id = eos_token_id @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: + if cur_len is None: + warnings.warn( + "Using `input_ids.shape[-1]` as generated tokens length is deprecated and will " + "be removed in Transformers v4.43. Pass `cur_len` when calling `LogitsProcessor` instead.", + FutureWarning, + ) + cur_len = input_ids.shape[1] vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) eos_token_id = torch.tensor(self.eos_token_id, device=scores.device) eos_token_mask = torch.isin(vocab_tensor, eos_token_id) @@ -217,17 +233,27 @@ def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id eos_token_id = [eos_token_id] if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id): logger.warning(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") - logger.warning( - f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." - "do not add `min_new_tokens` in generation config for efficient compilation" - ) + if is_torchdynamo_compiling(): + logger.warning_once( + f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." + "Do not add `min_new_tokens` in generation config for efficient compilation" + ) self.prompt_length_to_skip = prompt_length_to_skip self.min_new_tokens = min_new_tokens self.eos_token_id = eos_token_id @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: + if cur_len is None: + warnings.warn( + "Using `input_ids.shape[-1]` as generated tokens length is deprecated and will " + "be removed in Transformers v4.43. Pass `cur_len` when calling `LogitsProcessor` instead.", + FutureWarning, + ) + cur_len = input_ids.shape[1] vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) eos_token_id = torch.tensor(self.eos_token_id, device=scores.device) eos_token_mask = torch.isin(vocab_tensor, eos_token_id) @@ -298,7 +324,9 @@ def __init__(self, temperature: float): self.temperature = temperature @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: scores = scores / self.temperature return scores @@ -347,7 +375,9 @@ def __init__(self, penalty: float): self.penalty = penalty @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: score = torch.gather(scores, 1, input_ids) # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities @@ -402,7 +432,9 @@ def __init__(self, penalty: float, encoder_input_ids: torch.LongTensor): self.encoder_input_ids = encoder_input_ids @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: score = torch.gather(scores, 1, self.encoder_input_ids) # if score < 0 then hallucination penalty has to be multiplied to increase the token probabilities @@ -462,7 +494,9 @@ def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens self.min_tokens_to_keep = min_tokens_to_keep @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: sorted_logits, sorted_indices = torch.sort(scores, descending=False) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) @@ -522,7 +556,9 @@ def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_t self.filter_value = filter_value @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: top_k = min(self.top_k, scores.size(-1)) # Safety check # Remove all tokens with a probability less than the last token of the top-k indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None] @@ -595,7 +631,9 @@ def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_t self.min_tokens_to_keep = min_tokens_to_keep @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: # calculate entropy normalized = torch.nn.functional.log_softmax(scores, dim=-1) p = torch.exp(normalized) @@ -672,7 +710,9 @@ def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_toke self.min_tokens_to_keep = min_tokens_to_keep @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: # Determine which indices to remove probabilities = scores.softmax(dim=-1) indices_to_remove = probabilities < self.epsilon @@ -749,7 +789,9 @@ def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_toke self.min_tokens_to_keep = min_tokens_to_keep @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: # Calculate the adaptive cutoff epsilon = torch.tensor(self.epsilon, device=scores.device) probabilities = scores.softmax(dim=-1) @@ -818,7 +860,7 @@ def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len): def _calc_banned_ngram_tokens( - ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int + ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int = None ) -> List[Iterable[int]]: """Copied from fairseq for no_repeat_ngram in beam_search""" if cur_len + 1 < ngram_size: @@ -877,16 +919,25 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor): def __init__(self, ngram_size: int): if not isinstance(ngram_size, int) or ngram_size <= 0: raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") - logger.warning( - f"{self.__class__.__name__} will cannot be used with `torch.compile`" - "To compile generation, do not add `no_repeat_ngram_size`to generation config" - ) + if is_torchdynamo_compiling(): + logger.warning_once( + f"{self.__class__.__name__} will cannot be used with `torch.compile`" + "To compile generation, Do not add `no_repeat_ngram_size`to generation config" + ) self.ngram_size = ngram_size @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: + if cur_len is None: + warnings.warn( + "Using `input_ids.shape[-1]` as generated tokens length is deprecated and will " + "be removed in Transformers v4.43. Pass `cur_len` when calling `LogitsProcessor` instead.", + FutureWarning, + ) + cur_len = input_ids.shape[1] num_batch_hypotheses = scores.shape[0] - cur_len = input_ids.shape[-1] banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len) for i, banned_tokens in enumerate(banned_batch_tokens): scores[i, banned_tokens] = -float("inf") @@ -937,10 +988,11 @@ def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor) raise ValueError( f"`encoder_ngram_size` has to be a strictly positive integer, but is {encoder_ngram_size}" ) - logger.warning( - f"{self.__class__.__name__} will cannot be used with `torch.compile`" - "To compile generation, do not add `encoder_no_repeat_ngram_size`to generation config" - ) + if is_torchdynamo_compiling(): + logger.warning_once( + f"{self.__class__.__name__} will cannot be used with `torch.compile`" + "To compile generation, Do not add `encoder_no_repeat_ngram_size`to generation config" + ) self.ngram_size = encoder_ngram_size if len(encoder_input_ids.shape) == 1: encoder_input_ids = encoder_input_ids.unsqueeze(0) @@ -948,11 +1000,20 @@ def __init__(self, encoder_ngram_size: int, encoder_input_ids: torch.LongTensor) self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: + if cur_len is None: + warnings.warn( + "Using `input_ids.shape[-1]` as generated tokens length is deprecated and will " + "be removed in Transformers v4.43. Pass `cur_len` when calling `LogitsProcessor` instead.", + FutureWarning, + ) + cur_len = input_ids.shape[1] + # B x num_beams num_hypos = scores.shape[0] num_beams = num_hypos // self.batch_size - cur_len = input_ids.shape[-1] banned_batch_tokens = [ _get_generated_ngrams( self.generated_ngrams[hypo_idx // num_beams], input_ids[hypo_idx], self.ngram_size, cur_len @@ -1038,7 +1099,9 @@ def __init__(self, sequence_bias: Dict[Tuple[int], float]): self.prepared_bias_variables = False @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: # 1 - Prepares the bias tensors. This is only needed the first time the logit processor is called. if not self.prepared_bias_variables: self._prepare_bias_variables(scores) @@ -1249,15 +1312,18 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor): """ def __init__(self, prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], List[int]], num_beams: int): - logger.warning( - f"{self.__class__.__name__} cannot be used with `torch.compile`" - "To compile generation, do not add `prefix_allowed_tokens_fn`to generation config" - ) + if is_torchdynamo_compiling(): + logger.warning_once( + f"{self.__class__.__name__} cannot be used with `torch.compile`" + "To compile generation, Do not add `prefix_allowed_tokens_fn`to generation config" + ) self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn self._num_beams = num_beams @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: mask = torch.full_like(scores, -math.inf) for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])): for beam_id, sent in enumerate(beam_sent): @@ -1364,9 +1430,9 @@ def __call__( self, input_ids: torch.LongTensor, scores: torch.FloatTensor, - cur_len: int, - current_tokens: torch.LongTensor, - beam_group_idx: int, + cur_len: int = None, + current_tokens: torch.LongTensor = None, + beam_group_idx: int = None, ) -> torch.FloatTensor: r""" Args: @@ -1440,14 +1506,24 @@ class ForcedBOSTokenLogitsProcessor(LogitsProcessor): """ def __init__(self, bos_token_id: int): - logger.warning( - f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." - "do not add `forced_bos_token_id` in generation config for efficient compilation" - ) + if is_torchdynamo_compiling(): + logger.warning_once( + f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." + "Do not add `forced_bos_token_id` in generation config for efficient compilation" + ) self.bos_token_id = bos_token_id @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: + if cur_len is None: + warnings.warn( + "Using `input_ids.shape[-1]` as generated tokens length is deprecated and will " + "be removed in Transformers v4.43. Pass `cur_len` when calling `LogitsProcessor` instead.", + FutureWarning, + ) + cur_len = input_ids.shape[1] if cur_len == 1: mask = torch.full_like(scores, -math.inf) mask[:, self.bos_token_id] = 0 @@ -1489,17 +1565,27 @@ class ForcedEOSTokenLogitsProcessor(LogitsProcessor): """ def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]): - logger.warning( - f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." - "do not add `forced_eos_token_id` in generation config for efficient compilation" - ) + if is_torchdynamo_compiling(): + logger.warning( + f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." + "Do not add `forced_eos_token_id` in generation config for efficient compilation" + ) self.max_length = max_length if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] self.eos_token_id = eos_token_id @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: + if cur_len is None: + warnings.warn( + "Using `input_ids.shape[-1]` as generated tokens length is deprecated and will " + "be removed in Transformers v4.43. Pass `cur_len` when calling `LogitsProcessor` instead.", + FutureWarning, + ) + cur_len = input_ids.shape[1] if cur_len == self.max_length - 1: mask = torch.full_like(scores, -math.inf) mask[:, self.eos_token_id] = 0 @@ -1517,7 +1603,9 @@ class InfNanRemoveLogitsProcessor(LogitsProcessor): """ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: # set all nan values to 0.0 scores = torch.where(scores != scores, 0.0, scores) @@ -1602,10 +1690,11 @@ def __init__( eos_token_id: Union[int, List[int]], input_ids_seq_length: int, ): - logger.warning( - f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." - "do not add `exponential_decay_length_penalty` in generation config for efficient compilation" - ) + if is_torchdynamo_compiling(): + logger.warning_once( + f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." + "Do not add `exponential_decay_length_penalty` in generation config for efficient compilation" + ) self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length self.regulation_factor = exponential_decay_length_penalty[1] if isinstance(eos_token_id, int): @@ -1613,7 +1702,16 @@ def __init__( self.eos_token_id = eos_token_id @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: + if cur_len is None: + warnings.warn( + "Using `input_ids.shape[-1]` as generated tokens length is deprecated and will " + "be removed in Transformers v4.43. Pass `cur_len` when calling `LogitsProcessor` instead.", + FutureWarning, + ) + cur_len = input_ids.shape[1] penalties = torch.zeros_like(scores) if cur_len > self.regulation_start: for i in self.eos_token_id: @@ -1656,7 +1754,9 @@ class LogitNormalization(LogitsProcessor, LogitsWarper): """ @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: scores = scores.log_softmax(dim=-1) return scores @@ -1697,10 +1797,11 @@ class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): """ def __init__(self, begin_suppress_tokens, begin_index): - logger.warning( - f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." - "do not add `begin_suppress_tokens` in generation config for efficient compilation" - ) + if is_torchdynamo_compiling(): + logger.warning_once( + f"{self.__class__.__name__} will cause recompilations for every new token if used with `torch.compile`." + "Do not add `begin_suppress_tokens` in generation config for efficient compilation" + ) self.begin_suppress_tokens = list(begin_suppress_tokens) self.begin_index = begin_index @@ -1708,7 +1809,16 @@ def set_begin_index(self, begin_index): self.begin_index = begin_index @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: + if cur_len is None: + warnings.warn( + "Using `input_ids.shape[-1]` as generated tokens length is deprecated and will " + "be removed in Transformers v4.43. Pass `cur_len` when calling `LogitsProcessor` instead.", + FutureWarning, + ) + cur_len = input_ids.shape[1] vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) begin_suppress_tokens = torch.tensor(self.begin_suppress_tokens, device=scores.device) suppress_token_mask = torch.isin(vocab_tensor, begin_suppress_tokens) @@ -1751,7 +1861,9 @@ def __init__(self, suppress_tokens): self.suppress_tokens = list(suppress_tokens) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: vocab_tensor = torch.arange(scores.shape[-1], device=scores.device) suppress_tokens = torch.tensor(self.suppress_tokens, device=scores.device) suppress_token_mask = torch.isin(vocab_tensor, suppress_tokens) @@ -1801,7 +1913,9 @@ def __init__(self, force_token_map: List[List[int]]): self.force_token_map = dict(force_token_map) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: current_token = self.force_token_map.get(cur_len, None) if current_token is not None: mask = torch.full_like(scores, -float("inf")) @@ -1893,7 +2007,9 @@ def set_begin_index(self, begin_index): self.begin_index = begin_index @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: # suppress <|notimestamps|> which is handled by without_timestamps scores[:, self.no_timestamps_token_id] = -float("inf") @@ -1975,7 +2091,9 @@ def set_begin_index(self, begin_index): self.begin_index = begin_index @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: if input_ids.shape[1] == self.begin_index: if self.start_of_trans_offset > 1: with torch.no_grad(): @@ -2045,7 +2163,9 @@ def __init__(self, guidance_scale): ) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: # simple check to make sure we have compatible batch sizes between our # logits scores (cond + uncond) and input ids (cond only) if scores.shape[0] != 2 * input_ids.shape[0]: @@ -2089,8 +2209,17 @@ def __init__(self, input_start_len: int, semantic_vocab_size: int, codebook_size self.semantic_vocab_size = semantic_vocab_size self.codebook_size = codebook_size - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: # even -> first codebook, odd -> second codebook + if cur_len is None: + warnings.warn( + "Using `input_ids.shape[-1]` as generated tokens length is deprecated and will " + "be removed in Transformers v4.43. Pass `cur_len` when calling `LogitsProcessor` instead.", + FutureWarning, + ) + cur_len = input_ids.shape[1] is_first_codebook = ((cur_len - self.input_start_len) % 2) == 0 if is_first_codebook: @@ -2171,10 +2300,11 @@ def __init__( "past_key_values": None, "first_pass": True, } - logger.warning( - f"{self.__class__.__name__} cannot be used with `torch.compile`" - "To compile generation, do not add `guidance_scale`to generation config" - ) + if is_torchdynamo_compiling(): + logger.warning_once( + f"{self.__class__.__name__} cannot be used with `torch.compile`" + "To compile generation, Do not add `guidance_scale`to generation config" + ) def get_unconditional_logits(self, input_ids: torch.LongTensor) -> torch.FloatTensor: if self.unconditional_context["first_pass"]: @@ -2212,7 +2342,9 @@ def get_unconditional_logits(self, input_ids: torch.LongTensor) -> torch.FloatTe return out.logits - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: scores = torch.nn.functional.log_softmax(scores, dim=-1) if self.guidance_scale == 1: return scores @@ -2250,7 +2382,9 @@ def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float): self.min_eos_p = min_eos_p @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int) -> torch.FloatTensor: + def __call__( + self, input_ids: torch.LongTensor, scores: torch.FloatTensor, cur_len: int = None + ) -> torch.FloatTensor: if self.min_eos_p: probs = torch.nn.functional.softmax(scores.float(), dim=-1) # create scores full of -inf except for the eos_token_id diff --git a/src/transformers/generation/stopping_criteria.py b/src/transformers/generation/stopping_criteria.py index e85a709a14a133..64855c7f8200b8 100644 --- a/src/transformers/generation/stopping_criteria.py +++ b/src/transformers/generation/stopping_criteria.py @@ -118,7 +118,7 @@ class MaxTimeCriteria(StoppingCriteria): """ def __init__(self, max_time: float, initial_timestamp: Optional[float] = None): - logger.warning( + logger.warning_once( f"{self.__class__.__name__} cannot be used with `torch.compile`" "To compile generation, do not add `max_time`to generation config" )