Skip to content

Commit

Permalink
Fix batch size handling in prediction_loop for DataLoaderShard (huggi…
Browse files Browse the repository at this point in the history
…ngface#34343)

* Fix batch size handling in prediction_loop for DataLoaderShard

Updated the prediction_loop method in the Trainer class to correctly handle batch size when using DataLoaderShard. This ensures that the batch size is retrieved from total_batch_size for distributed training scenarios, preventing TypeError related to NoneType during evaluation.

* Update src/transformers/trainer.py

Co-authored-by: Zach Mueller <[email protected]>

* Applied the fix to remove unused imports

---------

Co-authored-by: Zach Mueller <[email protected]>
  • Loading branch information
2 people authored and BernardZach committed Dec 6, 2024
1 parent 4cff4e7 commit 1e84943
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4714,7 +4714,17 @@ def prediction_loop(
elif args.bf16_full_eval:
model = model.to(dtype=torch.bfloat16, device=args.device)

batch_size = dataloader.batch_size
batch_size = (
dataloader.total_batch_size
if getattr(dataloader, "_is_accelerate_prepared", False)
else dataloader.batch_size
)

if batch_size is None:
raise ValueError(
"Batch size cannot be None. Ensure the dataloader has a valid batch_size or total_batch_size."
)

num_examples = self.num_examples(dataloader)
logger.info(f"\n***** Running {description} *****")
logger.info(f" Num examples = {num_examples}")
Expand Down

0 comments on commit 1e84943

Please sign in to comment.