diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 6bd55c5f6b5109..c8967732d055c7 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -190,6 +190,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder ) self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len) + self.assistant_kwargs = _prepare_position_ids(self.assistant_kwargs, new_cur_len) # 2. Forecast next N tokens using the assistant model. assistant_generation_kwargs = { @@ -423,3 +424,18 @@ def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Di token_type_copies = final_token_type.repeat(1, type_length_diff) model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1) return model_kwargs + + +def _prepare_position_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]: + position_ids = model_kwargs.get("position_ids") + if position_ids is None: + return model_kwargs + + # we assume batch_size=1 for assited decoding (needs rework if bs > 1) + length_diff = new_length - position_ids[0, -1] + if length_diff < 0: + position_ids = position_ids[:, :length_diff] + elif length_diff > 0: + new_position_ids = torch.arange(position_ids[0, -1], new_length, device=position_ids.device).unsqueeze(0) + model_kwargs["position_ids"] = torch.cat([model_kwargs["position_ids"], new_position_ids], dim=-1) + return model_kwargs diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9e6a58d3e5a560..ba6b77512f7546 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -43,6 +43,7 @@ PromptLookupCandidateGenerator, _crop_past_key_values, _prepare_attention_mask, + _prepare_position_ids, _prepare_token_type_ids, ) from .configuration_utils import GenerationConfig, GenerationMode @@ -674,6 +675,10 @@ def _update_model_kwargs_for_generation( if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None: model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens + if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None: + position_ids = model_kwargs["position_ids"] + model_kwargs["position_ids"] = torch.cat([position_ids, position_ids[:, -1:] + 1], dim=-1) + return model_kwargs def _reorder_cache(self, past_key_values, beam_idx): @@ -4685,6 +4690,7 @@ def _assisted_decoding( candidate_kwargs = _prepare_attention_mask( candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder ) + candidate_kwargs = _prepare_position_ids(candidate_kwargs, candidate_input_ids.shape[1]) candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) if "cache_position" in candidate_kwargs: candidate_kwargs["cache_position"] = torch.cat( diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index da373603420ba2..d08c080c0ce418 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -1248,6 +1248,19 @@ def register_for_auto_class(cls, auto_class="FlaxAutoModel"): cls._auto_class = auto_class + def get_position_ids_from_attention_mask(self, attention_mask, batch_size, seq_length): + """ + Tries to infer position ids given attention mask and past kv cache length. All instances when + `position_ids=None` should call this method. + """ + if attention_mask is not None: + position_ids = jnp.cumsum(attention_mask, axis=-1) - 1 + position_ids = jnp.where(attention_mask == 0, 1, position_ids) + position_ids = position_ids[..., -seq_length:] + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length)[None, :], (batch_size, seq_length)) + return position_ids + # To update the docstring, we need to copy the method, otherwise we change the original docstring. FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index be164e8e2c0c00..f5f80e9675490f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4368,6 +4368,20 @@ def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask): logger.warning_once(warn_string) + def get_position_ids_from_attention_mask(self, attention_mask, past_length, seq_length, device): + """ + Tries to infer position ids given attention mask and past kv cache length. All instances when + `position_ids=None` should call this method. + """ + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids = position_ids.masked_fill(attention_mask == 0, 1) + position_ids = position_ids[..., -seq_length:].view(-1, seq_length) + else: + position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + return position_ids + @property def _is_quantized_training_enabled(self): warnings.warn( diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index c14e33bd1261dd..d3e61286dc4461 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -455,7 +455,8 @@ def forward( else: raise ValueError("You have to specify either input_ids or inputs_embeds") - device = input_ids.device if input_ids is not None else inputs_embeds.device + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) @@ -467,8 +468,9 @@ def forward( past_length = past_key_values[0][0].size(-2) if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=inputs_embeds.shape[1], device=inputs_embeds.device + ) # Attention mask. if attention_mask is not None: @@ -496,9 +498,6 @@ def forward( # head_mask has shape n_layer x batch x num_attention_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - hidden_states = inputs_embeds if token_type_ids is not None: @@ -597,6 +596,7 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # Omit tokens covered by past_key_values + past_length = 0 if past_key_values: past_length = past_key_values[0][0].shape[2] @@ -614,12 +614,16 @@ def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_ attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 41bb4c0516928c..60487b538a3c9f 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -909,7 +909,9 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_seen_tokens, seq_length=inputs_embeds.shape[1], device=inputs_embeds.device + ) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) @@ -1227,12 +1229,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 7534a0e50c9a23..11d95adfaee19a 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -412,9 +412,11 @@ def forward( past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) + if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=input_shape[1], device=device + ) # Attention mask. if attention_mask is not None: @@ -525,6 +527,7 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs): # only last tokens for inputs_ids if past is defined in kwargs + past_length = 0 if past_key_values is not None: past_length = past_key_values[0][0].shape[2] @@ -537,6 +540,16 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cac input_ids = input_ids[:, remove_prefix_length:] + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if position_ids is None: + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=input_ids.shape[1], device=input_ids.device + ) + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache} @add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING) diff --git a/src/transformers/models/ctrl/modeling_tf_ctrl.py b/src/transformers/models/ctrl/modeling_tf_ctrl.py index 6569b9e7d7b788..339f1e36e4a83e 100644 --- a/src/transformers/models/ctrl/modeling_tf_ctrl.py +++ b/src/transformers/models/ctrl/modeling_tf_ctrl.py @@ -344,8 +344,15 @@ def call( else: past_length = shape_list(past_key_values[0][0])[-2] if position_ids is None: - position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32), axis=0) - position_ids = tf.tile(position_ids, [input_shape[0], 1]) + if attention_mask is not None: + position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + # create ones tensor to match dtypes, otherwise we get errors + ones_tensor = tf.ones_like(position_ids, dtype=tf.int64) + position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids) + position_ids = position_ids[..., -input_shape[-1] :] + position_ids = tf.reshape(position_ids, (-1, input_shape[-1])) + else: + position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) # Attention mask. if attention_mask is not None: @@ -702,7 +709,9 @@ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache= attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None and position_ids is None: - position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True) + position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + ones_tensor = tf.ones_like(position_ids, dtype=tf.int64) + position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids) if past_key_values: position_ids = tf.expand_dims(position_ids[:, -1], -1) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 99b865c773f81d..2569f23c9028e3 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1145,7 +1145,10 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_seen_tokens, seq_length=inputs_embeds.shape[1], device=inputs_embeds.device + ) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) # embed positions @@ -1470,12 +1473,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] if self.generation_config.cache_implementation == "static": # generation with static cache diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 16248fee64ce59..cc1283004b82d9 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -593,6 +593,7 @@ def forward( argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } + print(self.encoder) if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, diff --git a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py index beecd080328e16..eae2feef268a79 100644 --- a/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_flax_encoder_decoder.py @@ -737,12 +737,10 @@ def prepare_inputs_for_generation( # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if decoder_attention_mask is not None: - decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) - else: - decoder_position_ids = jnp.broadcast_to( - jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) - ) + decoder_position_ids = self.get_position_ids_from_attention_mask( + decoder_attention_mask, batch_size, seq_length + ) return { "past_key_values": past_key_values, diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 1f4fd41afa2e89..f8c3c63d20a296 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1077,11 +1077,9 @@ def forward( else: alibi = None if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device ) - position_ids = position_ids.unsqueeze(0) if self._use_flash_attention_2: # 2d mask is passed through the layers @@ -1215,6 +1213,7 @@ def prepare_inputs_for_generation( inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> dict: + past_length = 0 if past_key_values is not None: past_length = past_key_values[0][0].shape[2] @@ -1228,12 +1227,17 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, remove_prefix_length:] # Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE. - if not self.transformer.use_alibi and attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if not self.transformer.use_alibi: + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 8e9a41954aee9c..2a6e1754db293d 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -279,13 +279,6 @@ def forward( past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0) - if inputs_embeds is None: inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if image_patches is not None and past_key_values is None: @@ -301,6 +294,11 @@ def forward( image_patch_input_indices=image_patches_indices, ) + if position_ids is None: + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device + ) + outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, @@ -325,16 +323,34 @@ def prepare_inputs_for_generation( image_patches_indices=None, **kwargs, ): - if past_key_values: - input_ids = input_ids[:, -1:] + # Omit tokens covered by past_key_values + past_length = 0 + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/gemma/modeling_flax_gemma.py b/src/transformers/models/gemma/modeling_flax_gemma.py index 235f65680fad3e..d2e527fce0fef9 100644 --- a/src/transformers/models/gemma/modeling_flax_gemma.py +++ b/src/transformers/models/gemma/modeling_flax_gemma.py @@ -504,7 +504,7 @@ def __call__( if past_key_values is not None: raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, sequence_length) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) @@ -747,10 +747,8 @@ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: O # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, seq_length) return { "past_key_values": past_key_values, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index e5b6b207748a53..45a088f9a0d676 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -889,7 +889,9 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_seen_tokens, seq_length=inputs_embeds.shape[1], device=inputs_embeds.device + ) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) @@ -1209,12 +1211,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/gpt2/modeling_flax_gpt2.py b/src/transformers/models/gpt2/modeling_flax_gpt2.py index c3ef377642a3c5..5813f1249923f3 100644 --- a/src/transformers/models/gpt2/modeling_flax_gpt2.py +++ b/src/transformers/models/gpt2/modeling_flax_gpt2.py @@ -485,7 +485,7 @@ def __call__( if past_key_values is not None: raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, sequence_length) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) @@ -752,12 +752,10 @@ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: O # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice( extended_attention_mask, attention_mask.astype("i4"), (0, 0) ) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, seq_length) return { "past_key_values": past_key_values, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index c44d27a23c5d05..7f26b42b703c79 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1021,9 +1021,11 @@ def forward( past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) + if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=input_shape[-1], device=device + ) # Attention mask. if attention_mask is not None: @@ -1227,6 +1229,7 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # Omit tokens covered by past_key_values + past_length = 0 if past_key_values: past_length = past_key_values[0][0].shape[2] @@ -1244,14 +1247,16 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) else: - position_ids = None + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: @@ -1433,6 +1438,7 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # Omit tokens covered by past_key_values + past_length = 0 if past_key_values: past_length = past_key_values[0][0].shape[2] @@ -1450,14 +1456,16 @@ def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_ attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) else: - position_ids = None + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index 26a4e7a398ae8d..de6b3da1335331 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -430,7 +430,15 @@ def call( past_length = shape_list(past_key_values[0][0])[-2] if position_ids is None: - position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) + if attention_mask is not None: + position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + # create ones tensor to match dtypes, otherwise we get errors + ones_tensor = tf.ones_like(position_ids, dtype=tf.int64) + position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids) + position_ids = position_ids[..., -input_shape[-1] :] + position_ids = tf.reshape(position_ids, (-1, input_shape[-1])) + else: + position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) if attention_mask is not None: # We create a 3D attention mask from a 2D tensor mask. @@ -860,7 +868,9 @@ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache= attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None and position_ids is None: - position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True) + position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + ones_tensor = tf.ones_like(position_ids, dtype=tf.int64) + position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids) if past_key_values: position_ids = tf.expand_dims(position_ids[:, -1], -1) diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index d61877cb1f1e7e..98b3ebd6e5463c 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -979,15 +979,10 @@ def forward( else: past_length = past_key_values[0].size(-2) - if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_length > 0: - position_ids = position_ids[:, past_length : input_shape[-1] + past_length :] - elif position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + if position_ids is None: + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=input_shape[-1], device=device + ) # Self-attention mask. query_length = input_shape[-1] @@ -1163,6 +1158,7 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # Omit tokens covered by past_key_values + past_length = 0 if past_key_values: if self.config.multi_query: past_length = past_key_values[0].shape[1] @@ -1182,15 +1178,17 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) else: - position_ids = None + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py index 5639ca50f166a2..5233439f819800 100644 --- a/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_flax_gpt_neo.py @@ -424,7 +424,7 @@ def __call__( if past_key_values is not None: raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, sequence_length) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) @@ -664,10 +664,8 @@ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: O # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, seq_length) return { "past_key_values": past_key_values, diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 2fbf4677ca6f44..2a9c60cf610182 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -784,8 +784,9 @@ def forward( past_length = past_key_values[0][0].size(-2) if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=input_shape[-1], device=device + ) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head @@ -900,6 +901,7 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # Omit tokens covered by past_key_values + past_length = 0 if past_key_values: past_length = past_key_values[0][0].shape[2] @@ -916,13 +918,17 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 83c99202ac9379..19db8f7ccdf29c 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -850,6 +850,8 @@ def forward( raise ValueError("You have to specify either input_ids or inputs_embeds") batch_size, seq_length = input_shape + if inputs_embeds is None: + inputs_embeds = self.embed_in(input_ids) if past_key_values is None: past_length = 0 @@ -858,9 +860,9 @@ def forward( past_length = past_key_values[0][0].size(-2) if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=inputs_embeds.device + ) # Attention mask. if attention_mask is not None: @@ -890,10 +892,6 @@ def forward( # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - if inputs_embeds is None: - inputs_embeds = self.embed_in(input_ids) - hidden_states = self.emb_dropout(inputs_embeds) if self.gradient_checkpointing and self.training: @@ -1074,6 +1072,7 @@ def prepare_inputs_for_generation( ): input_shape = input_ids.shape # cut decoder_input_ids if past is used + past_length = 0 if past_key_values is not None: past_length = past_key_values[0][0].shape[2] @@ -1087,12 +1086,16 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: diff --git a/src/transformers/models/gptj/modeling_flax_gptj.py b/src/transformers/models/gptj/modeling_flax_gptj.py index 9f0d4d6e860003..5e690127dc25de 100644 --- a/src/transformers/models/gptj/modeling_flax_gptj.py +++ b/src/transformers/models/gptj/modeling_flax_gptj.py @@ -458,7 +458,7 @@ def __call__( if past_key_values is not None: raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, sequence_length) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) @@ -693,10 +693,8 @@ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: O # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, seq_length) return { "past_key_values": past_key_values, diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 3c6ddac4ecf4ca..6469fc6b345994 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -870,8 +870,9 @@ def forward( past_length = past_key_values[0][0].size(-2) if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=input_shape[-1], device=device + ) if not self._use_flash_attention_2: # Attention mask. @@ -1052,6 +1053,7 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # Omit tokens covered by past_key_values + past_length = 0 if past_key_values: past_length = past_key_values[0][0].shape[2] @@ -1068,13 +1070,17 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_ attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/gptj/modeling_tf_gptj.py b/src/transformers/models/gptj/modeling_tf_gptj.py index 5c315b5b66f049..4b259492ae2d0c 100644 --- a/src/transformers/models/gptj/modeling_tf_gptj.py +++ b/src/transformers/models/gptj/modeling_tf_gptj.py @@ -446,7 +446,15 @@ def call( past_length = shape_list(past_key_values[0][0])[-2] if position_ids is None: - position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) + if attention_mask is not None: + position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + # create ones tensor to match dtypes, otherwise we get errors + ones_tensor = tf.ones_like(position_ids, dtype=tf.int64) + position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids) + position_ids = position_ids[..., -input_shape[-1] :] + position_ids = tf.reshape(position_ids, (-1, input_shape[-1])) + else: + position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) if attention_mask is not None: # We create a 3D attention mask from a 2D tensor mask. @@ -771,7 +779,9 @@ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache= attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None and position_ids is None: - position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True) + position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 + ones_tensor = tf.ones_like(position_ids, dtype=tf.int64) + position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids) if past_key_values: position_ids = tf.expand_dims(position_ids[:, -1], -1) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index a01c2279c15586..c30035ad632ddb 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -156,10 +156,6 @@ def expand_inputs_for_generation( model_kwargs["perceiver_embeddings"] = model_kwargs.get("perceiver_embeddings", None) model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask", None) - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) - if attention_mask is not None: model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) @@ -184,45 +180,6 @@ def expand_inputs_for_generation( return input_ids, model_kwargs -def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs): - token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - pixel_values = kwargs.get("pixel_values", None) - image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None) - perceiver_embeddings = kwargs.get("perceiver_embeddings", None) - image_attention_mask = kwargs.get("image_attention_mask", None) - interpolate_pos_encoding = kwargs.get("interpolate_pos_encoding", False) - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - "pixel_values": pixel_values, - "image_encoder_embeddings": image_encoder_embeddings, - "perceiver_embeddings": perceiver_embeddings, - "image_attention_mask": image_attention_mask, - "interpolate_pos_encoding": interpolate_pos_encoding, - } - - def freeze_model(model, module_exceptions=[]): mapping = { "LayerNorm": nn.LayerNorm, @@ -1155,15 +1112,10 @@ def forward( past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - elif position_ids is None: - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + if position_ids is None: + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=device ) - position_ids = position_ids.unsqueeze(0) if (pixel_values, image_encoder_embeddings, perceiver_embeddings).count(None) != 2: raise ValueError( @@ -1536,7 +1488,7 @@ def forward( image_hidden_states=outputs.image_hidden_states, ) - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): image_hidden_states = kwargs.pop("image_hidden_states", None) if image_hidden_states is not None: if self.config.use_resampler: @@ -1544,11 +1496,53 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): else: kwargs["image_encoder_embeddings"] = image_hidden_states kwargs["pixel_values"] = None - inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs) - unwanted_kwargs = ["token_type_ids"] - for kwarg in unwanted_kwargs: - inputs.pop(kwarg, None) - return inputs + + # only last token for inputs_ids if past is defined in kwargs + attention_mask = kwargs.get("attention_mask", None) + past_length = 0 + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + position_ids = kwargs.get("position_ids", None) + + seq_length = input_ids.shape[-1] + if position_ids is None: + device = input_ids.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] + + pixel_values = kwargs.get("pixel_values", None) + image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None) + perceiver_embeddings = kwargs.get("perceiver_embeddings", None) + image_attention_mask = kwargs.get("image_attention_mask", None) + interpolate_pos_encoding = kwargs.get("interpolate_pos_encoding", False) + + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "image_encoder_embeddings": image_encoder_embeddings, + "perceiver_embeddings": perceiver_embeddings, + "image_attention_mask": image_attention_mask, + "interpolate_pos_encoding": interpolate_pos_encoding, + } @staticmethod def _expand_inputs_for_generation( diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 81b41078633aa9..1db535d28fc742 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -728,9 +728,11 @@ def forward( past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) + if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=input_shape[-1], device=device + ) # ImageGPTAttention mask. if attention_mask is not None: @@ -900,6 +902,7 @@ def set_output_embeddings(self, new_embeddings): def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[bool] = None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # Omit tokens covered by past_key_values + past_length = 0 if past_key_values: past_length = past_key_values[0][0].shape[2] @@ -917,14 +920,15 @@ def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = input_ids.shape[-1] + if position_ids is None: + device = input_ids.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) else: - position_ids = None + position_ids = position_ids[:, -seq_length:] + return { "input_ids": input_ids, "past_key_values": past_key_values, diff --git a/src/transformers/models/llama/modeling_flax_llama.py b/src/transformers/models/llama/modeling_flax_llama.py index 1e7fb5d1a1cd94..06509f361a0601 100644 --- a/src/transformers/models/llama/modeling_flax_llama.py +++ b/src/transformers/models/llama/modeling_flax_llama.py @@ -492,7 +492,7 @@ def __call__( if past_key_values is not None: raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, sequence_length) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) @@ -723,11 +723,9 @@ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: O # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, seq_length) return { "past_key_values": past_key_values, "attention_mask": extended_attention_mask, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 905edf5f71a63d..ec83d294caea2e 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -987,7 +987,9 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_seen_tokens, seq_length=inputs_embeds.shape[1], device=inputs_embeds.device + ) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) @@ -1305,12 +1307,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 4cf5d98f77f114..ad94a8fdf166de 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -518,6 +518,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs ): + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -544,12 +545,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 155d9e3e6abf40..0989713d6f70fa 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -642,6 +642,7 @@ def prepare_inputs_for_generation( attention_mask=None, **kwargs, ): + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -668,12 +669,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/mistral/modeling_flax_mistral.py b/src/transformers/models/mistral/modeling_flax_mistral.py index 3480fc7214a84a..4bdbb6c0d1a66a 100644 --- a/src/transformers/models/mistral/modeling_flax_mistral.py +++ b/src/transformers/models/mistral/modeling_flax_mistral.py @@ -480,7 +480,7 @@ def __call__( if past_key_values is not None: raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, sequence_length) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) @@ -715,10 +715,8 @@ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: O # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: - position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) - else: - position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + position_ids = self.get_position_ids_from_attention_mask(attention_mask, batch_size, seq_length) return { "past_key_values": past_key_values, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index c013967c78f116..c27439ebdfb64e 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -974,18 +974,16 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: @@ -1200,6 +1198,7 @@ def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): # Omit tokens covered by past_key_values + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -1230,12 +1229,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index c78e907d5fdbb9..aab843a9dfa392 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1156,18 +1156,16 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: @@ -1423,6 +1421,7 @@ def prepare_inputs_for_generation( **kwargs, ): # Omit tokens covered by past_key_values + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -1453,12 +1452,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index e3b0e05127c52d..6f8a6b2dfd5e31 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -969,7 +969,9 @@ def forward( ) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_seen_tokens, seq_length=inputs_embeds.shape[1], device=inputs_embeds.device + ) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) @@ -1286,12 +1288,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index c83ba413952b16..a044b0638e9baf 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -622,15 +622,14 @@ def forward( past_key_values_length = past_key_values.get_usable_length(seq_length) seq_length_with_past = seq_length_with_past + past_key_values_length + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device ) - position_ids = position_ids.unsqueeze(0) - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) # embed positions if attention_mask is None: attention_mask = torch.ones( @@ -826,6 +825,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -856,12 +856,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 13719166edf9d9..b4188425e70c14 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -995,18 +995,15 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0) - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - inputs_embeds = self.embed_dropout(inputs_embeds) + if position_ids is None: + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device + ) + # Attention mask. if self._use_flash_attention_2: # 2d mask is passed through the layers @@ -1211,6 +1208,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -1241,12 +1239,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index f9364d130b7e6c..e9ee2f95296d82 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1317,6 +1317,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -1347,12 +1348,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index b5a1370ae1fc8f..d847aae360cefb 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -985,18 +985,16 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: @@ -1211,6 +1209,7 @@ def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): # Omit tokens covered by past_key_values + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -1241,12 +1240,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 70072c91720a57..c4edd35c2dc271 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1148,18 +1148,16 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: @@ -1408,6 +1406,7 @@ def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): # Omit tokens covered by past_key_values + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -1438,12 +1437,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 3262f2cd3c6117..9dd08235e59fa4 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -987,16 +987,14 @@ def forward( past_key_values_length = past_key_values.get_usable_length(seq_length) seq_length_with_past = seq_length_with_past + past_key_values_length + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device ) - position_ids = position_ids.unsqueeze(0) - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None @@ -1198,6 +1196,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -1228,12 +1227,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index ca4c8af23304f9..f680fbf233cee0 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -963,18 +963,16 @@ def forward( past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_key_values_length, seq_length=inputs_embeds.shape[1], device=inputs_embeds.device ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: @@ -1191,6 +1189,7 @@ def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): # Omit tokens covered by past_key_values + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -1221,12 +1220,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 1b20353410c895..edc2181a6bc4bf 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -512,6 +512,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs ): + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -538,12 +539,16 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + seq_length = ( + inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] + ) + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = self.get_position_ids_from_attention_mask( + attention_mask, past_length, seq_length=seq_length, device=device + ) + else: + position_ids = position_ids[:, -seq_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py index 987c9a1afa3d19..0e0091329411ce 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py @@ -531,8 +531,8 @@ def decode( if past_key_values is not None: raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + decoder_position_ids = self.get_position_ids_from_attention_mask( + decoder_attention_mask, batch_size, sequence_length ) # Handle any PRNG if needed @@ -665,8 +665,8 @@ def __call__( decoder_attention_mask = jnp.ones_like(decoder_input_ids) if decoder_position_ids is None: batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + decoder_position_ids = self.get_position_ids_from_attention_mask( + decoder_attention_mask, batch_size, sequence_length ) # Handle any PRNG if needed diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index eacba9ebc6f4a5..fab11db8a32f08 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1091,9 +1091,9 @@ def test_beam_search_low_memory(self): ) self.assertListEqual(low_output.tolist(), high_output.tolist()) - @parameterized.expand([("random",), ("same",)]) + @parameterized.expand([("random", True), ("same", True), ("random", False), ("same", False)]) @is_flaky() # Read NOTE (1) below. If there are API issues, all attempts will fail. - def test_assisted_decoding_matches_greedy_search(self, assistant_type): + def test_assisted_decoding_matches_greedy_search(self, assistant_type, use_position_ids): # This test ensures that the assisted generation does not introduce output changes over greedy search. # NOTE (1): The sentence above is true most of the time, there is a tiny difference in the logits due to matmul # shape differences -- and it may result in a different output. The input shape difference happens in the @@ -1150,7 +1150,6 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): "output_attentions": True, "return_dict_in_generate": True, } - output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) # test with the same assistant model or randomly init one # in the first case all candidate tokens are accepted, in the second none is accepted @@ -1161,8 +1160,31 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): assistant_model = model assistant_model.generation_config.num_assistant_tokens = 2 # see b) assistant_model.generation_config.num_assistant_tokens_schedule = "constant" # see b) - generation_kwargs.update({"assistant_model": assistant_model}) - output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + + # test that the output is correct if user pass in position ids into `generate` vs it's calculated internally + if use_position_ids: + model_forward_args = inspect.signature(model.forward).parameters + if "position_ids" not in model_forward_args: + self.skipTest("This model doesn't use `position_ids`") + + position_ids = model.get_position_ids_from_attention_mask( + attention_mask, past_length=0, seq_length=input_ids.shape[-1], device=input_ids.device + ) + output_greedy = model.generate( + input_ids, attention_mask=attention_mask, position_ids=position_ids, **generation_kwargs + ) + output_assisted = model.generate( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + assistant_model=assistant_model, + **generation_kwargs, + ) + else: + output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + output_assisted = model.generate( + input_ids, attention_mask=attention_mask, assistant_model=assistant_model, **generation_kwargs + ) # The two outputs must match and their shape must be as expected self.assertListEqual(output_greedy.sequences.tolist(), output_assisted.sequences.tolist()) @@ -1587,6 +1609,56 @@ def test_generate_continue_from_past_key_values(self): ) ) + def test_generate_with_and_without_position_ids(self): + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() + model_forward_args = inspect.signature(model.forward).parameters + if "position_ids" not in model_forward_args: + self.skipTest("This model doesn't use `position_ids`") + + out_wo_positions = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=5) + + # infer position ids from attn mask and generate again + position_ids = model.get_position_ids_from_attention_mask( + attention_mask, past_length=0, seq_length=input_ids.shape[-1], device=input_ids.device + ) + out_w_positions = model.generate( + input_ids, attention_mask=attention_mask, position_ids=position_ids, max_new_tokens=5 + ) + + # The two sets of generated sequences must match, if generate can infer position ids correctly + # and can continue adding new ids to the already passed position ids + self.assertListEqual(out_wo_positions.tolist(), out_w_positions.tolist()) + + def test_generate_with_and_without_position_ids_inputs_embeds(self): + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask = self._get_input_ids_and_config() + model = model_class(config).to(torch_device).eval() + model_forward_args = inspect.signature(model.forward).parameters + if "position_ids" not in model_forward_args: + self.skipTest("This model doesn't use `position_ids`") + + if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys(): + self.skipTest("This model doesn't use `inputs_embeds`") + + inputs_embeds = model.get_input_embeddings()(input_ids) + out_wo_positions = model.generate( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=5 + ) + + # infer position ids from attn mask and generate again + position_ids = model.get_position_ids_from_attention_mask( + attention_mask, past_length=0, seq_length=input_ids.shape[-1], device=input_ids.device + ) + out_w_positions = model.generate( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, max_new_tokens=5 + ) + + # The two sets of generated sequences must match, if generate can infer position ids correctly + # and can continue adding new ids to the already passed position ids + self.assertListEqual(out_wo_positions.tolist(), out_w_positions.tolist()) + @parameterized.expand([(1, False), (1, True), (4, False)]) def test_new_cache_format(self, num_beams, do_sample): # Tests that generating with the new format is exactly the same as the legacy one (for models that support it). diff --git a/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py index c8f76a144be703..84fd65711e34b4 100644 --- a/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_flax_encoder_decoder.py @@ -29,6 +29,8 @@ if is_flax_available(): + import jax.numpy as jnp + from transformers import ( AutoTokenizer, EncoderDecoderConfig, @@ -299,6 +301,14 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): flax_inputs = inputs_dict pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} + # when model weights are random init masking with attn_mask still leads to logits + # mismatch, which does not happen if pre-trained models are used. That causes error in encoder-decoder models + # when decoder_only is used in as backbone (GPT2), because GPT prepares positions depending on attn mask (for torch) + # and as arange in flax. That's why we init attn mask with all `1` + if "decoder_attention_mask" in pt_inputs: + pt_inputs["decoder_attention_mask"] = torch.ones_like(pt_inputs["decoder_attention_mask"]) + inputs_dict["decoder_attention_mask"] = jnp.ones_like(inputs_dict["decoder_attention_mask"]) + with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple() diff --git a/tests/models/gemma/test_modeling_flax_gemma.py b/tests/models/gemma/test_modeling_flax_gemma.py index 0f3c5df4f13622..d0bd7d2c6a8cd5 100644 --- a/tests/models/gemma/test_modeling_flax_gemma.py +++ b/tests/models/gemma/test_modeling_flax_gemma.py @@ -143,6 +143,9 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input max_decoder_length = 20 model = model_class_name(config) + # make full attn mask since below we are preparing position ids assuming it's all ones + attention_mask = jnp.ones_like(attention_mask) + attention_mask_cache = jnp.concatenate( [attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))], axis=-1, diff --git a/tests/models/gpt2/test_modeling_flax_gpt2.py b/tests/models/gpt2/test_modeling_flax_gpt2.py index fbf2d6c333fd8a..c724a1a8a7d74f 100644 --- a/tests/models/gpt2/test_modeling_flax_gpt2.py +++ b/tests/models/gpt2/test_modeling_flax_gpt2.py @@ -158,6 +158,9 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input max_decoder_length = 20 model = model_class_name(config) + # make full attn mask since below we are preparing position ids assuming it's all ones + attention_mask = jnp.ones_like(attention_mask) + attention_mask_cache = jnp.concatenate( [attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))], axis=-1, @@ -280,6 +283,7 @@ def test_equivalence_pt_to_flax(self): pt_inputs["attention_mask"][batch_idx, start_index:] = 1 prepared_inputs_dict["attention_mask"][batch_idx, :start_index] = 0 prepared_inputs_dict["attention_mask"][batch_idx, start_index:] = 1 + pt_model = pt_model_class(config).eval() fx_model = model_class(config, dtype=jnp.float32) diff --git a/tests/models/gpt_neo/test_modeling_flax_gpt_neo.py b/tests/models/gpt_neo/test_modeling_flax_gpt_neo.py index ca41495a842c77..6d934ee0ea16d5 100644 --- a/tests/models/gpt_neo/test_modeling_flax_gpt_neo.py +++ b/tests/models/gpt_neo/test_modeling_flax_gpt_neo.py @@ -150,6 +150,9 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input max_decoder_length = 20 model = model_class_name(config) + # make full attn mask since below we are preparing position ids assuming it's all ones + attention_mask = jnp.ones_like(attention_mask) + attention_mask_cache = jnp.concatenate( [attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))], axis=-1, diff --git a/tests/models/gptj/test_modeling_flax_gptj.py b/tests/models/gptj/test_modeling_flax_gptj.py index aa3b7a99aa0fdf..5ae55357eb837b 100644 --- a/tests/models/gptj/test_modeling_flax_gptj.py +++ b/tests/models/gptj/test_modeling_flax_gptj.py @@ -147,6 +147,9 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input max_decoder_length = 20 model = model_class_name(config) + # make full attn mask since below we are preparing position ids assuming it's all ones + attention_mask = jnp.ones_like(attention_mask) + attention_mask_cache = jnp.concatenate( [attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))], axis=-1, diff --git a/tests/models/llama/test_modeling_flax_llama.py b/tests/models/llama/test_modeling_flax_llama.py index a81398786b43b6..f7d80e3d197ac9 100644 --- a/tests/models/llama/test_modeling_flax_llama.py +++ b/tests/models/llama/test_modeling_flax_llama.py @@ -149,8 +149,8 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input ) past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length) - position_ids = jnp.broadcast_to( - jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1) + position_ids = model.get_position_ids_from_attention_mask( + attention_mask_cache, input_ids.shape[0], input_ids.shape[-1] - 1 ) outputs_cache = model( @@ -159,7 +159,6 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input past_key_values=past_key_values, position_ids=position_ids, ) - position_ids = jnp.array(input_ids.shape[0] * [[input_ids.shape[-1] - 1]], dtype="i4") outputs_cache_next = model( input_ids[:, -1:], past_key_values=outputs_cache.past_key_values, diff --git a/tests/models/mistral/test_modeling_flax_mistral.py b/tests/models/mistral/test_modeling_flax_mistral.py index 047bf4c6d433d0..6bb69eadf40e1b 100644 --- a/tests/models/mistral/test_modeling_flax_mistral.py +++ b/tests/models/mistral/test_modeling_flax_mistral.py @@ -154,6 +154,9 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input max_decoder_length = 20 model = model_class_name(config) + # make full attn mask since below we are preparing position ids assuming it's all ones + attention_mask = jnp.ones_like(attention_mask) + attention_mask_cache = jnp.concatenate( [attention_mask, jnp.zeros((attention_mask.shape[0], max_decoder_length - attention_mask.shape[1]))], axis=-1, diff --git a/tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py b/tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py index 62ce0d660a0abc..5e66beb215e85f 100644 --- a/tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py +++ b/tests/models/speech_encoder_decoder/test_modeling_flax_speech_encoder_decoder.py @@ -414,6 +414,14 @@ def check_pt_flax_equivalence(self, pt_model, fx_model, inputs_dict): flax_inputs = inputs_dict pt_inputs = {k: torch.tensor(v.tolist()) for k, v in flax_inputs.items()} + # when model weights are random init masking with attn_mask still leads to logits + # mismatch, which does not happen if pre-trained models are used. That causes error in encoder-decoder models + # when decoder_only is used in as backbone (GPT2), because GPT prepares positions depending on attn mask (for torch) + # and as arange in flax. That's why we init attn mask with all `1` + if "decoder_attention_mask" in pt_inputs: + pt_inputs["decoder_attention_mask"] = torch.ones_like(pt_inputs["decoder_attention_mask"]) + inputs_dict["decoder_attention_mask"] = jnp.ones_like(inputs_dict["decoder_attention_mask"]) + with torch.no_grad(): pt_outputs = pt_model(**pt_inputs).to_tuple()