From 690d0076591ad60761dbcd7696e1faaf678059c8 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Mon, 15 Jan 2024 11:12:09 -0700 Subject: [PATCH] fix: sampling in flax keeps EOS (#28378) --- src/transformers/generation/flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 4fce8970f8647c..1e063be8638650 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -716,8 +716,8 @@ def sample_search_body_fn(state): next_token = jax.random.categorical(prng_key, logits, axis=-1) + next_token = next_token * ~state.is_sent_finished + pad_token_id * state.is_sent_finished next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id) - next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished next_token = next_token[:, None] next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))