Skip to content

Commit

Permalink
Fixed incorrect usage of trainer main process checks
Browse files Browse the repository at this point in the history
Added with open usage to ensure better file closing as suggested from PR
Added rotate_checkpoints into main process logic
  • Loading branch information
tblattner committed Jan 9, 2024
1 parent 3c1c92f commit df9f6e9
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2382,25 +2382,23 @@ def _save_checkpoint(self, model, trial, metrics=None):
# Place checkpoint in final location after all saving is finished.
# First wait for everyone to finish writing
self.args.distributed_state.wait_for_everyone()
# Then go through the rewriting process, only renaming from main process(es)
if staging_output_dir != output_dir:
if (
self.is_local_process_zero
if self.args.save_on_each_node
else self.is_world_process_zero
):

# Then go through the rewriting process, only renaming and rotating from main process(es)
if self.is_local_process_zero() if self.args.save_on_each_node else self.is_world_process_zero():
if staging_output_dir != output_dir:
if os.path.exists(staging_output_dir):
os.rename(staging_output_dir, output_dir)

# Ensure rename completed in cases where os.rename is not atomic
fd = os.open(output_dir, os.O_RDONLY)
os.fsync(fd)
with open(output_dir, "r") as f:
f.flush()
os.fsync(f.fileno())

self.args.distributed_state.wait_for_everyone()
# Maybe delete some older checkpoints.
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

# Maybe delete some older checkpoints.
if self.args.should_save:
self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
self.args.distributed_state.wait_for_everyone()

def _save_rng_state(self, output_dir):
# Save RNG state in non-distributed training
Expand Down

0 comments on commit df9f6e9

Please sign in to comment.