From 76775a41b6c6fe98da92d95deefa065455291eee Mon Sep 17 00:00:00 2001 From: David Marx Date: Wed, 4 Sep 2024 20:10:31 -0700 Subject: [PATCH] feat: update ._contrastive(streamer) --- src/transformers/generation/utils.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0c3f220e2a0436..3e90d6b999bbc2 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2818,8 +2818,6 @@ def _contrastive_search( # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration if self.config.is_encoder_decoder: - next_step_cross_attentions = () - next_step_decoder_attentions = () if output_attentions: for layer in outputs.cross_attentions: layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...] @@ -2856,7 +2854,21 @@ def _contrastive_search( # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: - streamer.put(next_tokens.cpu()) + output_stub = self._prepare_output( + return_dict_in_generate=return_dict_in_generate, + sequences=next_tokens, + scores=(processed_logit_for_next_step,), # (scores,), + logits=(processed_logit_for_next_step,), + # I think there's an issue with the contrastive sampling implementation that is currently returning the same values for logits as scores #(logits[selected_idx,:],), #(logit_for_next_step,), # `logit_for_next_step`: values don't match, `logits`: shapes don't match + encoder_attentions=None, # probably doesn't make sense to stream this + encoder_hidden_states=None, # probably doesn't make sense to stream this + decoder_attentions=(next_step_decoder_attentions,), + # ([0],),# very concerning that if I set this to `([0],)` my tests don't fail + cross_attentions=(next_step_cross_attentions,), + decoder_hidden_states=(next_decoder_hidden_states,), + past_key_values=None, # probably doesn't make sense to stream this + ) + streamer.put(output_stub) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs,