From 8a8de18d31156f1acfbeb3b6853e17aa90573eb9 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Tue, 10 Sep 2024 10:07:36 -0700 Subject: [PATCH] Sort callbacks so that CheckpointSaver goes before HuggingFaceCheckpointer (#1515) --- llmfoundry/command_utils/train.py | 21 +++++++++++++++++++++ tests/a_scripts/train/test_train.py | 18 ++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index 8e6309175a..73fa4c8d5a 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -10,6 +10,7 @@ import torch import torch.distributed from composer import ComposerModel, Trainer +from composer.callbacks.checkpoint_saver import CheckpointSaver from composer.core.callback import Callback from composer.profiler import ( JSONTraceHandler, @@ -187,6 +188,24 @@ def _initialize_dist_with_barrier(dist_timeout: Union[int, float]): log.debug('Barrier test passed with device.') +def _sort_callbacks(trainer: Trainer): + """Sort callback so that checkpoint saving callbacks go first. + + Args: + trainer (Trainer): Trainer object + """ + + def _sort_key(c: Callback) -> int: + # CheckpointSaver goes before HuggingFaceCheckpointer because the blocking time is shortest while upload is async. + if isinstance(c, CheckpointSaver): + return 1 + if isinstance(c, HuggingFaceCheckpointer): + return 2 + return 0 + + trainer.state.callbacks = sorted(trainer.state.callbacks, key=_sort_key) + + def train(cfg: DictConfig) -> Trainer: code_paths = cfg.get('code_paths', []) # Import any user provided code @@ -548,6 +567,8 @@ def train(cfg: DictConfig) -> Trainer: spin_dataloaders=train_cfg.spin_dataloaders, ) + _sort_callbacks(trainer) + # Optionally just save an HF checkpoint if train_cfg.only_hf_checkpoint: hf_checkpointer_callbacks = [ diff --git a/tests/a_scripts/train/test_train.py b/tests/a_scripts/train/test_train.py index 1f724a6070..9af96f9868 100644 --- a/tests/a_scripts/train/test_train.py +++ b/tests/a_scripts/train/test_train.py @@ -5,14 +5,18 @@ import os import pathlib from typing import Optional +from unittest.mock import Mock import pytest +from composer.callbacks import CheckpointSaver from composer.loggers import InMemoryLogger from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om +from llmfoundry.callbacks import HuggingFaceCheckpointer, RunTimeoutCallback from llmfoundry.command_utils import TrainConfig # noqa: E402 from llmfoundry.command_utils import TRAIN_CONFIG_KEYS, train, validate_config +from llmfoundry.command_utils.train import _sort_callbacks from llmfoundry.utils.config_utils import ( make_dataclass_and_log_config, update_batch_size_info, @@ -110,6 +114,20 @@ def test_train_gauntlet(averages: Optional[dict], tmp_path: pathlib.Path): -1][-1] == 0 +def test_sort_callbacks(): + trainer_mock = Mock() + trainer_mock.state.callbacks = [ + CheckpointSaver(), + HuggingFaceCheckpointer('save-folder', '1ba'), + RunTimeoutCallback(), + ] + _sort_callbacks(trainer_mock) + + assert isinstance(trainer_mock.state.callbacks[0], RunTimeoutCallback) + assert isinstance(trainer_mock.state.callbacks[1], CheckpointSaver) + assert isinstance(trainer_mock.state.callbacks[2], HuggingFaceCheckpointer) + + def test_train_multi_eval(tmp_path: pathlib.Path): """Test training run with multiple eval datasets.""" c4_dataset_name = create_c4_dataset_xxsmall(tmp_path)