From 9e299f326eabc5af76843d58637ae85a833257d1 Mon Sep 17 00:00:00 2001 From: Timothy Blattner Date: Wed, 10 Jan 2024 10:55:42 -0500 Subject: [PATCH] Fix for checkpoint rename race condition (#28364) * Changed logic for renaming staging directory when saving checkpoint to only operate with the main process. Added fsync functionality to attempt to flush the write changes in case os.rename is not atomic. * Updated styling using make fixup * Updated check for main process to use built-in versions from trainer Co-authored-by: Zach Mueller * Fixed incorrect usage of trainer main process checks Added with open usage to ensure better file closing as suggested from PR Added rotate_checkpoints into main process logic * Removed "with open" due to not working with directory. os.open seems to work for directories. --------- Co-authored-by: Zach Mueller --- src/transformers/trainer.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index cd46eb4d1c14bb..5aa903a80ea7c0 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2391,17 +2391,23 @@ def _save_checkpoint(self, model, trial, metrics=None): # Place checkpoint in final location after all saving is finished. # First wait for everyone to finish writing self.args.distributed_state.wait_for_everyone() - # Then go through the rewriting process starting on process 0 - if staging_output_dir != output_dir: - with self.args.main_process_first( - desc="Renaming model checkpoint folder to true location", local=self.args.save_on_each_node - ): + + # Then go through the rewriting process, only renaming and rotating from main process(es) + if self.is_local_process_zero() if self.args.save_on_each_node else self.is_world_process_zero(): + if staging_output_dir != output_dir: if os.path.exists(staging_output_dir): os.rename(staging_output_dir, output_dir) - # Maybe delete some older checkpoints. - if self.args.should_save: - self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + # Ensure rename completed in cases where os.rename is not atomic + fd = os.open(output_dir, os.O_RDONLY) + os.fsync(fd) + os.close(fd) + + # Maybe delete some older checkpoints. + if self.args.should_save: + self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + + self.args.distributed_state.wait_for_everyone() def _save_rng_state(self, output_dir): # Save RNG state in non-distributed training