Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calculate position ids in modeling utils for all generative models #30053

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d7fcb07
prepare position ids in modeling utils
zucchini-nlp Apr 4, 2024
eaecb4f
fix seq length when inputs embeds
zucchini-nlp Apr 4, 2024
d24796c
Merge remote-tracking branch 'upstream/main' into position_ids
zucchini-nlp Apr 4, 2024
cf66b56
forgot to fix starcoder2
zucchini-nlp Apr 5, 2024
ecba33f
fix copies
zucchini-nlp Apr 5, 2024
b769074
remove that print :)
zucchini-nlp Apr 9, 2024
b96725f
lets add same for assisted decoding
zucchini-nlp Apr 9, 2024
49f3495
Merge remote-tracking branch 'upstream/main' into position_ids
zucchini-nlp Apr 15, 2024
8e7a2bd
Merge remote-tracking branch 'upstream/main' into position_ids
zucchini-nlp Apr 17, 2024
ff4e424
framework equivalence?
zucchini-nlp Apr 17, 2024
5531118
final solution, lets make all frameworks same
zucchini-nlp Apr 18, 2024
6341955
Merge branch 'main' into position_ids
zucchini-nlp Apr 18, 2024
ce37742
new models
zucchini-nlp Apr 22, 2024
e28e551
tf fix cast
zucchini-nlp Apr 22, 2024
7e5e3bf
tf equivalence
zucchini-nlp Apr 22, 2024
15ad877
remove extra if conditions
zucchini-nlp Apr 23, 2024
0d9ee9f
make test parameterized
zucchini-nlp Apr 23, 2024
cbe4394
Merge remote-tracking branch 'upstream/main' into position_ids
zucchini-nlp Apr 25, 2024
0f20d92
fix failing flax cases
zucchini-nlp Apr 25, 2024
e749080
torch tests fail due to merge conflicts?
zucchini-nlp Apr 25, 2024
ef2494e
let the tests pass
zucchini-nlp Apr 29, 2024
d5f5989
import if available
zucchini-nlp Apr 29, 2024
cd10c73
fixes
zucchini-nlp Apr 29, 2024
0f1997c
encoder-decoder models
zucchini-nlp Apr 29, 2024
87befb7
fix llama flax
zucchini-nlp Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,
self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
)
self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)
self.assistant_kwargs = _prepare_position_ids(self.assistant_kwargs, new_cur_len)

# 2. Forecast next N tokens using the assistant model.
assistant_generation_kwargs = {
Expand Down Expand Up @@ -423,3 +424,18 @@ def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Di
token_type_copies = final_token_type.repeat(1, type_length_diff)
model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1)
return model_kwargs


def _prepare_position_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
position_ids = model_kwargs.get("position_ids")
if position_ids is None:
return model_kwargs

# we assume batch_size=1 for assited decoding (needs rework if bs > 1)
length_diff = new_length - position_ids[0, -1]
if length_diff < 0:
position_ids = position_ids[:, :length_diff]
elif length_diff > 0:
new_position_ids = torch.arange(position_ids[0, -1], new_length, device=position_ids.device).unsqueeze(0)
Comment on lines +436 to +439
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment briefly explaining when each situation can be triggered, and why we want that operation? Our future selves will probably be happy with that comment

e.g. I'm assuming length_diff > 0 is used when candidates are proposed, and thus we want the corresponding position ids. But I'm not immediately seeing when length_diff < 0 can be triggered :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ this function still needs better variable names and/or a docstring

