Skip to content

Commit

Permalink
Watermarking: fix order (#34849)
Browse files Browse the repository at this point in the history
fix watermarking order
  • Loading branch information
zucchini-nlp authored Nov 22, 2024
1 parent 4e90b99 commit 6a912ff
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,10 +1029,6 @@ def _get_logits_processor(
"You have explicitly specified `forced_decoder_ids`. Please remove the `forced_decoder_ids` argument "
"in favour of `input_ids` or `decoder_input_ids` respectively.",
)
if generation_config.watermarking_config is not None:
processors.append(
generation_config.watermarking_config.construct_processor(self.config.vocab_size, device)
)

# TODO (joao): find a strategy to specify the order of the processors
processors = self._merge_criteria_processor_list(processors, logits_processor)
Expand Down Expand Up @@ -1085,6 +1081,12 @@ def _get_logits_processor(
)
)

# Watermarking should be after all logits processing is finished (see #34630)
if generation_config.watermarking_config is not None:
processors.append(
generation_config.watermarking_config.construct_processor(self.config.vocab_size, device)
)

# `LogitNormalization` should always be the last logit processor, when present
if generation_config.renormalize_logits is True:
processors.append(LogitNormalization())
Expand Down

0 comments on commit 6a912ff

Please sign in to comment.