Skip to content

Commit

Permalink
Idefics: fix position ids (huggingface#33907)
Browse files Browse the repository at this point in the history
* fix position ids

* fix labels also

* fix copies

* oops, not that one

* dont deprecate
  • Loading branch information
zucchini-nlp authored Oct 11, 2024
1 parent 7d97cca commit be9aeba
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 69 deletions.
111 changes: 52 additions & 59 deletions src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,51 +183,6 @@ def expand_inputs_for_generation(
return input_ids, model_kwargs


def prepare_inputs_for_generation(input_ids, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
cache_position = kwargs.get("cache_position", None)
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
if past_key_values is not None:
if input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]

attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)

pixel_values = kwargs.get("pixel_values", None)
image_encoder_embeddings = kwargs.get("image_encoder_embeddings", None)
perceiver_embeddings = kwargs.get("perceiver_embeddings", None)
image_attention_mask = kwargs.get("image_attention_mask", None)
interpolate_pos_encoding = kwargs.get("interpolate_pos_encoding", False)

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"cache_position": cache_position,
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"pixel_values": pixel_values,
"image_encoder_embeddings": image_encoder_embeddings,
"perceiver_embeddings": perceiver_embeddings,
"image_attention_mask": image_attention_mask,
"interpolate_pos_encoding": interpolate_pos_encoding,
}


def freeze_model(model, module_exceptions=[]):
mapping = {
"LayerNorm": nn.LayerNorm,
Expand Down Expand Up @@ -1210,11 +1165,9 @@ def forward(
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids[:, -seq_length:]
elif position_ids is None:
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0)
position_ids = cache_position.unsqueeze(0)

if (pixel_values, image_encoder_embeddings, perceiver_embeddings).count(None) != 2:
raise ValueError(
Expand Down Expand Up @@ -1684,7 +1637,9 @@ def forward(
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:].to(logits.device)
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
else:
Expand All @@ -1707,19 +1662,57 @@ def forward(
image_hidden_states=outputs.image_hidden_states,
)

def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
position_ids=None,
pixel_values=None,
image_hidden_states=None,
use_cache=None,
cache_position=None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
if past_key_values is not None:
if input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)

model_inputs = {}
image_hidden_states = kwargs.pop("image_hidden_states", None)
if image_hidden_states is not None:
if self.config.use_resampler:
kwargs["perceiver_embeddings"] = image_hidden_states
model_inputs["perceiver_embeddings"] = image_hidden_states
else:
kwargs["image_encoder_embeddings"] = image_hidden_states
kwargs["pixel_values"] = None
inputs = prepare_inputs_for_generation(input_ids, past=past, **kwargs)
unwanted_kwargs = ["token_type_ids"]
for kwarg in unwanted_kwargs:
inputs.pop(kwarg, None)
return inputs
model_inputs["image_encoder_embeddings"] = image_hidden_states
pixel_values = None

model_inputs.update(
{
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"cache_position": cache_position,
"position_ids": position_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"image_attention_mask": kwargs.get("image_attention_mask", None),
"interpolate_pos_encoding": kwargs.get("interpolate_pos_encoding", False),
}
)

return model_inputs

@staticmethod
def _expand_inputs_for_generation(
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1626,7 +1626,9 @@ def forward(
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:].to(logits.device)
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
else:
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/idefics3/modeling_idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,9 @@ def forward(
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:].to(logits.device)
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
else:
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,9 @@ def forward(
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:]
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,7 +923,9 @@ def forward(
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:]
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,9 @@ def forward(
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:]
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,9 @@ def forward(
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:]
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,9 @@ def forward(
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:]
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/paligemma/modeling_paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,8 @@ def forward(
shift_labels = labels[..., 1:]
if attention_mask is not None:
# we use the input attention mask to shift the logits and labels, because it is 2D.
shift_attention_mask = attention_mask[..., 1:]
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
else:
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/video_llava/modeling_video_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,9 @@ def forward(
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:]
# we use the input attention mask to shift the logits and labels, because it is 2D.
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/vipllava/modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def forward(
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:]
shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
Expand Down

0 comments on commit be9aeba

Please sign in to comment.