From 898812d9edff8e8754e26f5761f6bfcf091e1df5 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 9 Sep 2024 20:56:06 +0000 Subject: [PATCH] other callbacks go first and add tests --- llmfoundry/command_utils/train.py | 6 +++--- tests/a_scripts/train/test_train.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index 9119629d94..3897fbffd3 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -198,10 +198,10 @@ def _sort_callbacks(trainer: Trainer): def _sort_key(c: Callback) -> int: # CheckpointSaver goes first because the blocking time is shortest while upload is async. if isinstance(c, CheckpointSaver): - return 0 - if isinstance(c, HuggingFaceCheckpointer): return 1 - return 2 + if isinstance(c, HuggingFaceCheckpointer): + return 2 + return 0 trainer.state.callbacks = sorted(trainer.state.callbacks, key=_sort_key) diff --git a/tests/a_scripts/train/test_train.py b/tests/a_scripts/train/test_train.py index 1f724a6070..9af96f9868 100644 --- a/tests/a_scripts/train/test_train.py +++ b/tests/a_scripts/train/test_train.py @@ -5,14 +5,18 @@ import os import pathlib from typing import Optional +from unittest.mock import Mock import pytest +from composer.callbacks import CheckpointSaver from composer.loggers import InMemoryLogger from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om +from llmfoundry.callbacks import HuggingFaceCheckpointer, RunTimeoutCallback from llmfoundry.command_utils import TrainConfig # noqa: E402 from llmfoundry.command_utils import TRAIN_CONFIG_KEYS, train, validate_config +from llmfoundry.command_utils.train import _sort_callbacks from llmfoundry.utils.config_utils import ( make_dataclass_and_log_config, update_batch_size_info, @@ -110,6 +114,20 @@ def test_train_gauntlet(averages: Optional[dict], tmp_path: pathlib.Path): -1][-1] == 0 +def test_sort_callbacks(): + trainer_mock = Mock() + trainer_mock.state.callbacks = [ + CheckpointSaver(), + HuggingFaceCheckpointer('save-folder', '1ba'), + RunTimeoutCallback(), + ] + _sort_callbacks(trainer_mock) + + assert isinstance(trainer_mock.state.callbacks[0], RunTimeoutCallback) + assert isinstance(trainer_mock.state.callbacks[1], CheckpointSaver) + assert isinstance(trainer_mock.state.callbacks[2], HuggingFaceCheckpointer) + + def test_train_multi_eval(tmp_path: pathlib.Path): """Test training run with multiple eval datasets.""" c4_dataset_name = create_c4_dataset_xxsmall(tmp_path)