-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Calculate position ids in modeling utils for all generative models #30053
Changes from all commits
d7fcb07
eaecb4f
d24796c
cf66b56
ecba33f
b769074
b96725f
49f3495
8e7a2bd
ff4e424
5531118
6341955
ce37742
e28e551
7e5e3bf
15ad877
0d9ee9f
cbe4394
0f20d92
e749080
ef2494e
d5f5989
cd10c73
0f1997c
87befb7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -455,7 +455,8 @@ def forward( | |
else: | ||
raise ValueError("You have to specify either input_ids or inputs_embeds") | ||
|
||
device = input_ids.device if input_ids is not None else inputs_embeds.device | ||
if inputs_embeds is None: | ||
inputs_embeds = self.wte(input_ids) | ||
|
||
if token_type_ids is not None: | ||
token_type_ids = token_type_ids.view(-1, input_shape[-1]) | ||
|
@@ -467,8 +468,9 @@ def forward( | |
past_length = past_key_values[0][0].size(-2) | ||
|
||
if position_ids is None: | ||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) | ||
position_ids = position_ids.unsqueeze(0) | ||
position_ids = self.get_position_ids_from_attention_mask( | ||
attention_mask, past_length, seq_length=inputs_embeds.shape[1], device=inputs_embeds.device | ||
) | ||
|
||
# Attention mask. | ||
if attention_mask is not None: | ||
|
@@ -496,9 +498,6 @@ def forward( | |
# head_mask has shape n_layer x batch x num_attention_heads x N x N | ||
head_mask = self.get_head_mask(head_mask, self.config.n_layer) | ||
|
||
if inputs_embeds is None: | ||
inputs_embeds = self.wte(input_ids) | ||
|
||
hidden_states = inputs_embeds | ||
|
||
if token_type_ids is not None: | ||
|
@@ -597,6 +596,7 @@ def set_output_embeddings(self, new_embeddings): | |
def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs): | ||
token_type_ids = kwargs.get("token_type_ids", None) | ||
# Omit tokens covered by past_key_values | ||
past_length = 0 | ||
if past_key_values: | ||
past_length = past_key_values[0][0].shape[2] | ||
|
||
|
@@ -614,12 +614,16 @@ def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_ | |
attention_mask = kwargs.get("attention_mask", None) | ||
position_ids = kwargs.get("position_ids", None) | ||
|
||
if attention_mask is not None and position_ids is None: | ||
# create position_ids on the fly for batch generation | ||
position_ids = attention_mask.long().cumsum(-1) - 1 | ||
position_ids.masked_fill_(attention_mask == 0, 1) | ||
if past_key_values: | ||
position_ids = position_ids[:, -input_ids.shape[1] :] | ||
seq_length = ( | ||
inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1] | ||
) | ||
if position_ids is None: | ||
device = input_ids.device if input_ids is not None else inputs_embeds.device | ||
position_ids = self.get_position_ids_from_attention_mask( | ||
attention_mask, past_length, seq_length=seq_length, device=device | ||
) | ||
else: | ||
position_ids = position_ids[:, -seq_length:] | ||
Comment on lines
+617
to
+626
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can remove all this code, actually 👀 I see the following cases:
WDYT? (this logic would apply to all models, and would make maintenance easier for us 👼 ) |
||
|
||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step | ||
if inputs_embeds is not None and past_key_values is None: | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -344,8 +344,15 @@ def call( | |||||||||||||||||||
else: | ||||||||||||||||||||
past_length = shape_list(past_key_values[0][0])[-2] | ||||||||||||||||||||
if position_ids is None: | ||||||||||||||||||||
position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32), axis=0) | ||||||||||||||||||||
position_ids = tf.tile(position_ids, [input_shape[0], 1]) | ||||||||||||||||||||
if attention_mask is not None: | ||||||||||||||||||||
position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 | ||||||||||||||||||||
# create ones tensor to match dtypes, otherwise we get errors | ||||||||||||||||||||
ones_tensor = tf.ones_like(position_ids, dtype=tf.int64) | ||||||||||||||||||||
position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids) | ||||||||||||||||||||
position_ids = position_ids[..., -input_shape[-1] :] | ||||||||||||||||||||
position_ids = tf.reshape(position_ids, (-1, input_shape[-1])) | ||||||||||||||||||||
else: | ||||||||||||||||||||
position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0) | ||||||||||||||||||||
Comment on lines
+348
to
+355
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
(see comment below) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the same logic applies to other TF models |
||||||||||||||||||||
|
||||||||||||||||||||
# Attention mask. | ||||||||||||||||||||
if attention_mask is not None: | ||||||||||||||||||||
|
@@ -702,7 +709,9 @@ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache= | |||||||||||||||||||
attention_mask = kwargs.get("attention_mask", None) | ||||||||||||||||||||
|
||||||||||||||||||||
if attention_mask is not None and position_ids is None: | ||||||||||||||||||||
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True) | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this one should be correct, no? 🤔 (the same comment applies to other TF models) |
||||||||||||||||||||
position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1 | ||||||||||||||||||||
ones_tensor = tf.ones_like(position_ids, dtype=tf.int64) | ||||||||||||||||||||
position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids) | ||||||||||||||||||||
if past_key_values: | ||||||||||||||||||||
position_ids = tf.expand_dims(position_ids[:, -1], -1) | ||||||||||||||||||||
|
||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -593,6 +593,7 @@ def forward( | |||
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") | ||||
} | ||||
|
||||
print(self.encoder) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
if encoder_outputs is None: | ||||
encoder_outputs = self.encoder( | ||||
input_ids=input_ids, | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment briefly explaining when each situation can be triggered, and why we want that operation? Our future selves will probably be happy with that comment
e.g. I'm assuming
length_diff > 0
is used when candidates are proposed, and thus we want the corresponding position ids. But I'm not immediately seeing whenlength_diff < 0
can be triggered :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
^ this function still needs better variable names and/or a docstring