From 93766251cb0e07afa8e6e25dfeacf525db39cead Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 13 Dec 2023 12:17:30 -0500 Subject: [PATCH] Fix bug with rotating checkpoints (#28009) * Fix bug * Write test * Keep back old modification for grad accum steps * Whitespace... * Whitespace again * Race condition * Wait for everyone --- src/transformers/trainer.py | 7 ++++++- tests/trainer/test_trainer_distributed.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d6ccc4334dd46d..3a4ff5528047ae 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2382,8 +2382,13 @@ def _save_checkpoint(self, model, trial, metrics=None): self._push_from_checkpoint(staging_output_dir) # Place checkpoint in final location after all saving is finished. + # First wait for everyone to finish writing + self.args.distributed_state.wait_for_everyone() + # Then go through the rewriting process starting on process 0 if staging_output_dir != output_dir: - os.rename(staging_output_dir, output_dir) + with self.args.main_process_first(desc="Renaming model checkpoint folder to true location"): + if os.path.exists(staging_output_dir): + os.rename(staging_output_dir, output_dir) # Maybe delete some older checkpoints. if self.args.should_save: diff --git a/tests/trainer/test_trainer_distributed.py b/tests/trainer/test_trainer_distributed.py index 8f867cf0beba37..2850d6c40b4e1c 100644 --- a/tests/trainer/test_trainer_distributed.py +++ b/tests/trainer/test_trainer_distributed.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path from typing import Dict import numpy as np @@ -236,6 +237,20 @@ def compute_metrics(p: EvalPrediction) -> Dict: trainer.args.eval_accumulation_steps = None + # Check that saving does indeed work with temp dir rotation + # If this fails, will see a FileNotFoundError + model = RegressionModel() + training_args.max_steps = 1 + opt = torch.optim.Adam(model.parameters(), lr=1e-3) + sched = torch.optim.lr_scheduler.LambdaLR(opt, lambda x: 1) + trainer = Trainer( + model, training_args, optimizers=(opt, sched), data_collator=DummyDataCollator(), eval_dataset=dataset + ) + trainer._save_checkpoint(model=None, trial=None) + # Check that the temp folder does not exist + assert not (Path(training_args.output_dir) / "tmp-checkpoint-0").exists() + assert (Path(training_args.output_dir) / "checkpoint-0").exists() + # Check that `dispatch_batches=False` will work on a finite iterable dataset train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1)