From c58f24e393dc234500bec4cab5742db830c997bd Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 9 Jul 2024 08:32:51 +0000 Subject: [PATCH] fix cache positions --- src/transformers/cache_utils.py | 1 - src/transformers/generation/utils.py | 39 ++++++++++++++++------------ src/transformers/testing_utils.py | 5 +++- 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 1f5a164815aaed..67d8e150d13c61 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -398,7 +398,6 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens def crop(self, max_length: int): """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" - # In case it is negative if max_length < 0: max_length = self.get_seq_length() - abs(max_length) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b76b22cfeac207..702a116dd60e6e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1392,21 +1392,28 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): model_kwargs["cache_position"] = None return model_kwargs - past_length = 0 + # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` from + # `past_length` to `input_length` + if "inputs_embeds" in model_kwargs: + cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 + else: + cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 + + # doesn't have cache -> no need to slice `cache_position` + # has cache -> keep values from `past_length` to `input_length` if model_kwargs.get("past_key_values") is not None: cache = model_kwargs["past_key_values"] + past_length = 0 if not isinstance(cache, Cache): past_length = cache[0][0].shape[2] elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: past_length = cache.get_seq_length() - # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` from - # `past_length` to `past_length + input_length` - if "inputs_embeds" in model_kwargs: - cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - else: - cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - model_kwargs["cache_position"] = cache_position + past_length - 1 + # TODO(joao): this is not torch.compile-friendly, find a work-around + if not is_torchdynamo_compiling(): + cache_position = cache_position[past_length:] + + model_kwargs["cache_position"] = cache_position return model_kwargs def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache: @@ -1530,21 +1537,19 @@ def _tensor_or_none(token_kwargs, token_self, device=None): pad_token_id = eos_token_id[0] logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.") - # we can't infer attn mask if pad token is set to be eos token in model's generation config - if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any(): - if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: - logger.warning_once( - "The attention mask is not set and cannot be inferred from input because pad token is same as eos token." - "As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` " - "to obtain reliable results." - ) - # Sanity checks/warnings if self.config.is_encoder_decoder and decoder_start_token_id is None: raise ValueError( "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." ) if not is_torchdynamo_compiling(): # Checks that depend on tensor-dependent control flow + if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any(): + if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: + logger.warning_once( + "The attention mask is not set and cannot be inferred from input because pad token is same as " + "eos token. As a consequence, you may observe unexpected behavior. Please pass your input's " + "`attention_mask` to obtain reliable results." + ) if eos_token_id is not None and (torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any()): logger.warning( f"`eos_token_id` should consist of positive integers, but is {eos_token_id}. Your generation will " diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 60ff7815a971ae..5cd01e390be8c4 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -512,7 +512,10 @@ def require_read_token(fn): @wraps(fn) def _inner(*args, **kwargs): - with patch("huggingface_hub.utils._headers.get_token", return_value=token): + if token is not None: + with patch("huggingface_hub.utils._headers.get_token", return_value=token): + return fn(*args, **kwargs) + else: # Allow running locally with the default token env variable return fn(*args, **kwargs) return _inner