From d7fcb07a927b2e2b1bc71155925207e04b49dd9f Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 4 Apr 2024 15:39:00 +0200 Subject: [PATCH 01/20] prepare position ids in modeling utils --- src/transformers/generation/utils.py | 5 + src/transformers/modeling_utils.py | 14 +++ .../models/codegen/modeling_codegen.py | 20 ++-- .../models/cohere/modeling_cohere.py | 20 ++-- .../models/falcon/modeling_falcon.py | 19 +-- src/transformers/models/fuyu/modeling_fuyu.py | 37 ++++-- .../models/gemma/modeling_gemma.py | 18 +-- src/transformers/models/gpt2/modeling_gpt2.py | 22 ++-- .../gpt_bigcode/modeling_gpt_bigcode.py | 29 +++-- .../models/gpt_neo/modeling_gpt_neo.py | 21 ++-- .../models/gpt_neox/modeling_gpt_neox.py | 20 ++-- src/transformers/models/gptj/modeling_gptj.py | 21 ++-- .../models/idefics/modeling_idefics.py | 109 +++++++++--------- .../models/imagegpt/modeling_imagegpt.py | 23 ++-- .../models/llama/modeling_llama.py | 18 +-- .../models/llava/modeling_llava.py | 15 ++- .../models/llava_next/modeling_llava_next.py | 15 ++- .../models/mistral/modeling_mistral.py | 20 ++-- .../models/mixtral/modeling_mixtral.py | 20 ++-- .../models/persimmon/modeling_persimmon.py | 20 ++-- src/transformers/models/phi/modeling_phi.py | 20 ++-- .../models/qwen2/modeling_qwen2.py | 20 ++-- .../models/qwen2_moe/modeling_qwen2_moe.py | 20 ++-- .../models/stablelm/modeling_stablelm.py | 20 ++-- .../models/starcoder2/modeling_starcoder2.py | 27 ++--- .../models/vipllava/modeling_vipllava.py | 15 ++- tests/generation/test_utils.py | 50 ++++++++ 27 files changed, 399 insertions(+), 259 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index cb3ac0ff1d121c..87d610aee19e7d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -667,6 +667,11 @@ def _update_model_kwargs_for_generation( if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None: model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 + if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None: + position_ids = model_kwargs["position_ids"] + model_kwargs["position_ids"] = torch.cat([position_ids, position_ids[:, -1:] + 1], dim=-1) + print(model_kwargs["position_ids"]) + return model_kwargs def _reorder_cache(self, past_key_values, beam_idx): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index fd0afa521a1453..d7c888dcc7b954 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4360,6 +4360,20 @@ def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask): logger.warning_once(warn_string) + def get_position_ids_from_attention_mask(self, attention_mask, past_length, seq_length, device): + """ + Tries to infer position ids given attention mask and past kv cache length. All instances when + `position_ids=None` should call this method. + """ + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids[..., -seq_length:].view(-1, seq_length) + else: + position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + return position_ids + @property def _is_quantized_training_enabled(self): warnings.warn( diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 41f23900c29a2c..f79f21e38cd69e 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -467,8 +467,10 @@ def forward( past_length = past_key_values[0][0].size(-2) if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) # Attention mask. if attention_mask is not None: @@ -597,6 +599,7 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # Omit tokens covered by past_key_values + past_length = 0 if past_key_values: past_length = past_key_values[0][0].shape[2] @@ -613,13 +616,12 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if position_ids is None: + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=input_ids.shape[-1], device=input_ids.device + ) + else: + position_ids = position_ids[:, -input_ids.shape[-1] :] return { "input_ids": input_ids, diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index e949bc14482e74..147677da34d5c9 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -884,7 +884,11 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + device = input_ids.device if input_ids is not None else inputs_embeds.device + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_seen_tokens, seq_length=seq_length, device=device + ) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) @@ -1178,12 +1182,14 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index f1cff3f181ac56..395c12801a6e20 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1077,10 +1077,9 @@ def forward( alibi = None if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=device ) - position_ids = position_ids.unsqueeze(0) if self._use_flash_attention_2: # 2d mask is passed through the layers @@ -1217,6 +1216,7 @@ def prepare_inputs_for_generation( position_ids: Optional[torch.Tensor] = None, **kwargs, ) -> dict: + past_length = 0 if past_key_values is not None: past_length = past_key_values[0][0].shape[2] @@ -1230,12 +1230,13 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, remove_prefix_length:] # Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE. - if not self.transformer.use_alibi and attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if not self.transformer.use_alibi: + if position_ids is None: + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=input_ids.shape[-1], device=input_ids.device + ) + else: + position_ids = position_ids[:, -input_ids.shape[-1] :] return { "input_ids": input_ids, diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index f94bac569fc9bb..75bddf671d2d71 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -281,10 +281,9 @@ def forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=device ) - position_ids = position_ids.unsqueeze(0) if inputs_embeds is None: inputs_embeds = self.language_model.get_input_embeddings()(input_ids) @@ -325,16 +324,32 @@ def prepare_inputs_for_generation( image_patches_indices=None, **kwargs, ): - if past_key_values: - input_ids = input_ids[:, -1:] + # Omit tokens covered by past_key_values + past_length = 0 + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 2d93c43425f99a..f4b1fefe0022c6 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -886,7 +886,9 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_seen_tokens, seq_length=inputs_embeds.shape[-1], device=inputs_embeds.device + ) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) @@ -1182,12 +1184,14 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 1409a3fc3f0fcb..15e51f39aa248a 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1021,9 +1021,12 @@ def forward( past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) + + seq_length = input_shape[-1] if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) # Attention mask. if attention_mask is not None: @@ -1227,6 +1230,7 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # Omit tokens covered by past_key_values + past_length = 0 if past_key_values: past_length = past_key_values[0][0].shape[2] @@ -1244,14 +1248,14 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) else: - position_ids = None + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 4e3b8498480c9e..886bbc9daab423 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -979,15 +979,11 @@ def forward( else: past_length = past_key_values[0].size(-2) - if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_length > 0: - position_ids = position_ids[:, past_length : input_shape[-1] + past_length :] - elif position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + if position_ids is None: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) # Self-attention mask. query_length = input_shape[-1] @@ -1163,6 +1159,7 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # Omit tokens covered by past_key_values + past_length = 0 if past_key_values: if self.config.multi_query: past_length = past_key_values[0].shape[1] @@ -1182,15 +1179,15 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) else: - position_ids = None + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 2fbf4677ca6f44..4ad299f4be4bcf 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -784,8 +784,10 @@ def forward( past_length = past_key_values[0][0].size(-2) if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + seq_length = input_shape[-1] + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -900,6 +902,7 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # Omit tokens covered by past_key_values + past_length = 0 if past_key_values: past_length = past_key_values[0][0].shape[2] @@ -916,13 +919,15 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 83c99202ac9379..eb801dc22d88be 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -859,8 +859,9 @@ def forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) # Attention mask. if attention_mask is not None: @@ -1074,6 +1075,7 @@ def prepare_inputs_for_generation( ): input_shape = input_ids.shape # cut decoder_input_ids if past is used + past_length = 0 if past_key_values is not None: past_length = past_key_values[0][0].shape[2] @@ -1087,12 +1089,14 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 3c6ddac4ecf4ca..6775783fe7b637 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -870,8 +870,10 @@ def forward( past_length = past_key_values[0][0].size(-2) if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + seq_length = input_shape[-1] + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) if not self._use_flash_attention_2: # Attention mask. @@ -1052,6 +1054,7 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # Omit tokens covered by past_key_values + past_length = 0 if past_key_values: past_length = past_key_values[0][0].shape[2] @@ -1068,13 +1071,15 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 47024d24e60623..47c9fadf60e0b0 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -156,10 +156,6 @@ def expand_inputs_for_generation( model_kwargs["perceiver_embeddings"] = model_kwargs.get("perceiver_embeddings", None) model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask", None) - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) - if attention_mask is not None: model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) @@ -184,45 +180,6 @@ def expand_inputs_for_generation( return input_ids, model_kwargs -def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs): - token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - pixel_values = kwargs.get("pixel_values", None) - image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None) - perceiver_embeddings = kwargs.get("perceiver_embeddings", None) - image_attention_mask = kwargs.get("image_attention_mask", None) - interpolate_pos_encoding = kwargs.get("interpolate_pos_encoding", False) - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - "pixel_values": pixel_values, - "image_encoder_embeddings": image_encoder_embeddings, - "perceiver_embeddings": perceiver_embeddings, - "image_attention_mask": image_attention_mask, - "interpolate_pos_encoding": interpolate_pos_encoding, - } - - def freeze_model(model, module_exceptions=[]): mapping = { "LayerNorm": nn.LayerNorm, @@ -1155,15 +1112,11 @@ def forward( past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - elif position_ids is None: - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=device ) - position_ids = position_ids.unsqueeze(0) if (pixel_values, image_encoder_embeddings, perceiver_embeddings).count(None) != 2: raise ValueError( @@ -1527,7 +1480,7 @@ def forward( image_hidden_states=outputs.image_hidden_states, ) - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): image_hidden_states = kwargs.pop("image_hidden_states", None) if image_hidden_states is not None: if self.config.use_resampler: @@ -1535,11 +1488,53 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): else: kwargs["image_encoder_embeddings"] = image_hidden_states kwargs["pixel_values"] = None - inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs) - unwanted_kwargs = ["token_type_ids"] - for kwarg in unwanted_kwargs: - inputs.pop(kwarg, None) - return inputs + + # only last token for inputs_ids if past is defined in kwargs + attention_mask = kwargs.get("attention_mask", None) + past_length = 0 + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + position_ids = kwargs.get("position_ids", None) + + seq_length = input_ids.shape[-1] + if position_ids is None: + device = input_ids.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] + + pixel_values = kwargs.get("pixel_values", None) + image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None) + perceiver_embeddings = kwargs.get("perceiver_embeddings", None) + image_attention_mask = kwargs.get("image_attention_mask", None) + interpolate_pos_encoding = kwargs.get("interpolate_pos_encoding", False) + + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "image_encoder_embeddings": image_encoder_embeddings, + "perceiver_embeddings": perceiver_embeddings, + "image_attention_mask": image_attention_mask, + "interpolate_pos_encoding": interpolate_pos_encoding, + } @staticmethod def _expand_inputs_for_generation( diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 3b9be17246e81e..27c8dcddbdbbdb 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -727,9 +727,12 @@ def forward( past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) + if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + seq_length = input_shape[-1] + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) # ImageGPTAttention mask. if attention_mask is not None: @@ -899,6 +902,7 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[bool] = None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # Omit tokens covered by past_key_values + past_length = 0 if past_key_values: past_length = past_key_values[0][0].shape[2] @@ -916,14 +920,15 @@ def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = input_ids.shape[-1] + if position_ids is None: + device = input_ids.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) else: - position_ids = None + position_ids = position_ids[:, -seq_length:] + return { "input_ids": input_ids, "past_key_values": past_key_values, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 8d0baf63c7b3fe..78e44f45d3e091 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -985,7 +985,9 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_seen_tokens, seq_length=inputs_embeds.shape[-1], device=inputs_embeds.device + ) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) @@ -1279,12 +1281,14 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index f195c1140be86b..1ffa7f84f1a572 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -517,6 +517,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs ): + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -543,12 +544,14 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 155d9e3e6abf40..1afa7a98122a6a 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -642,6 +642,7 @@ def prepare_inputs_for_generation( attention_mask=None, **kwargs, ): + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -668,12 +669,14 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index e219271e8ee5c3..587b08979fafea 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -976,10 +976,9 @@ def forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=device ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -1199,6 +1198,7 @@ def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): # Omit tokens covered by past_key_values + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -1229,12 +1229,14 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index e9e801bb71670b..5cf991b6398560 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1165,10 +1165,9 @@ def forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=device ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -1429,6 +1428,7 @@ def prepare_inputs_for_generation( **kwargs, ): # Omit tokens covered by past_key_values + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -1459,12 +1459,14 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index c83ba413952b16..02a00bb7a9d0db 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -624,10 +624,9 @@ def forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=device ) - position_ids = position_ids.unsqueeze(0) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -826,6 +825,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -856,12 +856,14 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 13719166edf9d9..40df5214376e43 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -997,10 +997,9 @@ def forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=device ) - position_ids = position_ids.unsqueeze(0) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -1211,6 +1210,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -1241,12 +1241,14 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 7ca32c37685c3c..2182236bcfbdb5 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -987,10 +987,9 @@ def forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=device ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -1210,6 +1209,7 @@ def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): # Omit tokens covered by past_key_values + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -1240,12 +1240,14 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index e921af9232dd25..1f41eeccaf740d 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1157,10 +1157,9 @@ def forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=device ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -1414,6 +1413,7 @@ def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): # Omit tokens covered by past_key_values + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -1444,12 +1444,14 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 76aca7bae91d18..dec10d2d3c47ca 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -945,10 +945,9 @@ def forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=device ) - position_ids = position_ids.unsqueeze(0) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -1154,6 +1153,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -1184,12 +1184,14 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 85a76f87b8d6e5..06fc8d9d7f17a3 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -963,18 +963,16 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=inputs_embeds.shape[-1], device=inputs_embeds.device ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: @@ -1190,6 +1188,7 @@ def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): # Omit tokens covered by past_key_values + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -1220,12 +1219,14 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 1b20353410c895..f4574b64f5676c 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -512,6 +512,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs ): + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -538,12 +539,14 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index b346b745d8bcbe..13e4b899e6c8b4 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1606,6 +1606,56 @@ def test_generate_continue_from_past_key_values(self): ) ) + def test_generate_with_and_without_position_ids(self): + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, _ = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() + model_forward_args = inspect.signature(model.forward).parameters + if "position_ids" not in model_forward_args: + self.skipTest("This model doesn't use `position_ids`") + + out_wo_positions = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=5) + + # infer position ids from attn mask and generate again + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids[..., -input_ids.shape[-1] :].view(-1, input_ids.shape[-1]) + out_w_positions = model.generate( + input_ids, attention_mask=attention_mask, position_ids=position_ids, max_new_tokens=5 + ) + + # The two sets of generated sequences must match, if generate can infer position ids correctly + # and can continue adding new ids to the already passed position ids + self.assertListEqual(out_wo_positions.tolist(), out_w_positions.tolist()) + + def test_generate_with_and_without_position_ids_inputs_embeds(self): + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, _ = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() + model_forward_args = inspect.signature(model.forward).parameters + if "position_ids" not in model_forward_args: + self.skipTest("This model doesn't use `position_ids`") + + if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): + self.skipTest("This model doesn't use `inputs_embeds`") + + inputs_embeds = model.get_input_embeddings()(input_ids) + out_wo_positions = model.generate( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=5 + ) + + # infer position ids from attn mask and generate again + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids[..., -input_ids.shape[-1] :].view(-1, input_ids.shape[-1]) + out_w_positions = model.generate( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, max_new_tokens=5 + ) + + # The two sets of generated sequences must match, if generate can infer position ids correctly + # and can continue adding new ids to the already passed position ids + self.assertListEqual(out_wo_positions.tolist(), out_w_positions.tolist()) + @parameterized.expand([(1, False), (1, True), (4, False)]) def test_new_cache_format(self, num_beams, do_sample): # Tests that generating with the new format is exactly the same as the legacy one (for models that support it). From eaecb4ff853ecd6f884f6bfe3a8a06ced3b60c65 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 4 Apr 2024 19:04:49 +0200 Subject: [PATCH 02/20] fix seq length when inputs embeds --- .../models/codegen/modeling_codegen.py | 2 +- .../models/cohere/modeling_cohere.py | 6 ++++-- src/transformers/models/fuyu/modeling_fuyu.py | 4 +++- src/transformers/models/gemma/modeling_gemma.py | 8 +++++--- src/transformers/models/gpt2/modeling_gpt2.py | 17 +++++++++-------- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 6 ++++-- .../models/gpt_neo/modeling_gpt_neo.py | 4 +++- .../models/gpt_neox/modeling_gpt_neox.py | 4 +++- src/transformers/models/gptj/modeling_gptj.py | 4 +++- src/transformers/models/llama/modeling_llama.py | 6 ++++-- src/transformers/models/llava/modeling_llava.py | 4 +++- .../models/llava_next/modeling_llava_next.py | 4 +++- .../models/mistral/modeling_mistral.py | 4 +++- .../models/mixtral/modeling_mixtral.py | 4 +++- .../models/persimmon/modeling_persimmon.py | 4 +++- src/transformers/models/phi/modeling_phi.py | 4 +++- src/transformers/models/qwen2/modeling_qwen2.py | 4 +++- .../models/qwen2_moe/modeling_qwen2_moe.py | 4 +++- .../models/stablelm/modeling_stablelm.py | 4 +++- .../models/starcoder2/modeling_starcoder2.py | 4 +++- .../models/vipllava/modeling_vipllava.py | 4 +++- 21 files changed, 72 insertions(+), 33 deletions(-) diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index f79f21e38cd69e..00ce7905d7fc6a 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -467,7 +467,7 @@ def forward( past_length = past_key_values[0][0].size(-2) if position_ids is None: - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[1] position_ids = self.get_position_ids_from_attention_mask( attention_mask, past_length, seq_length=seq_length, device=device ) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 147677da34d5c9..89c532bcee3c74 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -885,7 +885,7 @@ def forward( if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[1] position_ids = self.get_position_ids_from_attention_mask( attention_mask, past_seen_tokens, seq_length=seq_length, device=device ) @@ -1182,7 +1182,9 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 75bddf671d2d71..543cfb1382d1b3 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -342,7 +342,9 @@ def prepare_inputs_for_generation( # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index f4b1fefe0022c6..ada4760e6572ed 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -887,7 +887,7 @@ def forward( if position_ids is None: position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_seen_tokens, seq_length=inputs_embeds.shape[-1], device=inputs_embeds.device + attention_mask, past_seen_tokens, seq_length=inputs_embeds.shape[1], device=inputs_embeds.device ) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) @@ -1184,7 +1184,9 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( @@ -1202,7 +1204,7 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[1] if cache_position is None: cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) else: diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 15e51f39aa248a..b5d1b088238966 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1248,7 +1248,9 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( @@ -1437,6 +1439,7 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # Omit tokens covered by past_key_values + past_length = 0 if past_key_values: past_length = past_key_values[0][0].shape[2] @@ -1454,14 +1457,12 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if position_ids is None: + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=input_ids.shape[1], device=input_ids.device + ) else: - position_ids = None + position_ids = position_ids[:, -input_ids.shape[1] :] return { "input_ids": input_ids, diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 886bbc9daab423..781dc3d404b44f 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -980,7 +980,7 @@ def forward( past_length = past_key_values[0].size(-2) if position_ids is None: - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[1] position_ids = self.get_position_ids_from_attention_mask( attention_mask, past_length, seq_length=seq_length, device=device ) @@ -1179,7 +1179,9 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 4ad299f4be4bcf..7efd05d8f97d87 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -919,7 +919,9 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index eb801dc22d88be..cebc582b188289 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -1089,7 +1089,9 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 6775783fe7b637..f272d2df400a4d 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -1071,7 +1071,9 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 78e44f45d3e091..5fbe14020fa780 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -986,7 +986,7 @@ def forward( if position_ids is None: position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_seen_tokens, seq_length=inputs_embeds.shape[-1], device=inputs_embeds.device + attention_mask, past_seen_tokens, seq_length=inputs_embeds.shape[1], device=inputs_embeds.device ) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) @@ -1281,7 +1281,9 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 1ffa7f84f1a572..ba191c8ee7adba 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -544,7 +544,9 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 1afa7a98122a6a..0989713d6f70fa 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -669,7 +669,9 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 587b08979fafea..183ee2e391912e 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1229,7 +1229,9 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 5cf991b6398560..9e1767724b2805 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1459,7 +1459,9 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 02a00bb7a9d0db..adc6d3d9c667ec 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -856,7 +856,9 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 40df5214376e43..cfe0b0b00ba7ab 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1241,7 +1241,9 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 2182236bcfbdb5..13e1a2e5cc002e 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1240,7 +1240,9 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 1f41eeccaf740d..da1926fcee5dab 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1444,7 +1444,9 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index dec10d2d3c47ca..ce1cac6f19da15 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -1184,7 +1184,9 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 06fc8d9d7f17a3..5458e6694fff58 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1219,7 +1219,9 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index f4574b64f5676c..edc2181a6bc4bf 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -539,7 +539,9 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] position_ids = kwargs.get("position_ids", None) - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-1] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( From cf66b56b494efa2af17a6b2e4e4f2d0b67ac576f Mon Sep 17 00:00:00 2001 From: raushan Date: Fri, 5 Apr 2024 09:24:09 +0200 Subject: [PATCH 03/20] forgot to fix starcoder2 --- src/transformers/models/starcoder2/modeling_starcoder2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 5458e6694fff58..eccabd099d05f3 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -968,7 +968,7 @@ def forward( if position_ids is None: position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_key_values_length, seq_length=inputs_embeds.shape[-1], device=inputs_embeds.device + attention_mask, past_key_values_length, seq_length=inputs_embeds.shape[1], device=inputs_embeds.device ) else: position_ids = position_ids.view(-1, seq_length).long() From ecba33fb73fb2611342cfde8fbd4768cc1bbdc8b Mon Sep 17 00:00:00 2001 From: raushan Date: Fri, 5 Apr 2024 09:30:32 +0200 Subject: [PATCH 04/20] fix copies --- src/transformers/models/gemma/modeling_gemma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index ada4760e6572ed..465bd765cb075a 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1204,7 +1204,7 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[1] + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] if cache_position is None: cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) else: From b769074e8392ce090e93adb8b2b9a1274957cbe2 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 9 Apr 2024 10:59:55 +0200 Subject: [PATCH 05/20] remove that print :) --- src/transformers/generation/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 87d610aee19e7d..f2daa626f61c01 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -670,7 +670,6 @@ def _update_model_kwargs_for_generation( if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None: position_ids = model_kwargs["position_ids"] model_kwargs["position_ids"] = torch.cat([position_ids, position_ids[:, -1:] + 1], dim=-1) - print(model_kwargs["position_ids"]) return model_kwargs From b96725f60e12a5305489c3296b407b0cd70b25df Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 9 Apr 2024 11:16:02 +0200 Subject: [PATCH 06/20] lets add same for assisted decoding --- .../generation/candidate_generator.py | 16 +++++ src/transformers/generation/utils.py | 2 + tests/generation/test_utils.py | 60 +++++++++++++++++++ 3 files changed, 78 insertions(+) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 08590219561531..ec084ee1413b7b 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -182,6 +182,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder ) self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len) + self.assistant_kwargs = _prepare_position_ids(self.assistant_kwargs, new_cur_len) # 2. Forecast next N tokens using the assistant model. assistant_generation_kwargs = { @@ -408,3 +409,18 @@ def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Di token_type_copies = final_token_type.repeat(1, type_length_diff) model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1) return model_kwargs + + +def _prepare_position_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]: + position_ids = model_kwargs.get("position_ids") + if position_ids is None: + return model_kwargs + + # we assume batch_size=1 for assited decoding (needs rework if bs > 1) + length_diff = new_length - position_ids[0, -1] + if length_diff < 0: + position_ids = position_ids[:, :length_diff] + elif length_diff > 0: + new_position_ids = torch.arange(position_ids[0, -1], new_length, device=position_ids.device).unsqueeze(0) + model_kwargs["position_ids"] = torch.cat([model_kwargs["position_ids"], new_position_ids], dim=-1) + return model_kwargs diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f2daa626f61c01..e4fa7be573aed0 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -43,6 +43,7 @@ PromptLookupCandidateGenerator, _crop_past_key_values, _prepare_attention_mask, + _prepare_position_ids, _prepare_token_type_ids, ) from .configuration_utils import GenerationConfig, GenerationMode @@ -4613,6 +4614,7 @@ def _assisted_decoding( candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder ) candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) + candidate_kwargs = _prepare_position_ids(candidate_kwargs, candidate_input_ids.shape[1]) if "cache_position" in candidate_kwargs: candidate_kwargs["cache_position"] = torch.cat( ( diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 13e4b899e6c8b4..8562b254f65375 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1188,6 +1188,66 @@ def test_assisted_decoding_matches_greedy_search(self): for output in (output_greedy, output_assisted): self._check_outputs(output, input_ids, model.config, use_cache=True) + @is_flaky() + def test_assisted_decoding_position_ids(self): + """ + Similar test to `test_assisted_decoding_matches_greedy_search` but passes in position ids to check if + assisted decoding can correctly expand/crop it while generating + """ + for model_class in self.all_generative_model_classes: + if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + self.skipTest("Won't fix: old model with different cache format") + if any( + model_name in model_class.__name__.lower() + for model_name in [ + "bigbirdpegasus", + "led", + "mega", + "speech2text", + "git", + "prophetnet", + "seamlessm4t", + "clvp", + ] + ): + self.skipTest("May fix in the future: need model-specific fixes") + + config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1) + if not hasattr(config, "use_cache"): + self.skipTest("This model doesn't support caching") + + config.use_cache = True + config.is_decoder = True + model = model_class(config).to(torch_device).eval() + model_forward_args = inspect.signature(model.forward).parameters + if "position_ids" not in model_forward_args: + self.skipTest("This model doesn't use `position_ids`") + + generation_kwargs = { + "eos_token_id": -1, + "max_new_tokens": 4, + "num_beams": 1, + "do_sample": False, + } + + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids[..., -input_ids.shape[-1] :].view(-1, input_ids.shape[-1]) + output_greedy = model.generate( + input_ids, attention_mask=attention_mask, position_ids=position_ids, **generation_kwargs + ) + + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" + generation_kwargs.update({"assistant_model": assistant_model}) + output_assisted = model.generate( + input_ids, attention_mask=attention_mask, position_ids=position_ids, **generation_kwargs + ) + + # The two outputs must match and their shape must be as expected + self.assertListEqual(output_greedy.tolist(), output_assisted.tolist()) + @is_flaky() def test_prompt_lookup_decoding_matches_greedy_search(self): # This test ensures that the prompt lookup generation does not introduce output changes over greedy search. From ff4e4240a86fe07c326f83fc24b9411e55c5f6dd Mon Sep 17 00:00:00 2001 From: raushan Date: Wed, 17 Apr 2024 12:07:06 +0200 Subject: [PATCH 07/20] framework equivalence? --- tests/models/gpt2/test_modeling_flax_gpt2.py | 10 +++------- tests/test_modeling_common.py | 5 +++++ tests/test_modeling_flax_common.py | 5 +++++ tests/test_modeling_tf_common.py | 5 +++++ 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/tests/models/gpt2/test_modeling_flax_gpt2.py b/tests/models/gpt2/test_modeling_flax_gpt2.py index fbf2d6c333fd8a..cbdc3196cdeabf 100644 --- a/tests/models/gpt2/test_modeling_flax_gpt2.py +++ b/tests/models/gpt2/test_modeling_flax_gpt2.py @@ -273,13 +273,9 @@ def test_equivalence_pt_to_flax(self): pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning pt_model_class = getattr(transformers, pt_model_class_name) - batch_size, seq_length = pt_inputs["input_ids"].shape - rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,)) - for batch_idx, start_index in enumerate(rnd_start_indices): - pt_inputs["attention_mask"][batch_idx, :start_index] = 0 - pt_inputs["attention_mask"][batch_idx, start_index:] = 1 - prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0 - prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1 + pt_inputs["attention_mask"] = torch.ones_like(pt_inputs["input_ids"]) + prepared_inputs_dict["attention_mask"] = jnp.ones_like(prepared_inputs_dict["input_ids"]) + pt_model = pt_model_class(config).eval() fx_model = model_class(config, dtype=jnp.float32) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a396ac752d324a..6c303f5c638579 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2600,6 +2600,11 @@ def test_equivalence_pt_to_flax(self): # convert inputs to Flax fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} + # Flax and torch calculate position ids differently, which is noticed only when attn mask + # does not follow expected pattern where zeros are only on one side (left or right) + pt_inputs["attention_mask"] = torch.ones_like(pt_inputs["input_ids"]) + fx_inputs["attention_mask"] = jnp.ones_like(fx_inputs["input_ids"]) + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_model.params = fx_state diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 22d6b241f0048c..020c0a4c9c198f 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -289,6 +289,11 @@ def test_equivalence_pt_to_flax(self): prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} + # Flax and torch calculate position ids differently, which is noticed only when attn mask + # does not follow expected pattern where zeros are only on one side (left or right) + pt_inputs["attention_mask"] = torch.ones_like(pt_inputs["input_ids"]) + prepared_inputs_dict["attention_mask"] = jnp.ones_like(prepared_inputs_dict["input_ids"]) + # load corresponding PyTorch class pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning pt_model_class = getattr(transformers, pt_model_class_name) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index f396875570c98d..76dc47f130ce96 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -600,6 +600,11 @@ def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict): k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items() } + # TF and torch calculate position ids differently, which is noticed only when attn mask + # does not follow expected pattern where zeros are only on one side (left or right) + pt_inputs_dict["attention_mask"] = torch.ones_like(pt_inputs_dict["input_ids"]) + tf_inputs_dict["attention_mask"] = tf.ones_like(tf_inputs_dict["input_ids"]) + # send pytorch model to the correct device pt_model.to(torch_device) From 55311184a459112b53467288ad831bd0d39b6bfc Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 18 Apr 2024 16:19:31 +0200 Subject: [PATCH 08/20] final solution, lets make all frameworks same --- src/transformers/modeling_flax_utils.py | 12 ++++++++++++ src/transformers/models/ctrl/modeling_ctrl.py | 17 +++++++++++++++-- .../models/ctrl/modeling_tf_ctrl.py | 3 ++- .../models/gemma/modeling_flax_gemma.py | 6 ++---- .../models/gpt2/modeling_flax_gpt2.py | 6 ++---- .../models/gpt2/modeling_tf_gpt2.py | 11 +++++++++-- .../models/gpt_neo/modeling_flax_gpt_neo.py | 6 ++---- .../models/gptj/modeling_flax_gptj.py | 6 ++---- .../models/gptj/modeling_tf_gptj.py | 11 +++++++++-- .../models/llama/modeling_flax_llama.py | 6 ++---- .../models/mistral/modeling_flax_mistral.py | 6 ++---- tests/models/gpt2/test_modeling_flax_gpt2.py | 9 +++++++-- tests/test_modeling_common.py | 5 ----- tests/test_modeling_flax_common.py | 5 ----- tests/test_modeling_tf_common.py | 5 ----- 15 files changed, 66 insertions(+), 48 deletions(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index da373603420ba2..8c213685648069 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -1248,6 +1248,18 @@ def register_for_auto_class(cls, auto_class="FlaxAutoModel"): cls._auto_class = auto_class + def get_position_ids_from_attention_mask(self, attention_mask, batch_size, seq_length): + """ + Tries to infer position ids given attention mask and past kv cache length. All instances when + `position_ids=None` should call this method. + """ + if attention_mask is not None: + position_ids = jnp.cumsum(attention_mask, axis=-1) - 1 + position_ids = jnp.where(attention_mask == 0, 1, position_ids) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length)[None, :], (batch_size, seq_length)) + return position_ids + # To update the docstring, we need to copy the method, otherwise we change the original docstring. FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub) diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 7534a0e50c9a23..b30f41732046b7 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -413,8 +413,10 @@ def forward( else: past_length = past_key_values[0][0].size(-2) if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=input_shape[1], device=device + ) # Attention mask. if attention_mask is not None: @@ -525,6 +527,7 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs): # only last tokens for inputs_ids if past is defined in kwargs + past_length = 0 if past_key_values is not None: past_length = past_key_values[0][0].shape[2] @@ -537,6 +540,16 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cac input_ids = input_ids[:, remove_prefix_length:] + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if position_ids is None: + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=input_ids.shape[1], device=input_ids.device + ) + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache} @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING) diff --git a/src/transformers/models/ctrl/modeling_tf_ctrl.py b/src/transformers/models/ctrl/modeling_tf_ctrl.py index 6569b9e7d7b788..15da9f07f72f6f 100644 --- a/src/transformers/models/ctrl/modeling_tf_ctrl.py +++ b/src/transformers/models/ctrl/modeling_tf_ctrl.py @@ -702,7 +702,8 @@ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache= attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None and position_ids is None: - position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True) + position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + position_ids = tf.where(attention_mask == 0, 1, position_ids) if past_key_values: position_ids = tf.expand_dims(position_ids[:, -1], -1) diff --git a/src/transformers/models/gemma/modeling_flax_gemma.py b/src/transformers/models/gemma/modeling_flax_gemma.py index 235f65680fad3e..d2e527fce0fef9 100644 --- a/src/transformers/models/gemma/modeling_flax_gemma.py +++ b/src/transformers/models/gemma/modeling_flax_gemma.py @@ -504,7 +504,7 @@ def __call__( if past_key_values is not None: raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, sequence_length) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) @@ -747,10 +747,8 @@ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: O # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, seq_length) return { "past_key_values": past_key_values, diff --git a/src/transformers/models/gpt2/modeling_flax_gpt2.py b/src/transformers/models/gpt2/modeling_flax_gpt2.py index c3ef377642a3c5..5813f1249923f3 100644 --- a/src/transformers/models/gpt2/modeling_flax_gpt2.py +++ b/src/transformers/models/gpt2/modeling_flax_gpt2.py @@ -485,7 +485,7 @@ def __call__( if past_key_values is not None: raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, sequence_length) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) @@ -752,12 +752,10 @@ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: O # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice( extended_attention_mask, attention_mask.astype("i4"), (0, 0) ) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, seq_length) return { "past_key_values": past_key_values, diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index 26a4e7a398ae8d..8711865712f98e 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -430,7 +430,13 @@ def call( past_length = shape_list(past_key_values[0][0])[-2] if position_ids is None: - position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) + if attention_mask is not None: + position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + position_ids = tf.where(attention_mask == 0, 1, position_ids) + position_ids = position_ids[..., -input_shape[-1] :] + position_ids = tf.reshape(position_ids, (-1, input_shape[-1])) + else: + position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) if attention_mask is not None: # We create a 3D attention mask from a 2D tensor mask. @@ -860,7 +866,8 @@ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache= attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None and position_ids is None: - position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True) + position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + position_ids = tf.where(attention_mask == 0, 1, position_ids) if past_key_values: position_ids = tf.expand_dims(position_ids[:, -1], -1) diff --git a/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py index 5639ca50f166a2..5233439f819800 100644 --- a/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py @@ -424,7 +424,7 @@ def __call__( if past_key_values is not None: raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, sequence_length) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) @@ -664,10 +664,8 @@ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: O # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, seq_length) return { "past_key_values": past_key_values, diff --git a/src/transformers/models/gptj/modeling_flax_gptj.py b/src/transformers/models/gptj/modeling_flax_gptj.py index 9f0d4d6e860003..5e690127dc25de 100644 --- a/src/transformers/models/gptj/modeling_flax_gptj.py +++ b/src/transformers/models/gptj/modeling_flax_gptj.py @@ -458,7 +458,7 @@ def __call__( if past_key_values is not None: raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, sequence_length) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) @@ -693,10 +693,8 @@ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: O # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, seq_length) return { "past_key_values": past_key_values, diff --git a/src/transformers/models/gptj/modeling_tf_gptj.py b/src/transformers/models/gptj/modeling_tf_gptj.py index 5c315b5b66f049..b3e5635856dafa 100644 --- a/src/transformers/models/gptj/modeling_tf_gptj.py +++ b/src/transformers/models/gptj/modeling_tf_gptj.py @@ -446,7 +446,13 @@ def call( past_length = shape_list(past_key_values[0][0])[-2] if position_ids is None: - position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) + if attention_mask is not None: + position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + position_ids = tf.where(attention_mask == 0, 1, position_ids) + position_ids = position_ids[..., -input_shape[-1] :] + position_ids = tf.reshape(position_ids, (-1, input_shape[-1])) + else: + position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) if attention_mask is not None: # We create a 3D attention mask from a 2D tensor mask. @@ -771,7 +777,8 @@ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache= attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None and position_ids is None: - position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True) + position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + position_ids = tf.where(attention_mask == 0, 1, position_ids) if past_key_values: position_ids = tf.expand_dims(position_ids[:, -1], -1) diff --git a/src/transformers/models/llama/modeling_flax_llama.py b/src/transformers/models/llama/modeling_flax_llama.py index 1e7fb5d1a1cd94..06509f361a0601 100644 --- a/src/transformers/models/llama/modeling_flax_llama.py +++ b/src/transformers/models/llama/modeling_flax_llama.py @@ -492,7 +492,7 @@ def __call__( if past_key_values is not None: raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, sequence_length) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) @@ -723,11 +723,9 @@ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: O # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, seq_length) return { "past_key_values": past_key_values, "attention_mask": extended_attention_mask, diff --git a/src/transformers/models/mistral/modeling_flax_mistral.py b/src/transformers/models/mistral/modeling_flax_mistral.py index 3480fc7214a84a..4bdbb6c0d1a66a 100644 --- a/src/transformers/models/mistral/modeling_flax_mistral.py +++ b/src/transformers/models/mistral/modeling_flax_mistral.py @@ -480,7 +480,7 @@ def __call__( if past_key_values is not None: raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, sequence_length) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) @@ -715,10 +715,8 @@ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: O # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, seq_length) return { "past_key_values": past_key_values, diff --git a/tests/models/gpt2/test_modeling_flax_gpt2.py b/tests/models/gpt2/test_modeling_flax_gpt2.py index cbdc3196cdeabf..820b23bd70cfc5 100644 --- a/tests/models/gpt2/test_modeling_flax_gpt2.py +++ b/tests/models/gpt2/test_modeling_flax_gpt2.py @@ -273,8 +273,13 @@ def test_equivalence_pt_to_flax(self): pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning pt_model_class = getattr(transformers, pt_model_class_name) - pt_inputs["attention_mask"] = torch.ones_like(pt_inputs["input_ids"]) - prepared_inputs_dict["attention_mask"] = jnp.ones_like(prepared_inputs_dict["input_ids"]) + batch_size, seq_length = pt_inputs["input_ids"].shape + rnd_start_indices = np.random.randint(0, seq_length - 1, size=(batch_size,)) + for batch_idx, start_index in enumerate(rnd_start_indices): + pt_inputs["attention_mask"][batch_idx, :start_index] = 0 + pt_inputs["attention_mask"][batch_idx, start_index:] = 1 + prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0 + prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1 pt_model = pt_model_class(config).eval() fx_model = model_class(config, dtype=jnp.float32) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 6c303f5c638579..a396ac752d324a 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2600,11 +2600,6 @@ def test_equivalence_pt_to_flax(self): # convert inputs to Flax fx_inputs = {k: np.array(v.to("cpu")) for k, v in pt_inputs.items() if torch.is_tensor(v)} - # Flax and torch calculate position ids differently, which is noticed only when attn mask - # does not follow expected pattern where zeros are only on one side (left or right) - pt_inputs["attention_mask"] = torch.ones_like(pt_inputs["input_ids"]) - fx_inputs["attention_mask"] = jnp.ones_like(fx_inputs["input_ids"]) - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_model.params = fx_state diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 020c0a4c9c198f..22d6b241f0048c 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -289,11 +289,6 @@ def test_equivalence_pt_to_flax(self): prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} - # Flax and torch calculate position ids differently, which is noticed only when attn mask - # does not follow expected pattern where zeros are only on one side (left or right) - pt_inputs["attention_mask"] = torch.ones_like(pt_inputs["input_ids"]) - prepared_inputs_dict["attention_mask"] = jnp.ones_like(prepared_inputs_dict["input_ids"]) - # load corresponding PyTorch class pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning pt_model_class = getattr(transformers, pt_model_class_name) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 76dc47f130ce96..f396875570c98d 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -600,11 +600,6 @@ def check_pt_tf_models(self, tf_model, pt_model, tf_inputs_dict): k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items() } - # TF and torch calculate position ids differently, which is noticed only when attn mask - # does not follow expected pattern where zeros are only on one side (left or right) - pt_inputs_dict["attention_mask"] = torch.ones_like(pt_inputs_dict["input_ids"]) - tf_inputs_dict["attention_mask"] = tf.ones_like(tf_inputs_dict["input_ids"]) - # send pytorch model to the correct device pt_model.to(torch_device) From ce3774248c17b3186d3168f8d1d10c08a8a2d4b3 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 22 Apr 2024 10:24:42 +0200 Subject: [PATCH 09/20] new models --- src/transformers/generation/utils.py | 2 +- src/transformers/models/dbrx/modeling_dbrx.py | 21 ++++++++++++------- src/transformers/models/olmo/modeling_olmo.py | 20 +++++++++++------- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 377c8f873e37a3..a6a701506f48aa 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4672,7 +4672,7 @@ def _assisted_decoding( model_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder ) model_kwargs = _prepare_token_type_ids(model_kwargs, candidate_input_ids.shape[1]) - candidate_kwargs = _prepare_position_ids(model_kwargs, candidate_input_ids.shape[1]) + model_kwargs = _prepare_position_ids(model_kwargs, candidate_input_ids.shape[1]) if "cache_position" in model_kwargs: model_kwargs["cache_position"] = torch.cat( ( diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 99b865c773f81d..2569f23c9028e3 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1145,7 +1145,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_seen_tokens, seq_length=inputs_embeds.shape[1], device=inputs_embeds.device + ) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) # embed positions @@ -1470,12 +1473,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] if self.generation_config.cache_implementation == "static": # generation with static cache diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 83637536a12531..3e5794e80dd41b 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -969,7 +969,9 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_seen_tokens, seq_length=inputs_embeds.shape[1], device=inputs_embeds.device + ) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) @@ -1279,12 +1281,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: From e28e551d62a1eceb625bb2415fd3f2a145d2cbcd Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 22 Apr 2024 10:31:38 +0200 Subject: [PATCH 10/20] tf fix cast --- src/transformers/models/ctrl/modeling_tf_ctrl.py | 2 +- src/transformers/models/gpt2/modeling_tf_gpt2.py | 2 +- src/transformers/models/gptj/modeling_tf_gptj.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/ctrl/modeling_tf_ctrl.py b/src/transformers/models/ctrl/modeling_tf_ctrl.py index 15da9f07f72f6f..f8babe031c1950 100644 --- a/src/transformers/models/ctrl/modeling_tf_ctrl.py +++ b/src/transformers/models/ctrl/modeling_tf_ctrl.py @@ -702,7 +702,7 @@ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache= attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None and position_ids is None: - position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + position_ids = tf.cumsum(attention_mask, axis=-1) - 1 position_ids = tf.where(attention_mask == 0, 1, position_ids) if past_key_values: position_ids = tf.expand_dims(position_ids[:, -1], -1) diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index 8711865712f98e..d19f8ef2f165e9 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -866,7 +866,7 @@ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache= attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None and position_ids is None: - position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + position_ids = tf.cumsum(attention_mask, axis=-1) - 1 position_ids = tf.where(attention_mask == 0, 1, position_ids) if past_key_values: position_ids = tf.expand_dims(position_ids[:, -1], -1) diff --git a/src/transformers/models/gptj/modeling_tf_gptj.py b/src/transformers/models/gptj/modeling_tf_gptj.py index b3e5635856dafa..826cd9d187c86d 100644 --- a/src/transformers/models/gptj/modeling_tf_gptj.py +++ b/src/transformers/models/gptj/modeling_tf_gptj.py @@ -447,7 +447,7 @@ def call( if position_ids is None: if attention_mask is not None: - position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + position_ids = tf.cumsum(attention_mask, axis=-1) - 1 position_ids = tf.where(attention_mask == 0, 1, position_ids) position_ids = position_ids[..., -input_shape[-1] :] position_ids = tf.reshape(position_ids, (-1, input_shape[-1])) @@ -777,7 +777,7 @@ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache= attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None and position_ids is None: - position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + position_ids = tf.cumsum(attention_mask, axis=-1) - 1 position_ids = tf.where(attention_mask == 0, 1, position_ids) if past_key_values: position_ids = tf.expand_dims(position_ids[:, -1], -1) From 7e5e3bf9150fd13b0b459a031c1c48867f887754 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 22 Apr 2024 13:12:58 +0200 Subject: [PATCH 11/20] tf equivalence --- src/transformers/models/ctrl/modeling_tf_ctrl.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/ctrl/modeling_tf_ctrl.py b/src/transformers/models/ctrl/modeling_tf_ctrl.py index f8babe031c1950..2a3d003fb8f692 100644 --- a/src/transformers/models/ctrl/modeling_tf_ctrl.py +++ b/src/transformers/models/ctrl/modeling_tf_ctrl.py @@ -344,8 +344,13 @@ def call( else: past_length = shape_list(past_key_values[0][0])[-2] if position_ids is None: - position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32), axis=0) - position_ids = tf.tile(position_ids, [input_shape[0], 1]) + if attention_mask is not None: + position_ids = tf.cumsum(attention_mask, axis=-1) - 1 + position_ids = tf.where(attention_mask == 0, 1, position_ids) + position_ids = position_ids[..., -input_shape[-1] :] + position_ids = tf.reshape(position_ids, (-1, input_shape[-1])) + else: + position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) # Attention mask. if attention_mask is not None: From 15ad877deb5c34a6c8f4192bf7ddb9ad78f2722a Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 23 Apr 2024 12:37:27 +0200 Subject: [PATCH 12/20] remove extra if conditions --- src/transformers/models/codegen/modeling_codegen.py | 9 +++------ src/transformers/models/cohere/modeling_cohere.py | 4 +--- src/transformers/models/ctrl/modeling_ctrl.py | 2 +- .../encoder_decoder/modeling_flax_encoder_decoder.py | 12 +++++------- src/transformers/models/falcon/modeling_falcon.py | 3 +-- src/transformers/models/fuyu/modeling_fuyu.py | 11 +++++------ src/transformers/models/gpt2/modeling_gpt2.py | 3 +-- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 3 +-- src/transformers/models/gpt_neo/modeling_gpt_neo.py | 3 +-- .../models/gpt_neox/modeling_gpt_neox.py | 9 +++------ src/transformers/models/gptj/modeling_gptj.py | 3 +-- src/transformers/models/idefics/modeling_idefics.py | 1 - .../models/imagegpt/modeling_imagegpt.py | 3 +-- src/transformers/models/mistral/modeling_mistral.py | 9 ++++----- src/transformers/models/mixtral/modeling_mixtral.py | 9 ++++----- .../models/persimmon/modeling_persimmon.py | 8 ++++---- src/transformers/models/phi/modeling_phi.py | 12 +++++------- src/transformers/models/qwen2/modeling_qwen2.py | 9 ++++----- .../models/qwen2_moe/modeling_qwen2_moe.py | 9 ++++----- .../models/stablelm/modeling_stablelm.py | 9 ++++----- 20 files changed, 53 insertions(+), 78 deletions(-) diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 00ce7905d7fc6a..45fe703ea3a4cc 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -455,7 +455,8 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) @@ -467,9 +468,8 @@ def forward( past_length = past_key_values[0][0].size(-2) if position_ids is None: - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[1] position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_length, seq_length=seq_length, device=device + attention_mask, past_length, seq_length=inputs_embeds.shape[1], device=inputs_embeds.device ) # Attention mask. @@ -498,9 +498,6 @@ def forward( # head_mask has shape n_layer x batch x num_attention_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - hidden_states = inputs_embeds if token_type_ids is not None: diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index e44831741e9bc2..711c125d848999 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -909,10 +909,8 @@ def forward( ) if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[1] position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_seen_tokens, seq_length=seq_length, device=device + attention_mask, past_seen_tokens, seq_length=inputs_embeds.shape[1], device=inputs_embeds.device ) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index b30f41732046b7..11d95adfaee19a 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -412,8 +412,8 @@ def forward( past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) + if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( attention_mask, past_length, seq_length=input_shape[1], device=device ) diff --git a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py index beecd080328e16..2ae1b2ccac8ea5 100644 --- a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py @@ -560,8 +560,8 @@ def decode( if past_key_values is not None: raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + decoder_position_ids = self.get_position_ids_from_attention_mask( + decoder_attention_mask, batch_size, sequence_length ) # Handle any PRNG if needed @@ -737,12 +737,10 @@ def prepare_inputs_for_generation( # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if decoder_attention_mask is not None: - decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) - else: - decoder_position_ids = jnp.broadcast_to( - jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) - ) + decoder_position_ids = self.get_position_ids_from_attention_mask( + decoder_attention_mask, batch_size, seq_length + ) return { "past_key_values": past_key_values, diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 1bef6eb96a04e2..61c53c8118cb62 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1077,9 +1077,8 @@ def forward( else: alibi = None if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_key_values_length, seq_length=seq_length, device=device + attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device ) if self._use_flash_attention_2: diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 8f809ed3d3f9f5..2a6e1754db293d 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -279,12 +279,6 @@ def forward( past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_key_values_length, seq_length=seq_length, device=device - ) - if inputs_embeds is None: inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if image_patches is not None and past_key_values is None: @@ -300,6 +294,11 @@ def forward( image_patch_input_indices=image_patches_indices, ) + if position_ids is None: + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device + ) + outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index b5d1b088238966..898173d0cc5f8a 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1022,10 +1022,9 @@ def forward( else: past_length = past_key_values[0][0].size(-2) - seq_length = input_shape[-1] if position_ids is None: position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_length, seq_length=seq_length, device=device + attention_mask, past_length, seq_length=input_shape[-1], device=device ) # Attention mask. diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 781dc3d404b44f..615728102ab73f 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -980,9 +980,8 @@ def forward( past_length = past_key_values[0].size(-2) if position_ids is None: - seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[1] position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_length, seq_length=seq_length, device=device + attention_mask, past_length, seq_length=input_shape[-1], device=device ) # Self-attention mask. diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 7efd05d8f97d87..2a9c60cf610182 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -784,9 +784,8 @@ def forward( past_length = past_key_values[0][0].size(-2) if position_ids is None: - seq_length = input_shape[-1] position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_length, seq_length=seq_length, device=device + attention_mask, past_length, seq_length=input_shape[-1], device=device ) # Prepare head mask if needed diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index cebc582b188289..19db8f7ccdf29c 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -850,6 +850,8 @@ def forward( raise ValueError("You have to specify either input_ids or inputs_embeds") batch_size, seq_length = input_shape + if inputs_embeds is None: + inputs_embeds = self.embed_in(input_ids) if past_key_values is None: past_length = 0 @@ -858,9 +860,8 @@ def forward( past_length = past_key_values[0][0].size(-2) if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_length, seq_length=seq_length, device=device + attention_mask, past_length, seq_length=seq_length, device=inputs_embeds.device ) # Attention mask. @@ -891,10 +892,6 @@ def forward( # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - if inputs_embeds is None: - inputs_embeds = self.embed_in(input_ids) - hidden_states = self.emb_dropout(inputs_embeds) if self.gradient_checkpointing and self.training: diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index f272d2df400a4d..6469fc6b345994 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -870,9 +870,8 @@ def forward( past_length = past_key_values[0][0].size(-2) if position_ids is None: - seq_length = input_shape[-1] position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_length, seq_length=seq_length, device=device + attention_mask, past_length, seq_length=input_shape[-1], device=device ) if not self._use_flash_attention_2: diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 1333347693497d..c30035ad632ddb 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1113,7 +1113,6 @@ def forward( seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( attention_mask, past_key_values_length, seq_length=seq_length, device=device ) diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 27c8dcddbdbbdb..f91794af8bc127 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -729,9 +729,8 @@ def forward( past_length = past_key_values[0][0].size(-2) if position_ids is None: - seq_length = input_shape[-1] position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_length, seq_length=seq_length, device=device + attention_mask, past_length, seq_length=input_shape[-1], device=device ) # ImageGPTAttention mask. diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 72c322bb393299..c27439ebdfb64e 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -974,17 +974,16 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_key_values_length, seq_length=seq_length, device=device + attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device ) else: position_ids = position_ids.view(-1, seq_length).long() - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 20b51337974273..aab843a9dfa392 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1156,17 +1156,16 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_key_values_length, seq_length=seq_length, device=device + attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device ) else: position_ids = position_ids.view(-1, seq_length).long() - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index adc6d3d9c667ec..a044b0638e9baf 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -622,14 +622,14 @@ def forward( past_key_values_length = past_key_values.get_usable_length(seq_length) seq_length_with_past = seq_length_with_past + past_key_values_length + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_key_values_length, seq_length=seq_length, device=device + attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device ) - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) # embed positions if attention_mask is None: attention_mask = torch.ones( diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index cfe0b0b00ba7ab..b4188425e70c14 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -995,17 +995,15 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_key_values_length, seq_length=seq_length, device=device - ) - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - inputs_embeds = self.embed_dropout(inputs_embeds) + if position_ids is None: + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device + ) + # Attention mask. if self._use_flash_attention_2: # 2d mask is passed through the layers diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 6f45034df14344..d847aae360cefb 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -985,17 +985,16 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_key_values_length, seq_length=seq_length, device=device + attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device ) else: position_ids = position_ids.view(-1, seq_length).long() - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index ac229592863dfd..c4edd35c2dc271 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1148,17 +1148,16 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_key_values_length, seq_length=seq_length, device=device + attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device ) else: position_ids = position_ids.view(-1, seq_length).long() - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 8ce2d97462bf43..9dd08235e59fa4 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -987,15 +987,14 @@ def forward( past_key_values_length = past_key_values.get_usable_length(seq_length) seq_length_with_past = seq_length_with_past + past_key_values_length + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_key_values_length, seq_length=seq_length, device=device + attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device ) - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None From 0d9ee9f6e3b34063ea57a82b945599577be8d2c7 Mon Sep 17 00:00:00 2001 From: raushan Date: Tue, 23 Apr 2024 12:52:05 +0200 Subject: [PATCH 13/20] make test parameterized --- tests/generation/test_utils.py | 103 +++++++++++---------------------- 1 file changed, 33 insertions(+), 70 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index b7c7280976db46..cb8f58b16c22ed 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1118,8 +1118,9 @@ def test_beam_search_low_memory(self): ) self.assertListEqual(low_output.tolist(), high_output.tolist()) + @parameterized.expand([(True,), (False,)]) @is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail. - def test_assisted_decoding_matches_greedy_search(self): + def test_assisted_decoding_matches_greedy_search(self, use_position_ids): # This test ensures that the assisted generation does not introduce output changes over greedy search. # NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul # shape differences -- and it may result in a different output. The input shape difference happens in the @@ -1176,79 +1177,41 @@ def test_assisted_decoding_matches_greedy_search(self): "output_attentions": True, "return_dict_in_generate": True, } - output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) assistant_model = model assistant_model.generation_config.num_assistant_tokens = 2 # see b) assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b) - generation_kwargs.update({"assistant_model": assistant_model}) - output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + + # test that the output is correct if user pass in position ids into `generate` vs it's calculated internally + if use_position_ids: + model_forward_args = inspect.signature(model.forward).parameters + if "position_ids" not in model_forward_args: + self.skipTest("This model doesn't use `position_ids`") + + position_ids = model.get_position_ids_from_attention_mask( + attention_mask, past_length=0, seq_length=input_ids.shape[-1], device=input_ids.device + ) + output_greedy = model.generate( + input_ids, attention_mask=attention_mask, position_ids=position_ids, **generation_kwargs + ) + output_assisted = model.generate( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + assistant_model=assistant_model, + **generation_kwargs, + ) + else: + output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + output_assisted = model.generate( + input_ids, attention_mask=attention_mask, assistant_model=assistant_model, **generation_kwargs + ) # The two outputs must match and their shape must be as expected self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) for output in (output_greedy, output_assisted): self._check_outputs(output, input_ids, model.config, use_cache=True) - @is_flaky() - def test_assisted_decoding_position_ids(self): - """ - Similar test to `test_assisted_decoding_matches_greedy_search` but passes in position ids to check if - assisted decoding can correctly expand/crop it while generating - """ - for model_class in self.all_generative_model_classes: - if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): - self.skipTest("Won't fix: old model with different cache format") - if any( - model_name in model_class.__name__.lower() - for model_name in [ - "bigbirdpegasus", - "led", - "mega", - "speech2text", - "git", - "prophetnet", - "seamlessm4t", - "clvp", - ] - ): - self.skipTest("May fix in the future: need model-specific fixes") - - config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1) - if not hasattr(config, "use_cache"): - self.skipTest("This model doesn't support caching") - - config.use_cache = True - config.is_decoder = True - model = model_class(config).to(torch_device).eval() - model_forward_args = inspect.signature(model.forward).parameters - if "position_ids" not in model_forward_args: - self.skipTest("This model doesn't use `position_ids`") - - generation_kwargs = { - "eos_token_id": -1, - "max_new_tokens": 4, - "num_beams": 1, - "do_sample": False, - } - - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids[..., -input_ids.shape[-1] :].view(-1, input_ids.shape[-1]) - output_greedy = model.generate( - input_ids, attention_mask=attention_mask, position_ids=position_ids, **generation_kwargs - ) - - assistant_model = model - assistant_model.generation_config.num_assistant_tokens = 2 - assistant_model.generation_config.num_assistant_tokens_schedule = "constant" - generation_kwargs.update({"assistant_model": assistant_model}) - output_assisted = model.generate( - input_ids, attention_mask=attention_mask, position_ids=position_ids, **generation_kwargs - ) - - # The two outputs must match and their shape must be as expected - self.assertListEqual(output_greedy.tolist(), output_assisted.tolist()) - @is_flaky() def test_prompt_lookup_decoding_matches_greedy_search(self): # This test ensures that the prompt lookup generation does not introduce output changes over greedy search. @@ -1678,9 +1641,9 @@ def test_generate_with_and_without_position_ids(self): out_wo_positions = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=5) # infer position ids from attn mask and generate again - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids[..., -input_ids.shape[-1] :].view(-1, input_ids.shape[-1]) + position_ids = model.get_position_ids_from_attention_mask( + attention_mask, past_length=0, seq_length=input_ids.shape[-1], device=input_ids.device + ) out_w_positions = model.generate( input_ids, attention_mask=attention_mask, position_ids=position_ids, max_new_tokens=5 ) @@ -1706,9 +1669,9 @@ def test_generate_with_and_without_position_ids_inputs_embeds(self): ) # infer position ids from attn mask and generate again - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - position_ids = position_ids[..., -input_ids.shape[-1] :].view(-1, input_ids.shape[-1]) + position_ids = model.get_position_ids_from_attention_mask( + attention_mask, past_length=0, seq_length=input_ids.shape[-1], device=input_ids.device + ) out_w_positions = model.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, max_new_tokens=5 ) From 0f20d92703afcc81831b8a6357c4d0ce9cb6a42f Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 25 Apr 2024 16:51:15 +0200 Subject: [PATCH 14/20] fix failing flax cases --- src/transformers/generation/utils.py | 1 + src/transformers/modeling_flax_utils.py | 1 + src/transformers/modeling_utils.py | 2 +- .../encoder_decoder/modeling_encoder_decoder.py | 1 + .../modeling_flax_encoder_decoder.py | 4 ++-- src/transformers/models/phi3/modeling_phi3.py | 17 +++++++++++------ .../modeling_flax_speech_encoder_decoder.py | 8 ++++---- .../modeling_flax_vision_encoder_decoder.py | 8 ++++---- tests/generation/test_utils.py | 2 +- .../test_modeling_flax_encoder_decoder.py | 1 + tests/models/gemma/test_modeling_flax_gemma.py | 3 +++ tests/models/gpt2/test_modeling_flax_gpt2.py | 3 +++ .../gpt_neo/test_modeling_flax_gpt_neo.py | 3 +++ tests/models/gptj/test_modeling_flax_gptj.py | 3 +++ .../mistral/test_modeling_flax_mistral.py | 3 +++ 15 files changed, 42 insertions(+), 18 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index da41732614f67f..ba6b77512f7546 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4690,6 +4690,7 @@ def _assisted_decoding( candidate_kwargs = _prepare_attention_mask( candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder ) + candidate_kwargs = _prepare_position_ids(candidate_kwargs, candidate_input_ids.shape[1]) candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) if "cache_position" in candidate_kwargs: candidate_kwargs["cache_position"] = torch.cat( diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 8c213685648069..d08c080c0ce418 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -1256,6 +1256,7 @@ def get_position_ids_from_attention_mask(self, attention_mask, batch_size, seq_l if attention_mask is not None: position_ids = jnp.cumsum(attention_mask, axis=-1) - 1 position_ids = jnp.where(attention_mask == 0, 1, position_ids) + position_ids = position_ids[..., -seq_length:] else: position_ids = jnp.broadcast_to(jnp.arange(seq_length)[None, :], (batch_size, seq_length)) return position_ids diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6613df79d1b2d7..f5f80e9675490f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4375,7 +4375,7 @@ def get_position_ids_from_attention_mask(self, attention_mask, past_length, seq_ """ if attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.masked_fill(attention_mask == 0, 1) position_ids = position_ids[..., -seq_length:].view(-1, seq_length) else: position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 16248fee64ce59..cc1283004b82d9 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -593,6 +593,7 @@ def forward( argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } + print(self.encoder) if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, diff --git a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py index 2ae1b2ccac8ea5..ad0f99116d180c 100644 --- a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py @@ -697,8 +697,8 @@ def __call__( decoder_attention_mask = jnp.ones_like(decoder_input_ids) if decoder_position_ids is None: batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + decoder_position_ids = self.get_position_ids_from_attention_mask( + decoder_attention_mask, batch_size, sequence_length ) # Handle any PRNG if needed diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index f9364d130b7e6c..e9ee2f95296d82 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1317,6 +1317,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -1347,12 +1348,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py index e3bbd86266ea11..8063c30c92c07b 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py @@ -592,8 +592,8 @@ def decode( if past_key_values is not None: raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + decoder_position_ids = self.get_position_ids_from_attention_mask( + decoder_attention_mask, batch_size, sequence_length ) # Handle any PRNG if needed @@ -719,8 +719,8 @@ def __call__( decoder_attention_mask = jnp.ones_like(decoder_input_ids) if decoder_position_ids is None: batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + decoder_position_ids = self.get_position_ids_from_attention_mask( + decoder_attention_mask, batch_size, sequence_length ) # Handle any PRNG if needed diff --git a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py index 987c9a1afa3d19..0e0091329411ce 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py @@ -531,8 +531,8 @@ def decode( if past_key_values is not None: raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + decoder_position_ids = self.get_position_ids_from_attention_mask( + decoder_attention_mask, batch_size, sequence_length ) # Handle any PRNG if needed @@ -665,8 +665,8 @@ def __call__( decoder_attention_mask = jnp.ones_like(decoder_input_ids) if decoder_position_ids is None: batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + decoder_position_ids = self.get_position_ids_from_attention_mask( + decoder_attention_mask, batch_size, sequence_length ) # Handle any PRNG if needed diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index d459d2514f8588..530a45025079af 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1160,7 +1160,7 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type, use_posit assistant_model = model assistant_model.generation_config.num_assistant_tokens = 2 # see b) assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b) - + # test that the output is correct if user pass in position ids into `generate` vs it's calculated internally if use_position_ids: model_forward_args = inspect.signature(model.forward).parameters diff --git a/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py index c8f76a144be703..eb1eef1b56d4ed 100644 --- a/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py @@ -305,6 +305,7 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): fx_outputs = fx_model(**inputs_dict).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): + print("PASSED") self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5) # PT -> Flax diff --git a/tests/models/gemma/test_modeling_flax_gemma.py b/tests/models/gemma/test_modeling_flax_gemma.py index 0f3c5df4f13622..d0bd7d2c6a8cd5 100644 --- a/tests/models/gemma/test_modeling_flax_gemma.py +++ b/tests/models/gemma/test_modeling_flax_gemma.py @@ -143,6 +143,9 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input max_decoder_length = 20 model = model_class_name(config) + # make full attn mask since below we are preparing position ids assuming it's all ones + attention_mask = jnp.ones_like(attention_mask) + attention_mask_cache = jnp.concatenate( [attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))], axis=-1, diff --git a/tests/models/gpt2/test_modeling_flax_gpt2.py b/tests/models/gpt2/test_modeling_flax_gpt2.py index 820b23bd70cfc5..c724a1a8a7d74f 100644 --- a/tests/models/gpt2/test_modeling_flax_gpt2.py +++ b/tests/models/gpt2/test_modeling_flax_gpt2.py @@ -158,6 +158,9 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input max_decoder_length = 20 model = model_class_name(config) + # make full attn mask since below we are preparing position ids assuming it's all ones + attention_mask = jnp.ones_like(attention_mask) + attention_mask_cache = jnp.concatenate( [attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))], axis=-1, diff --git a/tests/models/gpt_neo/test_modeling_flax_gpt_neo.py b/tests/models/gpt_neo/test_modeling_flax_gpt_neo.py index ca41495a842c77..6d934ee0ea16d5 100644 --- a/tests/models/gpt_neo/test_modeling_flax_gpt_neo.py +++ b/tests/models/gpt_neo/test_modeling_flax_gpt_neo.py @@ -150,6 +150,9 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input max_decoder_length = 20 model = model_class_name(config) + # make full attn mask since below we are preparing position ids assuming it's all ones + attention_mask = jnp.ones_like(attention_mask) + attention_mask_cache = jnp.concatenate( [attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))], axis=-1, diff --git a/tests/models/gptj/test_modeling_flax_gptj.py b/tests/models/gptj/test_modeling_flax_gptj.py index aa3b7a99aa0fdf..5ae55357eb837b 100644 --- a/tests/models/gptj/test_modeling_flax_gptj.py +++ b/tests/models/gptj/test_modeling_flax_gptj.py @@ -147,6 +147,9 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input max_decoder_length = 20 model = model_class_name(config) + # make full attn mask since below we are preparing position ids assuming it's all ones + attention_mask = jnp.ones_like(attention_mask) + attention_mask_cache = jnp.concatenate( [attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))], axis=-1, diff --git a/tests/models/mistral/test_modeling_flax_mistral.py b/tests/models/mistral/test_modeling_flax_mistral.py index 047bf4c6d433d0..6bb69eadf40e1b 100644 --- a/tests/models/mistral/test_modeling_flax_mistral.py +++ b/tests/models/mistral/test_modeling_flax_mistral.py @@ -154,6 +154,9 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input max_decoder_length = 20 model = model_class_name(config) + # make full attn mask since below we are preparing position ids assuming it's all ones + attention_mask = jnp.ones_like(attention_mask) + attention_mask_cache = jnp.concatenate( [attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))], axis=-1, From e749080edb8030670d62d7aabfe1193bf4864510 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 25 Apr 2024 18:52:49 +0200 Subject: [PATCH 15/20] torch tests fail due to merge conflicts? --- tests/generation/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 530a45025079af..fab11db8a32f08 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1611,7 +1611,7 @@ def test_generate_continue_from_past_key_values(self): def test_generate_with_and_without_position_ids(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, _ = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() model_forward_args = inspect.signature(model.forward).parameters if "position_ids" not in model_forward_args: @@ -1633,7 +1633,7 @@ def test_generate_with_and_without_position_ids(self): def test_generate_with_and_without_position_ids_inputs_embeds(self): for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, _ = self._get_input_ids_and_config() + config, input_ids, attention_mask = self._get_input_ids_and_config() model = model_class(config).to(torch_device).eval() model_forward_args = inspect.signature(model.forward).parameters if "position_ids" not in model_forward_args: From ef2494ef57b33fc13db2a70baf2b0cd8acf68a40 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 29 Apr 2024 10:33:20 +0200 Subject: [PATCH 16/20] let the tests pass --- src/transformers/models/codegen/modeling_codegen.py | 9 +++++++-- src/transformers/models/ctrl/modeling_tf_ctrl.py | 11 +++++++---- src/transformers/models/falcon/modeling_falcon.py | 8 ++++++-- src/transformers/models/gpt2/modeling_gpt2.py | 8 ++++++-- src/transformers/models/gpt2/modeling_tf_gpt2.py | 9 ++++++--- src/transformers/models/gptj/modeling_tf_gptj.py | 11 +++++++---- .../test_modeling_flax_encoder_decoder.py | 10 +++++++++- 7 files changed, 48 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 56f14f2581a06f..d3e61286dc4461 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -613,12 +613,17 @@ def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_ attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) + + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_length, seq_length=input_ids.shape[-1], device=input_ids.device + attention_mask, past_length, seq_length=seq_length, device=device ) else: - position_ids = position_ids[:, -input_ids.shape[-1] :] + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/ctrl/modeling_tf_ctrl.py b/src/transformers/models/ctrl/modeling_tf_ctrl.py index 2a3d003fb8f692..339f1e36e4a83e 100644 --- a/src/transformers/models/ctrl/modeling_tf_ctrl.py +++ b/src/transformers/models/ctrl/modeling_tf_ctrl.py @@ -345,8 +345,10 @@ def call( past_length = shape_list(past_key_values[0][0])[-2] if position_ids is None: if attention_mask is not None: - position_ids = tf.cumsum(attention_mask, axis=-1) - 1 - position_ids = tf.where(attention_mask == 0, 1, position_ids) + position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + # create ones tensor to match dtypes, otherwise we get errors + ones_tensor = tf.ones_like(position_ids, dtype=tf.int64) + position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids) position_ids = position_ids[..., -input_shape[-1] :] position_ids = tf.reshape(position_ids, (-1, input_shape[-1])) else: @@ -707,8 +709,9 @@ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache= attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None and position_ids is None: - position_ids = tf.cumsum(attention_mask, axis=-1) - 1 - position_ids = tf.where(attention_mask == 0, 1, position_ids) + position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + ones_tensor = tf.ones_like(position_ids, dtype=tf.int64) + position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids) if past_key_values: position_ids = tf.expand_dims(position_ids[:, -1], -1) diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index dd4afd98287492..f8c3c63d20a296 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1228,12 +1228,16 @@ def prepare_inputs_for_generation( # Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE. if not self.transformer.use_alibi: + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_length, seq_length=input_ids.shape[-1], device=input_ids.device + attention_mask, past_length, seq_length=seq_length, device=device ) else: - position_ids = position_ids[:, -input_ids.shape[-1] :] + position_ids = position_ids[:, -seq_length:] if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 95d1e92909be22..7f26b42b703c79 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1456,12 +1456,16 @@ def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_ attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = self.get_position_ids_from_attention_mask( - attention_mask, past_length, seq_length=input_ids.shape[1], device=input_ids.device + attention_mask, past_length, seq_length=seq_length, device=device ) else: - position_ids = position_ids[:, -input_ids.shape[1] :] + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index d19f8ef2f165e9..de6b3da1335331 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -432,7 +432,9 @@ def call( if position_ids is None: if attention_mask is not None: position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 - position_ids = tf.where(attention_mask == 0, 1, position_ids) + # create ones tensor to match dtypes, otherwise we get errors + ones_tensor = tf.ones_like(position_ids, dtype=tf.int64) + position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids) position_ids = position_ids[..., -input_shape[-1] :] position_ids = tf.reshape(position_ids, (-1, input_shape[-1])) else: @@ -866,8 +868,9 @@ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache= attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None and position_ids is None: - position_ids = tf.cumsum(attention_mask, axis=-1) - 1 - position_ids = tf.where(attention_mask == 0, 1, position_ids) + position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + ones_tensor = tf.ones_like(position_ids, dtype=tf.int64) + position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids) if past_key_values: position_ids = tf.expand_dims(position_ids[:, -1], -1) diff --git a/src/transformers/models/gptj/modeling_tf_gptj.py b/src/transformers/models/gptj/modeling_tf_gptj.py index 826cd9d187c86d..4b259492ae2d0c 100644 --- a/src/transformers/models/gptj/modeling_tf_gptj.py +++ b/src/transformers/models/gptj/modeling_tf_gptj.py @@ -447,8 +447,10 @@ def call( if position_ids is None: if attention_mask is not None: - position_ids = tf.cumsum(attention_mask, axis=-1) - 1 - position_ids = tf.where(attention_mask == 0, 1, position_ids) + position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + # create ones tensor to match dtypes, otherwise we get errors + ones_tensor = tf.ones_like(position_ids, dtype=tf.int64) + position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids) position_ids = position_ids[..., -input_shape[-1] :] position_ids = tf.reshape(position_ids, (-1, input_shape[-1])) else: @@ -777,8 +779,9 @@ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache= attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None and position_ids is None: - position_ids = tf.cumsum(attention_mask, axis=-1) - 1 - position_ids = tf.where(attention_mask == 0, 1, position_ids) + position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + ones_tensor = tf.ones_like(position_ids, dtype=tf.int64) + position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids) if past_key_values: position_ids = tf.expand_dims(position_ids[:, -1], -1) diff --git a/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py index eb1eef1b56d4ed..fbbe68db32f04a 100644 --- a/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py @@ -17,6 +17,7 @@ import tempfile import unittest +import jax.numpy as jnp import numpy as np from transformers import is_flax_available, is_torch_available @@ -299,13 +300,20 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): flax_inputs = inputs_dict pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} + # when model weights are random init masking with attn_mask still leads to logits + # mismatch, which does not happen if pre-trained models are used. That causes error in encoder-decoder models + # when decoder_only is used in as backbone (GPT2), because GPT prepares positions depending on attn mask (for torch) + # and as arange in flax. That's why we init attn mask with all `1` + if "decoder_attention_mask" in pt_inputs: + pt_inputs["decoder_attention_mask"] = torch.ones_like(pt_inputs["attention_mask"]) + inputs_dict["decoder_attention_mask"] = jnp.ones_like(inputs_dict["attention_mask"]) + with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() fx_outputs = fx_model(**inputs_dict).to_tuple() self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") for fx_output, pt_output in zip(fx_outputs, pt_outputs): - print("PASSED") self.assert_almost_equals(fx_output, pt_output.numpy(), 1e-5) # PT -> Flax From d5f5989d71b6e2a6c8afb28cc7b79572d3416967 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 29 Apr 2024 12:06:36 +0200 Subject: [PATCH 17/20] import if available --- .../encoder_decoder/test_modeling_flax_encoder_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py index fbbe68db32f04a..a68f9d5defe8a5 100644 --- a/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py @@ -17,7 +17,6 @@ import tempfile import unittest -import jax.numpy as jnp import numpy as np from transformers import is_flax_available, is_torch_available @@ -43,6 +42,7 @@ convert_pytorch_state_dict_to_flax, load_flax_weights_in_pytorch_model, ) + import jax.numpy as jnp if is_torch_available(): import torch From cd10c7369502789e2327c823d51bb80c23a77a31 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 29 Apr 2024 13:04:45 +0200 Subject: [PATCH 18/20] fixes --- .../encoder_decoder/modeling_flax_encoder_decoder.py | 8 ++++---- .../modeling_flax_speech_encoder_decoder.py | 8 ++++---- .../encoder_decoder/test_modeling_flax_encoder_decoder.py | 3 ++- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py index ad0f99116d180c..eae2feef268a79 100644 --- a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py @@ -560,8 +560,8 @@ def decode( if past_key_values is not None: raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - decoder_position_ids = self.get_position_ids_from_attention_mask( - decoder_attention_mask, batch_size, sequence_length + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed @@ -697,8 +697,8 @@ def __call__( decoder_attention_mask = jnp.ones_like(decoder_input_ids) if decoder_position_ids is None: batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = self.get_position_ids_from_attention_mask( - decoder_attention_mask, batch_size, sequence_length + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed diff --git a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py index 8063c30c92c07b..e3bbd86266ea11 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_flax_speech_encoder_decoder.py @@ -592,8 +592,8 @@ def decode( if past_key_values is not None: raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - decoder_position_ids = self.get_position_ids_from_attention_mask( - decoder_attention_mask, batch_size, sequence_length + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed @@ -719,8 +719,8 @@ def __call__( decoder_attention_mask = jnp.ones_like(decoder_input_ids) if decoder_position_ids is None: batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = self.get_position_ids_from_attention_mask( - decoder_attention_mask, batch_size, sequence_length + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed diff --git a/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py index a68f9d5defe8a5..2322f553777b36 100644 --- a/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py @@ -29,6 +29,8 @@ if is_flax_available(): + import jax.numpy as jnp + from transformers import ( AutoTokenizer, EncoderDecoderConfig, @@ -42,7 +44,6 @@ convert_pytorch_state_dict_to_flax, load_flax_weights_in_pytorch_model, ) - import jax.numpy as jnp if is_torch_available(): import torch From 0f1997c201ffe421293959234e6717b42f2957a3 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 29 Apr 2024 13:32:17 +0200 Subject: [PATCH 19/20] encoder-decoder models --- .../encoder_decoder/test_modeling_flax_encoder_decoder.py | 4 ++-- .../test_modeling_flax_speech_encoder_decoder.py | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py index 2322f553777b36..84fd65711e34b4 100644 --- a/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py @@ -306,8 +306,8 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): # when decoder_only is used in as backbone (GPT2), because GPT prepares positions depending on attn mask (for torch) # and as arange in flax. That's why we init attn mask with all `1` if "decoder_attention_mask" in pt_inputs: - pt_inputs["decoder_attention_mask"] = torch.ones_like(pt_inputs["attention_mask"]) - inputs_dict["decoder_attention_mask"] = jnp.ones_like(inputs_dict["attention_mask"]) + pt_inputs["decoder_attention_mask"] = torch.ones_like(pt_inputs["decoder_attention_mask"]) + inputs_dict["decoder_attention_mask"] = jnp.ones_like(inputs_dict["decoder_attention_mask"]) with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() diff --git a/tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py b/tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py index 62ce0d660a0abc..5e66beb215e85f 100644 --- a/tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py +++ b/tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py @@ -414,6 +414,14 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): flax_inputs = inputs_dict pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} + # when model weights are random init masking with attn_mask still leads to logits + # mismatch, which does not happen if pre-trained models are used. That causes error in encoder-decoder models + # when decoder_only is used in as backbone (GPT2), because GPT prepares positions depending on attn mask (for torch) + # and as arange in flax. That's why we init attn mask with all `1` + if "decoder_attention_mask" in pt_inputs: + pt_inputs["decoder_attention_mask"] = torch.ones_like(pt_inputs["decoder_attention_mask"]) + inputs_dict["decoder_attention_mask"] = jnp.ones_like(inputs_dict["decoder_attention_mask"]) + with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() From 87befb7f6ee11b20add35b5966bebd7c8fbd80a0 Mon Sep 17 00:00:00 2001 From: raushan Date: Mon, 29 Apr 2024 14:11:52 +0200 Subject: [PATCH 20/20] fix llama flax --- tests/models/llama/test_modeling_flax_llama.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/models/llama/test_modeling_flax_llama.py b/tests/models/llama/test_modeling_flax_llama.py index a81398786b43b6..f7d80e3d197ac9 100644 --- a/tests/models/llama/test_modeling_flax_llama.py +++ b/tests/models/llama/test_modeling_flax_llama.py @@ -149,8 +149,8 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input ) past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length) - position_ids = jnp.broadcast_to( - jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1) + position_ids = model.get_position_ids_from_attention_mask( + attention_mask_cache, input_ids.shape[0], input_ids.shape[-1] - 1 ) outputs_cache = model( @@ -159,7 +159,6 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input past_key_values=past_key_values, position_ids=position_ids, ) - position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4") outputs_cache_next = model( input_ids[:, -1:], past_key_values=outputs_cache.past_key_values,