Skip to content

Commit

Permalink
Fix Seq2seqTrainer decoder attention mask (huggingface#26841)
Browse files Browse the repository at this point in the history
Don't drop decoder_input_ids without also dropping decoder_attention_mask
  • Loading branch information
Rocketknight1 authored and EduardoPach committed Nov 19, 2023
1 parent 21b4aee commit 58f99cf
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/transformers/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,9 @@ def prediction_step(
and "decoder_input_ids" in generation_inputs
and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
):
generation_inputs = {k: v for k, v in inputs.items() if k != "decoder_input_ids"}
generation_inputs = {
k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
}
generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)

# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
Expand Down

0 comments on commit 58f99cf

Please sign in to comment.