diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index affc7b725e8a70..3b69d0114321e3 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -4643,10 +4643,9 @@ def create_accelerator_and_postprocess(self): wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP" raise ValueError(f"{wrapper} can't be used with `save_only_model` along with `load_best_model_at_end`.") - # `auto_find_batch_size` isn't yet supported with DeepSpeed/FSDP - if (self.is_deepspeed_enabled or self.is_fsdp_enabled) and self.args.auto_find_batch_size: - wrapper = "DeepSpeed" if self.is_deepspeed_enabled else "FSDP" - raise NotImplementedError(f"`{wrapper}` doesn't support `auto_find_batch_size`.") + # `auto_find_batch_size` isn't yet supported with DeepSpeed + if self.is_deepspeed_enabled and self.args.auto_find_batch_size: + raise NotImplementedError(f"`DeepSpeed` doesn't support `auto_find_batch_size`.") def propagate_args_to_deepspeed(self, auto_find_batch_size=False): """