Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculate position ids in modeling utils for all generative models #30053

Closed
wants to merge 25 commits into from

Conversation

zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Apr 4, 2024

What does this PR do?

As it was discussed under this PR, position ids in some models are not calculated/inferred from attn mask in forward, which gives incorrect positions when the inputs is left padded.

To be consistent and for ease of maintaining, the logic of inferring position ids is moved to "modeling_utils.py" and all generative models call that method in their forward and prepare_inputs_for_generation. I added two tests, to check whether model outputs are same when position ids are passed by a user vs. when inferred from input ids or embeds.

Also Fixes #29149.

The newly added tests are passing. Plus slow tests on vision models, because they still do not have GenerationTesterMixin.

Btw, I see that non-generative models already use create_position_ids_from_input_ids method which is copied separately in each model's file. The logic is a bit different from generative models, because they start counting from "padding_idx" and not "0". Anyway, I guess it is still possible to merge that method and the one proposed here, to have one "get_position_id" for all models in the "modeling_utils".
@gante WDYT ?

@zucchini-nlp zucchini-nlp requested a review from gante April 4, 2024 17:14
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@zucchini-nlp
Copy link
Member Author

About the framework changes: I found that tf/flax has a slightly different way to get position_ids from torch models. Those frameworks generate position ids in forward without taking into account attention mask, same thing we had in torch before these fixes.

I made tf and flax same way as torch is now with a cumsum over attention mask, so that the equivalence over frameworks tests pass. I am not sure if we need similar test for tf/flax to "test_position_ids". Tests should pass now, at least locally it seemed okay

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general looks good, thank you for tackling this refactor 💪

A few notes:

  1. No TF function to infer the position IDs? 😢 TF feels neglected 💔
  2. There are CI errors in the model equivalence. Model equivalence is flaky by nature, make sure you run model equivalence for all models with flake finder locally!
  3. After you're happy with the changes, commit with [test_all] and tag me again. I've glanced over the model-level changes after the first few models, I'll do a final check more carefully after the full CI is green 🤗

Comment on lines +436 to +439
if length_diff < 0:
position_ids = position_ids[:, :length_diff]
elif length_diff > 0:
new_position_ids = torch.arange(position_ids[0, -1], new_length, device=position_ids.device).unsqueeze(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment briefly explaining when each situation can be triggered, and why we want that operation? Our future selves will probably be happy with that comment

e.g. I'm assuming length_diff > 0 is used when candidates are proposed, and thus we want the corresponding position ids. But I'm not immediately seeing when length_diff < 0 can be triggered :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ this function still needs better variable names and/or a docstring

src/transformers/models/codegen/modeling_codegen.py Outdated Show resolved Hide resolved
@@ -1189,6 +1189,66 @@ def test_assisted_decoding_matches_greedy_search(self):
for output in (output_greedy, output_assisted):
self._check_outputs(output, input_ids, model.config, use_cache=True)

@is_flaky()
def test_assisted_decoding_position_ids(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like in the other PR you added an assisted generation test: let's make this a parameterization of the original test, since it's a minor variation :)

tests/generation/test_utils.py Outdated Show resolved Hide resolved
tests/generation/test_utils.py Outdated Show resolved Hide resolved
@zucchini-nlp
Copy link
Member Author

Okay, will work on it.

  1. TF has only 3 models for decoder-only so I thought we would not need it. Okay I can dd it in the same way
  2. These are the composite models for flax, I think it's needs a fix but could not find where yet
  3. Okay :)

@zucchini-nlp
Copy link
Member Author

@gante the comments are addressed now. TF cannot have the "get_position_ids" method in PretrainedModel because all input related preparations in TF happen in a "keras.layers.Layer" class. I am not sure if we can or should be moving the position_id preparation into the "PretrainedModel", since there are only 3 TF models that were needed change.

Also, to note for Flax-based encoder-decoder models: the attention mask for decoder part is overriden to be full, because when using decoder-only model as decoder part the position ids are calculated differently (I mean only the unattended part). In random initialized models it is causing logits mismatch, even if the attention masks out unattended positions. In pre-trained models that does not happen 🤔

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@zucchini-nlp
Copy link
Member Author

Hold it for a while, not stale

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more pattern fixes and should be ready to go 🤞

