diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e3657550d0e7de..1e94e9d1ef875e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1029,10 +1029,6 @@ def _get_logits_processor( "You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument " "in favour of `input_ids` or `decoder_input_ids` respectively.", ) - if generation_config.watermarking_config is not None: - processors.append( - generation_config.watermarking_config.construct_processor(self.config.vocab_size, device) - ) # TODO (joao): find a strategy to specify the order of the processors processors = self._merge_criteria_processor_list(processors, logits_processor) @@ -1085,6 +1081,12 @@ def _get_logits_processor( ) ) + # Watermarking should be after all logits processing is finished (see #34630) + if generation_config.watermarking_config is not None: + processors.append( + generation_config.watermarking_config.construct_processor(self.config.vocab_size, device) + ) + # `LogitNormalization` should always be the last logit processor, when present if generation_config.renormalize_logits is True: processors.append(LogitNormalization())