Skip to content

Commit

Permalink
Mamba & RecurrentGemma: enable strict signature (#31549)
Browse files Browse the repository at this point in the history
* enable strict signature

* this should not have been deleted

* recurrent_gemma too
  • Loading branch information
gante authored Jul 8, 2024
1 parent ae9dd02 commit 594c161
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 40 deletions.
63 changes: 27 additions & 36 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2692,13 +2692,12 @@ def _sample(
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
outputs = self(**model_inputs, return_dict=True)

if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
Expand Down Expand Up @@ -2919,6 +2918,10 @@ def _beam_search(
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

# if sequential is True, split the input to batches of batch_size and run sequentially
if sequential:
if any(
Expand All @@ -2944,24 +2947,13 @@ def _beam_search(
model_inputs, split_size=batch_size, full_batch_size=batch_beam_size
)
outputs_per_sub_batch = [
self(
**inputs_per_sub_batch,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
for inputs_per_sub_batch in inputs_per_sub_batches
self(**inputs_per_sub_batch, return_dict=True) for inputs_per_sub_batch in inputs_per_sub_batches
]

outputs = stack_model_outputs(outputs_per_sub_batch)

else: # Unchanged original behavior
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
outputs = self(**model_inputs, return_dict=True)

if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
Expand Down Expand Up @@ -3241,12 +3233,12 @@ def _group_beam_search(

# do one decoder step on all beams of all sentences in batch
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)

# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

outputs = self(**model_inputs, return_dict=True)

if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
Expand Down Expand Up @@ -3522,12 +3514,11 @@ def _constrained_beam_search(
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

outputs = self(**model_inputs, return_dict=True)

if synced_gpus and this_peer_finished:
cur_len = cur_len + 1
Expand Down Expand Up @@ -3793,11 +3784,11 @@ def _assisted_decoding(
model_inputs["num_logits_to_keep"] = candidate_length + 1

# 2.2. Run a forward pass on the candidate sequence
outputs = self(
**model_inputs,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

outputs = self(**model_inputs)

# 2.3. Process the new logits
new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,6 @@ def forward(
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs, # `attention_mask` is passed by the tokenizer and we don't want it
) -> Union[Tuple, MambaOutput]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -673,7 +672,6 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
**kwargs, # for now we need this for generation
) -> Union[Tuple, MambaCausalLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,6 @@ def forward(
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -823,7 +822,6 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
**kwargs, # for now we need this for generation
) -> Union[Tuple, CausalLMOutput]:
r"""
Args:
Expand Down

0 comments on commit 594c161

Please sign in to comment.