Comment on lines +436 to +439
if length_diff < 0:
position_ids = position_ids[:, :length_diff]
elif length_diff > 0:
new_position_ids = torch.arange(position_ids[0, -1], new_length, device=position_ids.device).unsqueeze(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ this function still needs better variable names and/or a docstring

Comment on lines +617 to +626
seq_length = (
inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1]
)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = self.get_position_ids_from_attention_mask(
attention_mask, past_length, seq_length=seq_length, device=device
)
else:
position_ids = position_ids[:, -seq_length:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove all this code, actually 👀 I see the following cases:

  1. position_ids is None -> the forward pass correctly computes position_ids, due to the changes in this PR
  2. position_ids is not None -> the user has defined position_ids, it's its own responsibility to pass them correctly

WDYT? (this logic would apply to all models, and would make maintenance easier for us 👼 )

@@ -593,6 +593,7 @@ def forward(
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}

print(self.encoder)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(self.encoder)

@@ -702,7 +709,9 @@ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=
attention_mask = kwargs.get("attention_mask", None)

if attention_mask is not None and position_ids is None:
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one should be correct, no? 🤔

(the same comment applies to other TF models)

Comment on lines +348 to +355
position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1
# create ones tensor to match dtypes, otherwise we get errors
ones_tensor = tf.ones_like(position_ids, dtype=tf.int64)
position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids)
position_ids = position_ids[..., -input_shape[-1] :]
position_ids = tf.reshape(position_ids, (-1, input_shape[-1]))
else:
position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1
# create ones tensor to match dtypes, otherwise we get errors
ones_tensor = tf.ones_like(position_ids, dtype=tf.int64)
position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids)
position_ids = position_ids[..., -input_shape[-1] :]
position_ids = tf.reshape(position_ids, (-1, input_shape[-1]))
else:
position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)

(see comment below)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same logic applies to other TF models

Comment on lines +304 to +310
# when model weights are random init masking with attn_mask still leads to logits
# mismatch, which does not happen if pre-trained models are used. That causes error in encoder-decoder models
# when decoder_only is used in as backbone (GPT2), because GPT prepares positions depending on attn mask (for torch)
# and as arange in flax. That's why we init attn mask with all `1`
if "decoder_attention_mask" in pt_inputs:
pt_inputs["decoder_attention_mask"] = torch.ones_like(pt_inputs["decoder_attention_mask"])
inputs_dict["decoder_attention_mask"] = jnp.ones_like(inputs_dict["decoder_attention_mask"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change should no longer be needed, correct?

(as a general rule, we shouldn't fudge these equivalence tests :) )

Comment on lines +146 to +147
# make full attn mask since below we are preparing position ids assuming it's all ones
attention_mask = jnp.ones_like(attention_mask)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the other way around: we should update the creation of position_ids (below) to match the mask

The same comment applies to other FLAX test changes

@@ -149,8 +149,8 @@ def check_use_cache_forward_with_attn_mask(self, model_class_name, config, input
)

past_key_values = model.init_cache(input_ids.shape[0], max_decoder_length)
position_ids = jnp.broadcast_to(
jnp.arange(input_ids.shape[-1] - 1)[None, :], (input_ids.shape[0], input_ids.shape[-1] - 1)
position_ids = model.get_position_ids_from_attention_mask(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, like this!

Comment on lines +417 to +424
# when model weights are random init masking with attn_mask still leads to logits
# mismatch, which does not happen if pre-trained models are used. That causes error in encoder-decoder models
# when decoder_only is used in as backbone (GPT2), because GPT prepares positions depending on attn mask (for torch)
# and as arange in flax. That's why we init attn mask with all `1`
if "decoder_attention_mask" in pt_inputs:
pt_inputs["decoder_attention_mask"] = torch.ones_like(pt_inputs["decoder_attention_mask"])
inputs_dict["decoder_attention_mask"] = jnp.ones_like(inputs_dict["decoder_attention_mask"])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as in the same pattern above, we should remove this

@github-actions github-actions bot closed this Jun 27, 2024
@zucchini-nlp
Copy link
Member Author

Will reopen this one later, as a new PR. It will need resolving merge conflicts and propagating changes to new models + PR comments.

@huggingface huggingface deleted a comment from github-actions bot Jun 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Generate: support passing position_ids
3 participants