Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inputs embeds in generation #30269

Merged
merged 9 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,7 +1369,8 @@ def forward(
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input) * self.embed_scale
inputs_embeds = self.embed_tokens(input)
inputs_embeds = inputs_embeds * self.embed_scale
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we change this -- it is a breaking change; the same inputs_embeds input will yield a different output after this PR on a well-established model.

Does it solve any problem from the VLM side?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, afaik these models are not used as backbones in VLMs, so prob no real benefit for now.

Close the PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp We can keep the PR, it's useful! Not a priority, but since it's done let's add it :D

Let's just revert this line 🤗

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm that will make fail tests for equivalence between input_ids and input_embeds. Then the better decision if to disable totally inputs embeds for these models. Okay, will do


if self._use_flash_attention_2:
# 2d mask is passed through the layers
Expand Down Expand Up @@ -2276,7 +2277,7 @@ def forward(
)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
Expand All @@ -2294,12 +2295,19 @@ def prepare_inputs_for_generation(

input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}

model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
)
return model_inputs

@staticmethod
def _reorder_cache(past_key_values, beam_idx):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2168,8 +2168,9 @@ def forward(
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)

inputs_embeds *= self.embed_scale
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
Expand Down Expand Up @@ -3043,7 +3044,7 @@ def forward(
)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
Expand All @@ -3052,12 +3053,20 @@ def prepare_inputs_for_generation(
if past_key_values:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}

model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
)
return model_inputs

@staticmethod
def _reorder_cache(past_key_values, beam_idx):
Expand Down
21 changes: 14 additions & 7 deletions src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,7 +1562,7 @@ def forward(
)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
Expand All @@ -1580,12 +1580,19 @@ def prepare_inputs_for_generation(

input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}

model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
)
return model_inputs

@staticmethod
def _reorder_cache(past_key_values, beam_idx):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,7 @@ def forward(
)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
Expand All @@ -1550,12 +1550,19 @@ def prepare_inputs_for_generation(

input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}

model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
)
return model_inputs

@staticmethod
def _reorder_cache(past_key_values, beam_idx):
Expand Down
26 changes: 17 additions & 9 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def get_output_embeddings(self):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# Omit tokens covered by past_key_values
if past_key_values:
Expand All @@ -621,14 +621,22 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}

model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
)
return model_inputs

@add_start_docstrings_to_model_forward(CODEGEN_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down
22 changes: 15 additions & 7 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,7 @@ def prepare_inputs_for_generation(
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
if past_key_values is not None:
Expand All @@ -1234,13 +1235,20 @@ def prepare_inputs_for_generation(
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]

return {
"input_ids": input_ids,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}

model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs

@add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
Expand Down
26 changes: 17 additions & 9 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,7 +1430,7 @@ def get_output_embeddings(self):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# Omit tokens covered by past_key_values
if past_key_values:
Expand Down Expand Up @@ -1459,14 +1459,22 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
else:
position_ids = None

return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}

model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
)
return model_inputs

@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
Expand Down
21 changes: 14 additions & 7 deletions src/transformers/models/marian/modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,7 +1684,7 @@ def forward(
)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
Expand All @@ -1702,12 +1702,19 @@ def prepare_inputs_for_generation(

input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}

model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
)
return model_inputs

@staticmethod
def _reorder_cache(past_key_values, beam_idx):
Expand Down
21 changes: 14 additions & 7 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -2096,7 +2096,7 @@ def forward(
)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
Expand All @@ -2114,12 +2114,19 @@ def prepare_inputs_for_generation(

input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}

model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
)
return model_inputs

@staticmethod
def _reorder_cache(past_key_values, beam_idx):
Expand Down
21 changes: 14 additions & 7 deletions src/transformers/models/pegasus/modeling_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,7 +1658,7 @@ def forward(
)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
Expand All @@ -1676,12 +1676,19 @@ def prepare_inputs_for_generation(

input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting - what's the reason for needing to add .contiguous here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not much of a reason right now, as it is part of compile compatibilty. Decided to leave in case these models will be reworked for compile also


model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
)
return model_inputs

@staticmethod
def _reorder_cache(past_key_values, beam_idx):
Expand Down
Loading
Loading