Skip to content

Commit

Permalink
fix: sampling in flax keeps EOS (huggingface#28378)
Browse files Browse the repository at this point in the history
  • Loading branch information
borisdayma authored and wgifford committed Jan 21, 2024
1 parent bde8185 commit 690d007
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/generation/flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,8 +716,8 @@ def sample_search_body_fn(state):

next_token = jax.random.categorical(prng_key, logits, axis=-1)

next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
next_token = next_token[:, None]

next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
Expand Down

0 comments on commit 690d007

Please sign in to comment.