Skip to content

Commit

Permalink
Sort callbacks so that CheckpointSaver goes before HuggingFaceCheckpo…
Browse files Browse the repository at this point in the history
…inter (#1515)
  • Loading branch information
irenedea authored Sep 10, 2024
1 parent e8eca4f commit 8a8de18
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
21 changes: 21 additions & 0 deletions llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.distributed
from composer import ComposerModel, Trainer
from composer.callbacks.checkpoint_saver import CheckpointSaver
from composer.core.callback import Callback
from composer.profiler import (
JSONTraceHandler,
Expand Down Expand Up @@ -187,6 +188,24 @@ def _initialize_dist_with_barrier(dist_timeout: Union[int, float]):
log.debug('Barrier test passed with device.')


def _sort_callbacks(trainer: Trainer):
"""Sort callback so that checkpoint saving callbacks go first.
Args:
trainer (Trainer): Trainer object
"""

def _sort_key(c: Callback) -> int:
# CheckpointSaver goes before HuggingFaceCheckpointer because the blocking time is shortest while upload is async.
if isinstance(c, CheckpointSaver):
return 1
if isinstance(c, HuggingFaceCheckpointer):
return 2
return 0

trainer.state.callbacks = sorted(trainer.state.callbacks, key=_sort_key)


def train(cfg: DictConfig) -> Trainer:
code_paths = cfg.get('code_paths', [])
# Import any user provided code
Expand Down Expand Up @@ -548,6 +567,8 @@ def train(cfg: DictConfig) -> Trainer:
spin_dataloaders=train_cfg.spin_dataloaders,
)

_sort_callbacks(trainer)

# Optionally just save an HF checkpoint
if train_cfg.only_hf_checkpoint:
hf_checkpointer_callbacks = [
Expand Down
18 changes: 18 additions & 0 deletions tests/a_scripts/train/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
import os
import pathlib
from typing import Optional
from unittest.mock import Mock

import pytest
from composer.callbacks import CheckpointSaver
from composer.loggers import InMemoryLogger
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om

from llmfoundry.callbacks import HuggingFaceCheckpointer, RunTimeoutCallback
from llmfoundry.command_utils import TrainConfig # noqa: E402
from llmfoundry.command_utils import TRAIN_CONFIG_KEYS, train, validate_config
from llmfoundry.command_utils.train import _sort_callbacks
from llmfoundry.utils.config_utils import (
make_dataclass_and_log_config,
update_batch_size_info,
Expand Down Expand Up @@ -110,6 +114,20 @@ def test_train_gauntlet(averages: Optional[dict], tmp_path: pathlib.Path):
-1][-1] == 0


def test_sort_callbacks():
trainer_mock = Mock()
trainer_mock.state.callbacks = [
CheckpointSaver(),
HuggingFaceCheckpointer('save-folder', '1ba'),
RunTimeoutCallback(),
]
_sort_callbacks(trainer_mock)

assert isinstance(trainer_mock.state.callbacks[0], RunTimeoutCallback)
assert isinstance(trainer_mock.state.callbacks[1], CheckpointSaver)
assert isinstance(trainer_mock.state.callbacks[2], HuggingFaceCheckpointer)


def test_train_multi_eval(tmp_path: pathlib.Path):
"""Test training run with multiple eval datasets."""
c4_dataset_name = create_c4_dataset_xxsmall(tmp_path)
Expand Down

0 comments on commit 8a8de18

Please sign in to comment.