diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 15032a9b4c0f52..48ecab67970f59 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2908,7 +2908,9 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None): else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") if self.args.save_safetensors: - safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}) + safetensors.torch.save_file( + state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"} + ) else: torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: