Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sort callbacks so that CheckpointSaver goes before HuggingFaceCheckpointer #1515

Merged
merged 5 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.distributed
from composer import ComposerModel, Trainer
from composer.callbacks.checkpoint_saver import CheckpointSaver
from composer.core.callback import Callback
from composer.profiler import (
JSONTraceHandler,
Expand Down Expand Up @@ -187,6 +188,24 @@ def _initialize_dist_with_barrier(dist_timeout: Union[int, float]):
log.debug('Barrier test passed with device.')


def _sort_callbacks(trainer: Trainer):
"""Sort callback so that checkpoint saving callbacks go first.

Args:
trainer (Trainer): Trainer object
"""

def _sort_key(c: Callback) -> int:
# CheckpointSaver goes before HuggingFaceCheckpointer because the blocking time is shortest while upload is async.
if isinstance(c, CheckpointSaver):
return 1
if isinstance(c, HuggingFaceCheckpointer):
return 2
return 0

trainer.state.callbacks = sorted(trainer.state.callbacks, key=_sort_key)


def train(cfg: DictConfig) -> Trainer:
code_paths = cfg.get('code_paths', [])
# Import any user provided code
Expand Down Expand Up @@ -548,6 +567,8 @@ def train(cfg: DictConfig) -> Trainer:
spin_dataloaders=train_cfg.spin_dataloaders,
)

_sort_callbacks(trainer)

# Optionally just save an HF checkpoint
if train_cfg.only_hf_checkpoint:
hf_checkpointer_callbacks = [
Expand Down
18 changes: 18 additions & 0 deletions tests/a_scripts/train/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
import os
import pathlib
from typing import Optional
from unittest.mock import Mock

import pytest
from composer.callbacks import CheckpointSaver
from composer.loggers import InMemoryLogger
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om

from llmfoundry.callbacks import HuggingFaceCheckpointer, RunTimeoutCallback
from llmfoundry.command_utils import TrainConfig # noqa: E402
from llmfoundry.command_utils import TRAIN_CONFIG_KEYS, train, validate_config
from llmfoundry.command_utils.train import _sort_callbacks
from llmfoundry.utils.config_utils import (
make_dataclass_and_log_config,
update_batch_size_info,
Expand Down Expand Up @@ -110,6 +114,20 @@ def test_train_gauntlet(averages: Optional[dict], tmp_path: pathlib.Path):
-1][-1] == 0


def test_sort_callbacks():
trainer_mock = Mock()
trainer_mock.state.callbacks = [
CheckpointSaver(),
HuggingFaceCheckpointer('save-folder', '1ba'),
RunTimeoutCallback(),
]
_sort_callbacks(trainer_mock)

assert isinstance(trainer_mock.state.callbacks[0], RunTimeoutCallback)
assert isinstance(trainer_mock.state.callbacks[1], CheckpointSaver)
assert isinstance(trainer_mock.state.callbacks[2], HuggingFaceCheckpointer)


def test_train_multi_eval(tmp_path: pathlib.Path):
"""Test training run with multiple eval datasets."""
c4_dataset_name = create_c4_dataset_xxsmall(tmp_path)
Expand Down
Loading