From be9aeba5812cdcf9a47248b4ef05184cab6db200 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Fri, 11 Oct 2024 10:28:34 +0200 Subject: [PATCH] Idefics: fix position ids (#33907) * fix position ids * fix labels also * fix copies * oops, not that one * dont deprecate --- .../models/idefics/modeling_idefics.py | 111 ++++++++---------- .../models/idefics2/modeling_idefics2.py | 4 +- .../models/idefics3/modeling_idefics3.py | 4 +- .../models/llava/modeling_llava.py | 4 +- .../models/llava_next/modeling_llava_next.py | 4 +- .../modeling_llava_next_video.py | 4 +- .../modular_llava_next_video.py | 4 +- .../modeling_llava_onevision.py | 4 +- .../models/paligemma/modeling_paligemma.py | 3 +- .../video_llava/modeling_video_llava.py | 4 +- .../models/vipllava/modeling_vipllava.py | 2 +- 11 files changed, 79 insertions(+), 69 deletions(-) diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index f9a3538f488a54..02de8d61ae204c 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -183,51 +183,6 @@ def expand_inputs_for_generation( return input_ids, model_kwargs -def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs): - token_type_ids = kwargs.get("token_type_ids", None) - cache_position = kwargs.get("cache_position", None) - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - if past_key_values is not None: - if input_ids.shape[1] != cache_position.shape[0]: - input_ids = input_ids[:, cache_position] - if token_type_ids is not None: - token_type_ids = token_type_ids[:, -input_ids.shape[1] :] - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. - position_ids = position_ids.clone(memory_format=torch.contiguous_format) - - pixel_values = kwargs.get("pixel_values", None) - image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None) - perceiver_embeddings = kwargs.get("perceiver_embeddings", None) - image_attention_mask = kwargs.get("image_attention_mask", None) - interpolate_pos_encoding = kwargs.get("interpolate_pos_encoding", False) - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "cache_position": cache_position, - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": token_type_ids, - "pixel_values": pixel_values, - "image_encoder_embeddings": image_encoder_embeddings, - "perceiver_embeddings": perceiver_embeddings, - "image_attention_mask": image_attention_mask, - "interpolate_pos_encoding": interpolate_pos_encoding, - } - - def freeze_model(model, module_exceptions=[]): mapping = { "LayerNorm": nn.LayerNorm, @@ -1210,11 +1165,9 @@ def forward( # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids[:, -seq_length:] elif position_ids is None: - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0) + position_ids = cache_position.unsqueeze(0) if (pixel_values, image_encoder_embeddings, perceiver_embeddings).count(None) != 2: raise ValueError( @@ -1684,7 +1637,9 @@ def forward( labels = labels.to(logits.device) # Shift so that tokens < n predict n if attention_mask is not None: - shift_attention_mask = attention_mask[..., 1:].to(logits.device) + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous() shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() else: @@ -1707,19 +1662,57 @@ def forward( image_hidden_states=outputs.image_hidden_states, ) - def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + position_ids=None, + pixel_values=None, + image_hidden_states=None, + use_cache=None, + cache_position=None, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + if past_key_values is not None: + if input_ids.shape[1] != cache_position.shape[0]: + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + model_inputs = {} image_hidden_states = kwargs.pop("image_hidden_states", None) if image_hidden_states is not None: if self.config.use_resampler: - kwargs["perceiver_embeddings"] = image_hidden_states + model_inputs["perceiver_embeddings"] = image_hidden_states else: - kwargs["image_encoder_embeddings"] = image_hidden_states - kwargs["pixel_values"] = None - inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs) - unwanted_kwargs = ["token_type_ids"] - for kwarg in unwanted_kwargs: - inputs.pop(kwarg, None) - return inputs + model_inputs["image_encoder_embeddings"] = image_hidden_states + pixel_values = None + + model_inputs.update( + { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "cache_position": cache_position, + "position_ids": position_ids, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "image_attention_mask": kwargs.get("image_attention_mask", None), + "interpolate_pos_encoding": kwargs.get("interpolate_pos_encoding", False), + } + ) + + return model_inputs @staticmethod def _expand_inputs_for_generation( diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 635eb9b9e15721..b53d0722587d5a 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1626,7 +1626,9 @@ def forward( labels = labels.to(logits.device) # Shift so that tokens < n predict n if attention_mask is not None: - shift_attention_mask = attention_mask[..., 1:].to(logits.device) + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous() shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() else: diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index fe3067a145314a..757391175ea671 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -1213,7 +1213,9 @@ def forward( labels = labels.to(logits.device) # Shift so that tokens < n predict n if attention_mask is not None: - shift_attention_mask = attention_mask[..., 1:].to(logits.device) + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous() shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() else: diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 0bc08f9f86864f..e793ca61c750d7 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -546,7 +546,9 @@ def forward( if labels is not None: # Shift so that tokens < n predict n if attention_mask is not None: - shift_attention_mask = attention_mask[..., 1:] + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() else: diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index b9d20a47e61ec2..705821c2b713e8 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -923,7 +923,9 @@ def forward( if labels is not None: # Shift so that tokens < n predict n if attention_mask is not None: - shift_attention_mask = attention_mask[..., 1:] + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() else: diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 58fed1832670f5..7df4cf20372bb7 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -1004,7 +1004,9 @@ def forward( if labels is not None: # Shift so that tokens < n predict n if attention_mask is not None: - shift_attention_mask = attention_mask[..., 1:] + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() else: diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 39c55930a8d574..4b6be407dcab81 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -519,7 +519,9 @@ def forward( if labels is not None: # Shift so that tokens < n predict n if attention_mask is not None: - shift_attention_mask = attention_mask[..., 1:] + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() else: diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 4443ae68fa64ed..f65c0fe7cfb3e5 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -676,7 +676,9 @@ def forward( if labels is not None: # Shift so that tokens < n predict n if attention_mask is not None: - shift_attention_mask = attention_mask[..., 1:] + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() else: diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 5e695f3387d768..d75a05bda0e1ec 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -532,7 +532,8 @@ def forward( shift_labels = labels[..., 1:] if attention_mask is not None: # we use the input attention mask to shift the logits and labels, because it is 2D. - shift_attention_mask = attention_mask[..., 1:] + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() else: diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 008240d0d929e6..5711433c368d5e 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -656,7 +656,9 @@ def forward( if labels is not None: # Shift so that tokens < n predict n if attention_mask is not None: - shift_attention_mask = attention_mask[..., 1:] + # we use the input attention mask to shift the logits and labels, because it is 2D. + # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() else: diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index cacaaf6ac35ab1..26d92b9ac3dca4 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -539,7 +539,7 @@ def forward( if labels is not None: # Shift so that tokens < n predict n if attention_mask is not None: - shift_attention_mask = attention_mask[..., 1:] + shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() else: