Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for checkpoint rename race condition #28364

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2382,17 +2382,23 @@ def _save_checkpoint(self, model, trial, metrics=None):
# Place checkpoint in final location after all saving is finished.
# First wait for everyone to finish writing
self.args.distributed_state.wait_for_everyone()
# Then go through the rewriting process starting on process 0
if staging_output_dir != output_dir:
with self.args.main_process_first(
desc="Renaming model checkpoint folder to true location", local=self.args.save_on_each_node
):

# Then go through the rewriting process, only renaming and rotating from main process(es)
if self.is_local_process_zero() if self.args.save_on_each_node else self.is_world_process_zero():
if staging_output_dir != output_dir:
if os.path.exists(staging_output_dir):
os.rename(staging_output_dir, output_dir)

# Maybe delete some older checkpoints.
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
# Ensure rename completed in cases where os.rename is not atomic
fd = os.open(output_dir, os.O_RDONLY)
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
os.fsync(fd)
os.close(fd)

# Maybe delete some older checkpoints.
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

self.args.distributed_state.wait_for_everyone()

def _save_rng_state(self, output_dir):
# Save RNG state in non-distributed training
Expand Down
Loading