Skip to content

Commit

Permalink
Fix bug with rotating checkpoints (#28009)
Browse files Browse the repository at this point in the history
* Fix bug

* Write test

* Keep back old modification for grad accum steps

* Whitespace...

* Whitespace again

* Race condition

* Wait for everyone
  • Loading branch information
muellerzr authored Dec 13, 2023
1 parent ec43d68 commit 9376625
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2382,8 +2382,13 @@ def _save_checkpoint(self, model, trial, metrics=None):
self._push_from_checkpoint(staging_output_dir)

# Place checkpoint in final location after all saving is finished.
# First wait for everyone to finish writing
self.args.distributed_state.wait_for_everyone()
# Then go through the rewriting process starting on process 0
if staging_output_dir != output_dir:
os.rename(staging_output_dir, output_dir)
with self.args.main_process_first(desc="Renaming model checkpoint folder to true location"):
if os.path.exists(staging_output_dir):
os.rename(staging_output_dir, output_dir)

# Maybe delete some older checkpoints.
if self.args.should_save:
Expand Down
15 changes: 15 additions & 0 deletions tests/trainer/test_trainer_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Dict

import numpy as np
Expand Down Expand Up @@ -236,6 +237,20 @@ def compute_metrics(p: EvalPrediction) -> Dict:

trainer.args.eval_accumulation_steps = None

# Check that saving does indeed work with temp dir rotation
# If this fails, will see a FileNotFoundError
model = RegressionModel()
training_args.max_steps = 1
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
sched = torch.optim.lr_scheduler.LambdaLR(opt, lambda x: 1)
trainer = Trainer(
model, training_args, optimizers=(opt, sched), data_collator=DummyDataCollator(), eval_dataset=dataset
)
trainer._save_checkpoint(model=None, trial=None)
# Check that the temp folder does not exist
assert not (Path(training_args.output_dir) / "tmp-checkpoint-0").exists()
assert (Path(training_args.output_dir) / "checkpoint-0").exists()

# Check that `dispatch_batches=False` will work on a finite iterable dataset

train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1)
Expand Down

0 comments on commit 9376625

Please sign in to comment.