From 6a912ff2c5b3eadb9a0583d77083aae27d35d28d Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Fri, 22 Nov 2024 08:25:14 +0100 Subject: [PATCH] Watermarking: fix order (#34849) fix watermarking order --- src/transformers/generation/utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e3657550d0e7de..1e94e9d1ef875e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1029,10 +1029,6 @@ def _get_logits_processor( "You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument " "in favour of `input_ids` or `decoder_input_ids` respectively.", ) - if generation_config.watermarking_config is not None: - processors.append( - generation_config.watermarking_config.construct_processor(self.config.vocab_size, device) - ) # TODO (joao): find a strategy to specify the order of the processors processors = self._merge_criteria_processor_list(processors, logits_processor) @@ -1085,6 +1081,12 @@ def _get_logits_processor( ) ) + # Watermarking should be after all logits processing is finished (see #34630) + if generation_config.watermarking_config is not None: + processors.append( + generation_config.watermarking_config.construct_processor(self.config.vocab_size, device) + ) + # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: processors.append(LogitNormalization())