model_kwargs["position_ids"] = torch.cat([model_kwargs["position_ids"], new_position_ids], dim=-1)
return model_kwargs
6 changes: 6 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
PromptLookupCandidateGenerator,
_crop_past_key_values,
_prepare_attention_mask,
_prepare_position_ids,
_prepare_token_type_ids,
)
from .configuration_utils import GenerationConfig, GenerationMode
Expand Down Expand Up @@ -674,6 +675,10 @@ def _update_model_kwargs_for_generation(
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens

if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None:
position_ids = model_kwargs["position_ids"]
model_kwargs["position_ids"] = torch.cat([position_ids, position_ids[:, -1:] + 1], dim=-1)

return model_kwargs

def _reorder_cache(self, past_key_values, beam_idx):
Expand Down Expand Up @@ -4685,6 +4690,7 @@ def _assisted_decoding(
candidate_kwargs = _prepare_attention_mask(
candidate_kwargs, candidate_input_ids.shape[1], self.config.is_encoder_decoder
)
candidate_kwargs = _prepare_position_ids(candidate_kwargs, candidate_input_ids.shape[1])
candidate_kwargs = _prepare_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
if "cache_position" in candidate_kwargs:
candidate_kwargs["cache_position"] = torch.cat(
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,19 @@ def register_for_auto_class(cls, auto_class="FlaxAutoModel"):

cls._auto_class = auto_class

def get_position_ids_from_attention_mask(self, attention_mask, batch_size, seq_length):
"""
Tries to infer position ids given attention mask and past kv cache length. All instances when
`position_ids=None` should call this method.
"""
if attention_mask is not None:
position_ids = jnp.cumsum(attention_mask, axis=-1) - 1
position_ids = jnp.where(attention_mask == 0, 1, position_ids)
position_ids = position_ids[..., -seq_length:]
else:
position_ids = jnp.broadcast_to(jnp.arange(seq_length)[None, :], (batch_size, seq_length))
return position_ids


# To update the docstring, we need to copy the method, otherwise we change the original docstring.
FlaxPreTrainedModel.push_to_hub = copy_func(FlaxPreTrainedModel.push_to_hub)
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4368,6 +4368,20 @@ def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask):

logger.warning_once(warn_string)

def get_position_ids_from_attention_mask(self, attention_mask, past_length, seq_length, device):
"""
Tries to infer position ids given attention mask and past kv cache length. All instances when
`position_ids=None` should call this method.
"""
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids = position_ids.masked_fill(attention_mask == 0, 1)
position_ids = position_ids[..., -seq_length:].view(-1, seq_length)
else:
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
return position_ids

@property
def _is_quantized_training_enabled(self):
warnings.warn(
Expand Down
28 changes: 16 additions & 12 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,8 @@ def forward(
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)

if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
Expand All @@ -467,8 +468,9 @@ def forward(
past_length = past_key_values[0][0].size(-2)

if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
position_ids = self.get_position_ids_from_attention_mask(
attention_mask, past_length, seq_length=inputs_embeds.shape[1], device=inputs_embeds.device
)

# Attention mask.
if attention_mask is not None:
Expand Down Expand Up @@ -496,9 +498,6 @@ def forward(
# head_mask has shape n_layer x batch x num_attention_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)

if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)

hidden_states = inputs_embeds

if token_type_ids is not None:
Expand Down Expand Up @@ -597,6 +596,7 @@ def set_output_embeddings(self, new_embeddings):
def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# Omit tokens covered by past_key_values
past_length = 0
if past_key_values:
past_length = past_key_values[0][0].shape[2]

Expand All @@ -614,12 +614,16 @@ def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
seq_length = (
inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1]
)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = self.get_position_ids_from_attention_mask(
attention_mask, past_length, seq_length=seq_length, device=device
)
else:
position_ids = position_ids[:, -seq_length:]
Comment on lines +617 to +626
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove all this code, actually 👀 I see the following cases:

  1. position_ids is None -> the forward pass correctly computes position_ids, due to the changes in this PR
  2. position_ids is not None -> the user has defined position_ids, it's its own responsibility to pass them correctly

WDYT? (this logic would apply to all models, and would make maintenance easier for us 👼 )


# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
Expand Down
20 changes: 13 additions & 7 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,9 @@ def forward(
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)
position_ids = self.get_position_ids_from_attention_mask(
attention_mask, past_seen_tokens, seq_length=inputs_embeds.shape[1], device=inputs_embeds.device
)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)

Expand Down Expand Up @@ -1227,12 +1229,16 @@ def prepare_inputs_for_generation(
attention_mask = attention_mask[:, -max_cache_length:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
seq_length = (
inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1]
)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = self.get_position_ids_from_attention_mask(
attention_mask, past_length, seq_length=seq_length, device=device
)
else:
position_ids = position_ids[:, -seq_length:]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
Expand Down
17 changes: 15 additions & 2 deletions src/transformers/models/ctrl/modeling_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,11 @@ def forward(
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)

if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
position_ids = self.get_position_ids_from_attention_mask(
attention_mask, past_length, seq_length=input_shape[1], device=device
)

# Attention mask.
if attention_mask is not None:
Expand Down Expand Up @@ -525,6 +527,7 @@ def set_output_embeddings(self, new_embeddings):

def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs):
# only last tokens for inputs_ids if past is defined in kwargs
past_length = 0
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]

Expand All @@ -537,6 +540,16 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cac

input_ids = input_ids[:, remove_prefix_length:]

attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)

if position_ids is None:
position_ids = self.get_position_ids_from_attention_mask(
attention_mask, past_length, seq_length=input_ids.shape[1], device=input_ids.device
)
else:
position_ids = position_ids[:, -input_ids.shape[1] :]

return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache}

@add_start_docstrings_to_model_forward(CTRL_INPUTS_DOCSTRING)
Expand Down
15 changes: 12 additions & 3 deletions src/transformers/models/ctrl/modeling_tf_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,15 @@ def call(
else:
past_length = shape_list(past_key_values[0][0])[-2]
if position_ids is None:
position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length, dtype=tf.int32), axis=0)
position_ids = tf.tile(position_ids, [input_shape[0], 1])
if attention_mask is not None:
position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1
# create ones tensor to match dtypes, otherwise we get errors
ones_tensor = tf.ones_like(position_ids, dtype=tf.int64)
position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids)
position_ids = position_ids[..., -input_shape[-1] :]
position_ids = tf.reshape(position_ids, (-1, input_shape[-1]))
else:
position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
Comment on lines +348 to +355
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1
# create ones tensor to match dtypes, otherwise we get errors
ones_tensor = tf.ones_like(position_ids, dtype=tf.int64)
position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids)
position_ids = position_ids[..., -input_shape[-1] :]
position_ids = tf.reshape(position_ids, (-1, input_shape[-1]))
else:
position_ids = tf.expand_dims(tf.range(past_length, input_shape[-1] + past_length), axis=0)
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)

(see comment below)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the same logic applies to other TF models


# Attention mask.
if attention_mask is not None:
Expand Down Expand Up @@ -702,7 +709,9 @@ def prepare_inputs_for_generation(self, inputs, past_key_values=None, use_cache=
attention_mask = kwargs.get("attention_mask", None)

if attention_mask is not None and position_ids is None:
position_ids = tf.math.cumsum(attention_mask, axis=-1, exclusive=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one should be correct, no? 🤔

(the same comment applies to other TF models)

position_ids = tf.cumsum(tf.cast(attention_mask, tf.int64), axis=-1) - 1
ones_tensor = tf.ones_like(position_ids, dtype=tf.int64)
position_ids = tf.where(attention_mask == 0, ones_tensor, position_ids)
if past_key_values:
position_ids = tf.expand_dims(position_ids[:, -1], -1)

Expand Down
21 changes: 14 additions & 7 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,10 @@ def forward(
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)
position_ids = self.get_position_ids_from_attention_mask(
attention_mask, past_seen_tokens, seq_length=inputs_embeds.shape[1], device=inputs_embeds.device
)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)

# embed positions
Expand Down Expand Up @@ -1470,12 +1473,16 @@ def prepare_inputs_for_generation(
attention_mask = attention_mask[:, -max_cache_length:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
seq_length = (
inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1]
)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = self.get_position_ids_from_attention_mask(
attention_mask, past_length, seq_length=seq_length, device=device
)
else:
position_ids = position_ids[:, -seq_length:]

if self.generation_config.cache_implementation == "static":
# generation with static cache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ def forward(
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}

print(self.encoder)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(self.encoder)

if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -737,12 +737,10 @@ def prepare_inputs_for_generation(
# Thus we can create a single static attention_mask here, which is more efficient for compilation
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
if decoder_attention_mask is not None:
decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
else:
decoder_position_ids = jnp.broadcast_to(
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)
)
decoder_position_ids = self.get_position_ids_from_attention_mask(
decoder_attention_mask, batch_size, seq_length
)

return {
"past_key_values": past_key_values,
Expand Down
24 changes: 14 additions & 10 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,11 +1077,9 @@ def forward(
else:
alibi = None
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
position_ids = self.get_position_ids_from_attention_mask(
attention_mask, past_key_values_length, seq_length=seq_length, device=inputs_embeds.device
)
position_ids = position_ids.unsqueeze(0)

if self._use_flash_attention_2:
# 2d mask is passed through the layers
Expand Down Expand Up @@ -1215,6 +1213,7 @@ def prepare_inputs_for_generation(
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
past_length = 0
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]

Expand All @@ -1228,12 +1227,17 @@ def prepare_inputs_for_generation(
input_ids = input_ids[:, remove_prefix_length:]

# Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE.
if not self.transformer.use_alibi and attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if not self.transformer.use_alibi:
seq_length = (
inputs_embeds.shape[1] if inputs_embeds is not None and past_key_values is None else input_ids.shape[1]
)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = self.get_position_ids_from_attention_mask(
attention_mask, past_length, seq_length=seq_length, device=device
)
else:
position_ids = position_ids[:, -seq_length:]

if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
Expand Down
Loading
Loading