Skip to content

Commit

Permalink
Add Trainer.maybe_load_checkpoint method
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Aug 28, 2024
1 parent 1744420 commit 6758b08
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 15 deletions.
24 changes: 23 additions & 1 deletion src/olmo_core/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]):
)

def load_checkpoint(
self, dir: PathOrStr, load_optimizer_state: bool = True, load_trainer_state: bool = True
self, dir: PathOrStr, *, load_optimizer_state: bool = True, load_trainer_state: bool = True
):
"""
Load a checkpoint.
Expand Down Expand Up @@ -543,6 +543,28 @@ def load_checkpoint(

log.info("Checkpoint successfully loaded")

def maybe_load_checkpoint(
self, dir: PathOrStr, *, load_optimizer_state: bool = True, load_trainer_state: bool = True
) -> bool:
"""
Like :meth:`load_checkpoint()` but is a no-op if there is no checkpoint in the ``dir`` provided.
:returns: If a checkpoint was loaded.
"""
should_load: bool = True
if get_rank() == 0:
should_load = self.checkpointer.contains_checkpoint(dir)
should_load = scatter_object(should_load)
if should_load:
self.load_checkpoint(
dir,
load_optimizer_state=load_optimizer_state,
load_trainer_state=load_trainer_state,
)
else:
log.warning(f"No checkpoint found in '{dir}', will train from scratch...")
return should_load

def record_metric(
self, name: str, value: Union[float, torch.Tensor], reduce_type: Optional[ReduceType] = None
):
Expand Down
25 changes: 11 additions & 14 deletions src/scripts/train/OLMo-7B.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,20 +226,17 @@ def train(config: ExperimentConfig):
dataset = config.dataset.build()
trainer = config.trainer.build(model, optim, dataset)

if (load_path := config.load_path) is not None:
# Maybe load a checkpoint.
should_load: bool = True
if config.load_strategy == LoadStrategy.never:
should_load = False
elif config.load_strategy == LoadStrategy.if_available:
if get_rank() == 0:
should_load = trainer.checkpointer.contains_checkpoint(load_path)
should_load = scatter_object(should_load)

if should_load:
trainer.load_checkpoint(load_path)
elif get_rank() == 0:
# Save config to file.
# Maybe load a checkpoint.
checkpoint_loaded = False
if config.load_strategy == LoadStrategy.always:
assert config.load_path is not None
trainer.load_checkpoint(config.load_path)
checkpoint_loaded = True
elif config.load_strategy == LoadStrategy.if_available and config.load_path is not None:
checkpoint_loaded = trainer.maybe_load_checkpoint(config.load_path)

# Otherwise save the config to file.
if not checkpoint_loaded and get_rank() == 0:
trainer.write_file("config.json", json.dumps(config_dict, indent=2))

# Train.
Expand Down

0 comments on commit 6758b08

Please sign in to comment.