Skip to content

Commit

Permalink
Fix for checkpoint rename race condition (huggingface#28364)
Browse files Browse the repository at this point in the history
* Changed logic for renaming staging directory when saving checkpoint to only operate with the main process.
Added fsync functionality to attempt to flush the write changes in case os.rename is not atomic.

* Updated styling using make fixup

* Updated check for main process to use built-in versions from trainer

Co-authored-by: Zach Mueller <[email protected]>

* Fixed incorrect usage of trainer main process checks
Added with open usage to ensure better file closing as suggested from PR
Added rotate_checkpoints into main process logic

* Removed "with open" due to not working with directory. os.open seems to work for directories.

---------

Co-authored-by: Zach Mueller <[email protected]>
  • Loading branch information
2 people authored and MadElf1337 committed Jan 15, 2024
1 parent e65ef40 commit 3554ef2
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2391,17 +2391,23 @@ def _save_checkpoint(self, model, trial, metrics=None):
# Place checkpoint in final location after all saving is finished.
# First wait for everyone to finish writing
self.args.distributed_state.wait_for_everyone()
# Then go through the rewriting process starting on process 0
if staging_output_dir != output_dir:
with self.args.main_process_first(
desc="Renaming model checkpoint folder to true location", local=self.args.save_on_each_node
):

# Then go through the rewriting process, only renaming and rotating from main process(es)
if self.is_local_process_zero() if self.args.save_on_each_node else self.is_world_process_zero():
if staging_output_dir != output_dir:
if os.path.exists(staging_output_dir):
os.rename(staging_output_dir, output_dir)

# Maybe delete some older checkpoints.
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
# Ensure rename completed in cases where os.rename is not atomic
fd = os.open(output_dir, os.O_RDONLY)
os.fsync(fd)
os.close(fd)

# Maybe delete some older checkpoints.
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

self.args.distributed_state.wait_for_everyone()

def _save_rng_state(self, output_dir):
# Save RNG state in non-distributed training
Expand Down

0 comments on commit 3554ef2

Please sign in to comment.