diff --git a/llmfoundry/command_utils/train.py b/llmfoundry/command_utils/train.py index b4361c6307..9119629d94 100644 --- a/llmfoundry/command_utils/train.py +++ b/llmfoundry/command_utils/train.py @@ -10,6 +10,7 @@ import torch import torch.distributed from composer import ComposerModel, Trainer +from composer.callbacks.checkpoint_saver import CheckpointSaver from composer.core.callback import Callback from composer.profiler import ( JSONTraceHandler, @@ -187,6 +188,24 @@ def _initialize_dist_with_barrier(dist_timeout: Union[int, float]): log.debug('Barrier test passed with device.') +def _sort_callbacks(trainer: Trainer): + """Sort callback so that checkpoint saving callbacks go first. + + Args: + trainer (Trainer): Trainer object + """ + + def _sort_key(c: Callback) -> int: + # CheckpointSaver goes first because the blocking time is shortest while upload is async. + if isinstance(c, CheckpointSaver): + return 0 + if isinstance(c, HuggingFaceCheckpointer): + return 1 + return 2 + + trainer.state.callbacks = sorted(trainer.state.callbacks, key=_sort_key) + + def train(cfg: DictConfig) -> Trainer: code_paths = cfg.get('code_paths', []) # Import any user provided code @@ -548,14 +567,7 @@ def train(cfg: DictConfig) -> Trainer: spin_dataloaders=train_cfg.spin_dataloaders, ) - from composer.callbacks.checkpoint_saver import CheckpointSaver - - print('before', trainer.state.callbacks) - - - trainer.state.callbacks = sorted(trainer.state.callbacks, key=lambda c: 0 if isinstance(c, CheckpointSaver) else 1 if isinstance(c, HuggingFaceCheckpointer) else 2) - - print('after', trainer.state.callbacks) + _sort_callbacks(trainer) # Optionally just save an HF checkpoint if train_cfg.only_hf_checkpoint: