diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index cd46eb4d1c14bb..5aa903a80ea7c0 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2391,17 +2391,23 @@ def _save_checkpoint(self, model, trial, metrics=None): # Place checkpoint in final location after all saving is finished. # First wait for everyone to finish writing self.args.distributed_state.wait_for_everyone() - # Then go through the rewriting process starting on process 0 - if staging_output_dir != output_dir: - with self.args.main_process_first( - desc="Renaming model checkpoint folder to true location", local=self.args.save_on_each_node - ): + + # Then go through the rewriting process, only renaming and rotating from main process(es) + if self.is_local_process_zero() if self.args.save_on_each_node else self.is_world_process_zero(): + if staging_output_dir != output_dir: if os.path.exists(staging_output_dir): os.rename(staging_output_dir, output_dir) - # Maybe delete some older checkpoints. - if self.args.should_save: - self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + # Ensure rename completed in cases where os.rename is not atomic + fd = os.open(output_dir, os.O_RDONLY) + os.fsync(fd) + os.close(fd) + + # Maybe delete some older checkpoints. + if self.args.should_save: + self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + + self.args.distributed_state.wait_for_everyone() def _save_rng_state(self, output_dir): # Save RNG state in non-distributed training