Skip to content

Commit

Permalink
feat: update ._contrastive(streamer)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmarx committed Sep 10, 2024
1 parent 50536b7 commit 76775a4
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2818,8 +2818,6 @@ def _contrastive_search(

# Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
if self.config.is_encoder_decoder:
next_step_cross_attentions = ()
next_step_decoder_attentions = ()
if output_attentions:
for layer in outputs.cross_attentions:
layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
Expand Down Expand Up @@ -2856,7 +2854,21 @@ def _contrastive_search(
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
output_stub = self._prepare_output(
return_dict_in_generate=return_dict_in_generate,
sequences=next_tokens,
scores=(processed_logit_for_next_step,), # (scores,),
logits=(processed_logit_for_next_step,),
# I think there's an issue with the contrastive sampling implementation that is currently returning the same values for logits as scores #(logits[selected_idx,:],), #(logit_for_next_step,), # `logit_for_next_step`: values don't match, `logits`: shapes don't match
encoder_attentions=None, # probably doesn't make sense to stream this
encoder_hidden_states=None, # probably doesn't make sense to stream this
decoder_attentions=(next_step_decoder_attentions,),
# ([0],),# very concerning that if I set this to `([0],)` my tests don't fail
cross_attentions=(next_step_cross_attentions,),
decoder_hidden_states=(next_decoder_hidden_states,),
past_key_values=None, # probably doesn't make sense to stream this
)
streamer.put(output_stub)
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
Expand Down

0 comments on commit 76775a4

Please sign in to comment.