From c681feef056425e7db1b872aedfc25913edf7fce Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 13 May 2024 17:01:05 +0000 Subject: [PATCH] mvp --- src/transformers/generation/utils.py | 81 +++++++++++-------- .../models/cohere/modeling_cohere.py | 74 +++++++++-------- src/transformers/models/dbrx/modeling_dbrx.py | 74 +++++++++-------- .../models/gemma/modeling_gemma.py | 74 +++++++++-------- .../models/llama/modeling_llama.py | 74 +++++++++-------- src/transformers/models/olmo/modeling_olmo.py | 74 +++++++++-------- .../models/stablelm/modeling_stablelm.py | 1 - 7 files changed, 252 insertions(+), 200 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1c90fdd30753e5..5f03af11359df8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1276,20 +1276,13 @@ def _prepare_generation_config( generation_config = self.generation_config # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config` - # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled. - if is_torchdynamo_compiling(): - model_kwargs = kwargs - generate_attributes_in_kwargs = [ - key for key, value in kwargs.items() if getattr(generation_config, key, None) != value - ] - if len(generate_attributes_in_kwargs) > 0: - raise ValueError( - "`torch.compile` exception: all generation configuration attributes must be passed within a " - f"`generation_config` instance passed to `generate` (found: {generate_attributes_in_kwargs})." - ) - else: + # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an + # exception will be raised in `_validate_model_kwargs` + if not is_torchdynamo_compiling(): generation_config = copy.deepcopy(generation_config) model_kwargs = generation_config.update(**kwargs) + else: + model_kwargs = kwargs return generation_config, model_kwargs @@ -1305,11 +1298,14 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): past_length = model_kwargs["past_key_values"].get_seq_length() else: past_length = model_kwargs["past_key_values"][0][0].shape[2] + + # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` from + # `past_length` to `past_length + input_length` if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] + cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) else: - cur_len = input_ids.shape[-1] - model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) + cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) + model_kwargs["cache_position"] = cache_position + past_length - 1 return model_kwargs def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCache: @@ -1389,11 +1385,12 @@ def _tensor_or_none(token, device=None): raise ValueError( "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." ) - if eos_token_id is not None and (torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any()): - logger.warning( - f"`eos_token_id` should consist of positive integers, but is {eos_token_id}. Your generation will not " - "stop until the maximum length is reached. Depending on other flags, it may even crash." - ) + if not is_torchdynamo_compiling(): # Checks that depend on tensor-dependent control flow + if eos_token_id is not None and (torch.is_floating_point(eos_token_id) or (eos_token_id < 0).any()): + logger.warning( + f"`eos_token_id` should consist of positive integers, but is {eos_token_id}. Your generation will " + "not stop until the maximum length is reached. Depending on other flags, it may even crash." + ) # Update generation config with the updated special tokens tensors generation_config.bos_token_id = bos_token_id @@ -1863,23 +1860,35 @@ def typeerror(): return result - def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool: + def _has_unfinished_sequences( + self, + this_peer_finished: bool, + synced_gpus: bool, + device: torch.device, + cur_len: Optional[int] = None, + max_length: Optional[int] = None, + ) -> bool: """ Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is fed through `this_peer_finished`. ZeRO stage 3-friendly. """ - if synced_gpus: - # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. - # The following logic allows an early break if all peers finished generating their sequence - this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device) - # send 0.0 if we finished, 1.0 otherwise - dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) - # did all peers finish? the reduced sum will be 0.0 then - if this_peer_finished_flag.item() == 0.0: + # torch.compile does not support data-dependent control flow. This is a workaround to allow torch.compile, + # although we lose the ability to stop when all sequences return an EOS token (and other stopping criteria) + if is_torchdynamo_compiling(): + return cur_len < max_length + else: + if synced_gpus: + # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. + # The following logic allows an early break if all peers finished generating their sequence + this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device) + # send 0.0 if we finished, 1.0 otherwise + dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) + # did all peers finish? the reduced sum will be 0.0 then + if this_peer_finished_flag.item() == 0.0: + return False + elif this_peer_finished: return False - elif this_peer_finished: - return False - return True + return True @torch.no_grad() def _contrastive_search( @@ -2319,6 +2328,7 @@ def _sample( output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) do_sample = generation_config.do_sample if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): @@ -2342,12 +2352,14 @@ def _sample( ) # keep track of which sequences are already finished - batch_size = input_ids.shape[0] + batch_size, cur_len = input_ids.shape this_peer_finished = False unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + while self._has_unfinished_sequences( + this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length + ): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -2412,6 +2424,7 @@ def _sample( unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 + cur_len += 1 if streamer is not None: streamer.end() diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index b25528dfe73e2e..729c365256f4e8 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -46,6 +46,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1164,39 +1165,46 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - past_length = 0 - if past_key_values is not None: - if isinstance(past_key_values, Cache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + # with torch.compile, we only support the basic case where we assume `input_ids` has all tokens (prompt + generated tokens) + if is_torchdynamo_compiling(): + input_ids = input_ids[:, cache_position] + + else: + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = ( + past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + ) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 38c1fc814b1cad..b4d03011dd08b1 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -32,6 +32,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1429,39 +1430,46 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - past_length = 0 - if past_key_values is not None: - if isinstance(past_key_values, Cache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + # with torch.compile, we only support the basic case where we assume `input_ids` has all tokens (prompt + generated tokens) + if is_torchdynamo_compiling(): + input_ids = input_ids[:, cache_position] + + else: + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = ( + past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + ) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 12d01a6ea04d3e..790c2ba1429752 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -39,6 +39,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1151,39 +1152,46 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - past_length = 0 - if past_key_values is not None: - if isinstance(past_key_values, Cache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + # with torch.compile, we only support the basic case where we assume `input_ids` has all tokens (prompt + generated tokens) + if is_torchdynamo_compiling(): + input_ids = input_ids[:, cache_position] + + else: + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = ( + past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + ) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c6da59fcfb3edc..a384b4bc90b7a1 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -45,6 +45,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1242,39 +1243,46 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - past_length = 0 - if past_key_values is not None: - if isinstance(past_key_values, Cache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + # with torch.compile, we only support the basic case where we assume `input_ids` has all tokens (prompt + generated tokens) + if is_torchdynamo_compiling(): + input_ids = input_ids[:, cache_position] + + else: + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = ( + past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + ) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 6a7b2f748fcf03..ae4c96c0c9bbcf 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -43,6 +43,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -1219,39 +1220,46 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - past_length = 0 - if past_key_values is not None: - if isinstance(past_key_values, Cache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] + # with torch.compile, we only support the basic case where we assume `input_ids` has all tokens (prompt + generated tokens) + if is_torchdynamo_compiling(): + input_ids = input_ids[:, cache_position] + + else: + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = ( + past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + ) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index bc133ffb3d7329..bff27af6b60d51 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -820,7 +820,6 @@ class StableLmPreTrainedModel(PreTrainedModel): _no_split_modules = ["StableLmDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True - _supports_cache_class = True _supports_sdpa = True def _init_weights(self, module):