Skip to content

Commit

Permalink
fix cache positions
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Jul 9, 2024
1 parent 2dbdd16 commit c58f24e
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 19 deletions.
1 change: 0 additions & 1 deletion src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,6 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens
def crop(self, max_length: int):
"""Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""

# In case it is negative
if max_length < 0:
max_length = self.get_seq_length() - abs(max_length)
Expand Down
39 changes: 22 additions & 17 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,21 +1392,28 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
model_kwargs["cache_position"] = None
return model_kwargs

past_length = 0
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` from
# `past_length` to `input_length`
if "inputs_embeds" in model_kwargs:
cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
else:
cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1

# doesn't have cache -> no need to slice `cache_position`
# has cache -> keep values from `past_length` to `input_length`
if model_kwargs.get("past_key_values") is not None:
cache = model_kwargs["past_key_values"]
past_length = 0
if not isinstance(cache, Cache):
past_length = cache[0][0].shape[2]
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
past_length = cache.get_seq_length()

# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` from
# `past_length` to `past_length + input_length`
if "inputs_embeds" in model_kwargs:
cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0)
else:
cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0)
model_kwargs["cache_position"] = cache_position + past_length - 1
# TODO(joao): this is not torch.compile-friendly, find a work-around
if not is_torchdynamo_compiling():
cache_position = cache_position[past_length:]

model_kwargs["cache_position"] = cache_position
return model_kwargs

def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache:
Expand Down Expand Up @@ -1530,21 +1537,19 @@ def _tensor_or_none(token_kwargs, token_self, device=None):
pad_token_id = eos_token_id[0]
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")

# we can't infer attn mask if pad token is set to be eos token in model's generation config
if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any():
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
"As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` "
"to obtain reliable results."
)

# Sanity checks/warnings
if self.config.is_encoder_decoder and decoder_start_token_id is None:
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
if not is_torchdynamo_compiling(): # Checks that depend on tensor-dependent control flow
if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any():
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as "
"eos token. As a consequence, you may observe unexpected behavior. Please pass your input's "
"`attention_mask` to obtain reliable results."
)
if eos_token_id is not None and (torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any()):
logger.warning(
f"`eos_token_id` should consist of positive integers, but is {eos_token_id}. Your generation will "
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,10 @@ def require_read_token(fn):

@wraps(fn)
def _inner(*args, **kwargs):
with patch("huggingface_hub.utils._headers.get_token", return_value=token):
if token is not None:
with patch("huggingface_hub.utils._headers.get_token", return_value=token):
return fn(*args, **kwargs)
else: # Allow running locally with the default token env variable
return fn(*args, **kwargs)

return _inner
Expand Down

0 comments on commit c58f24e

Please sign in to comment.