Skip to content

Commit

Permalink
FIX Automatic checkpoint path inference issue (huggingface#1989)
Browse files Browse the repository at this point in the history
Resolves huggingface#1983

Fixes an issue where the checkpoint directory would be incorrectly set while
loading when using relative paths.
  • Loading branch information
BenjaminBossan authored Sep 19, 2023
1 parent 03deec2 commit 80da9cf
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2876,7 +2876,7 @@ def _inner(folder):
return list(map(int, re.findall(r"[\/]?([0-9]+)(?=[^\/]*$)", folder)))[0]

folders.sort(key=_inner)
input_dir = os.path.join(input_dir, folders[-1])
input_dir = folders[-1]
else:
raise ValueError("No input_dir provided and automatic checkpoint naming is disabled.")
logger.info(f"Loading states from {input_dir}")
Expand Down
67 changes: 67 additions & 0 deletions tests/test_state_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import shutil
import tempfile
import unittest
import uuid
from contextlib import contextmanager

import pytest
import torch
Expand Down Expand Up @@ -201,6 +203,71 @@ def test_can_resume_training(self):
self.assertEqual(opt_state1, opt_state3)
self.assertEqual(ground_truth_rands, test_rands)

def test_can_resume_training_checkpoints_relative_path(self):
# See #1983
# This test is like test_can_resume_training but uses a relative path for the checkpoint and automatically
# infers the checkpoint path when loading.
@contextmanager
def temporary_relative_directory():
# This is equivalent to tempfile.TemporaryDirectory() except that it returns a relative path
rand_dir = f"test_path_{uuid.uuid4()}"
os.mkdir(rand_dir)
try:
yield rand_dir
finally:
shutil.rmtree(rand_dir)

with temporary_relative_directory() as tmpdir:
set_seed(42)
model = DummyModel()
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
train_dataloader, valid_dataloader = dummy_dataloaders()
project_config = ProjectConfiguration(automatic_checkpoint_naming=True)

# Train baseline
accelerator = Accelerator(project_dir=tmpdir, project_config=project_config)
model, optimizer, train_dataloader, valid_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, valid_dataloader
)
# Save initial
accelerator.save_state()
(a, b) = model.a.item(), model.b.item()
opt_state = optimizer.state_dict()
ground_truth_rands = train(3, model, train_dataloader, optimizer, accelerator)
(a1, b1) = model.a.item(), model.b.item()
opt_state1 = optimizer.state_dict()

# Train partially
set_seed(42)
model = DummyModel()
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
train_dataloader, valid_dataloader = dummy_dataloaders()
project_config = ProjectConfiguration(iteration=1, automatic_checkpoint_naming=True)
accelerator = Accelerator(project_dir=tmpdir, project_config=project_config)
model, optimizer, train_dataloader, valid_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, valid_dataloader
)
accelerator.load_state() # <= infer the directory automatically
(a2, b2) = model.a.item(), model.b.item()
opt_state2 = optimizer.state_dict()
self.assertEqual(a, a2)
self.assertEqual(b, b2)
self.assertEqual(opt_state, opt_state2)

test_rands = train(2, model, train_dataloader, optimizer, accelerator)
# Save everything
accelerator.save_state()

# Load everything back in and make sure all states work
accelerator.load_state(os.path.join(tmpdir, "checkpoints", "checkpoint_1"))
test_rands += train(1, model, train_dataloader, optimizer, accelerator)
(a3, b3) = model.a.item(), model.b.item()
opt_state3 = optimizer.state_dict()
self.assertEqual(a1, a3)
self.assertEqual(b1, b3)
self.assertEqual(opt_state1, opt_state3)
self.assertEqual(ground_truth_rands, test_rands)

def test_invalid_registration(self):
t = torch.tensor([1, 2, 3])
t1 = torch.tensor([2, 3, 4])
Expand Down

0 comments on commit 80da9cf

Please sign in to comment.