Skip to content

Commit

Permalink
other callbacks go first and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Sep 9, 2024
1 parent ab493fc commit 898812d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
6 changes: 3 additions & 3 deletions llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@ def _sort_callbacks(trainer: Trainer):
def _sort_key(c: Callback) -> int:
# CheckpointSaver goes first because the blocking time is shortest while upload is async.
if isinstance(c, CheckpointSaver):
return 0
if isinstance(c, HuggingFaceCheckpointer):
return 1
return 2
if isinstance(c, HuggingFaceCheckpointer):
return 2
return 0

trainer.state.callbacks = sorted(trainer.state.callbacks, key=_sort_key)

Expand Down
18 changes: 18 additions & 0 deletions tests/a_scripts/train/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
import os
import pathlib
from typing import Optional
from unittest.mock import Mock

import pytest
from composer.callbacks import CheckpointSaver
from composer.loggers import InMemoryLogger
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om

from llmfoundry.callbacks import HuggingFaceCheckpointer, RunTimeoutCallback
from llmfoundry.command_utils import TrainConfig # noqa: E402
from llmfoundry.command_utils import TRAIN_CONFIG_KEYS, train, validate_config
from llmfoundry.command_utils.train import _sort_callbacks
from llmfoundry.utils.config_utils import (
make_dataclass_and_log_config,
update_batch_size_info,
Expand Down Expand Up @@ -110,6 +114,20 @@ def test_train_gauntlet(averages: Optional[dict], tmp_path: pathlib.Path):
-1][-1] == 0


def test_sort_callbacks():
trainer_mock = Mock()
trainer_mock.state.callbacks = [
CheckpointSaver(),
HuggingFaceCheckpointer('save-folder', '1ba'),
RunTimeoutCallback(),
]
_sort_callbacks(trainer_mock)

assert isinstance(trainer_mock.state.callbacks[0], RunTimeoutCallback)
assert isinstance(trainer_mock.state.callbacks[1], CheckpointSaver)
assert isinstance(trainer_mock.state.callbacks[2], HuggingFaceCheckpointer)


def test_train_multi_eval(tmp_path: pathlib.Path):
"""Test training run with multiple eval datasets."""
c4_dataset_name = create_c4_dataset_xxsmall(tmp_path)
Expand Down

0 comments on commit 898812d

Please sign in to comment.