Skip to content

Commit

Permalink
feat: add _prepare_output, update contrastive
Browse files Browse the repository at this point in the history
  • Loading branch information
dmarx committed Sep 5, 2024
1 parent 3dfb339 commit 5edfe64
Showing 1 changed file with 44 additions and 23 deletions.
67 changes: 44 additions & 23 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,33 @@ class GenerationMixin:
To learn more about decoding strategies refer to the [text generation strategies guide](../generation_strategies).
"""

def _prepare_output(
self, *,
return_dict_in_generate,
**output_kargs):
if return_dict_in_generate:
if self.config.is_encoder_decoder:
cls = GenerateEncoderDecoderOutput
else:
cls =GenerateDecoderOnlyOutput
if 'decoder_attentions' in output_kargs:
output_kargs['attentions'] = output_kargs.pop('decoder_attentions')
if 'decoder_hidden_states' in output_kargs:
output_kargs['hidden_states'] = output_kargs.pop('decoder_hidden_states')

if 'encoder_attentions' in output_kargs:
output_kargs.pop('encoder_attentions')
if 'encoder_hidden_states' in output_kargs:
output_kargs.pop('encoder_hidden_states')
if 'cross_attentions' in output_kargs:
output_kargs.pop('cross_attentions')

outv = cls(**output_kargs)
else:
outv = output_kargs['sequences']
return outv


def prepare_inputs_for_generation(self, *args, **kwargs):
raise NotImplementedError(
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`."
Expand Down Expand Up @@ -2546,6 +2573,11 @@ def _contrastive_search(
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

# initialize variables for self._prepare_output
encoder_attentions = encoder_hidden_states = None
next_step_cross_attentions = ()
next_step_decoder_attentions = ()

# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
Expand Down Expand Up @@ -2851,29 +2883,18 @@ def _contrastive_search(
past_key_values.append(tuple(layer_past_key_values))
model_kwargs["past_key_values"] = tuple(past_key_values)

if self.config.is_encoder_decoder:
return GenerateEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
return self._prepare_output(
return_dict_in_generate=return_dict_in_generate,
sequences=input_ids,
scores=scores,
logits=raw_logits,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values")
)

def _sample(
self,
Expand Down

0 comments on commit 5edfe64

Please sign in to comment.