Skip to content

Commit

Permalink
[fix bug] logits's shape different from label's shape in preprocess_l…
Browse files Browse the repository at this point in the history
…ogits_for_metrics (#31447)

* [fix BUG] pad labels before use it in preprocess_logits_for_metrics

* a more readable fix

labels can't use  `gather` before pass to `preprocess_logits_for_metrics`, so must split into 2 if-block

* add a comment

* oh code quality check
  • Loading branch information
wiserxin authored Jul 3, 2024
1 parent 65a02cd commit 534cbf8
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3839,6 +3839,9 @@ def evaluation_loop(
inputs_decode = self.gather_function((inputs_decode))
if not self.args.batch_eval_metrics or description == "Prediction":
all_inputs.add(inputs_decode)
if labels is not None:
# Pad labels here, preparing for preprocess_logits_for_metrics in next logits block.
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
if logits is not None:
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
if self.preprocess_logits_for_metrics is not None:
Expand All @@ -3847,7 +3850,6 @@ def evaluation_loop(
if not self.args.batch_eval_metrics or description == "Prediction":
all_preds.add(logits)
if labels is not None:
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
labels = self.gather_function((labels))
if not self.args.batch_eval_metrics or description == "Prediction":
all_labels.add(labels)
Expand Down

0 comments on commit 534cbf8

Please sign in to comment.