Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea committed Sep 9, 2024
1 parent 6e521c7 commit ab493fc
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.distributed
from composer import ComposerModel, Trainer
from composer.callbacks.checkpoint_saver import CheckpointSaver
from composer.core.callback import Callback
from composer.profiler import (
JSONTraceHandler,
Expand Down Expand Up @@ -187,6 +188,24 @@ def _initialize_dist_with_barrier(dist_timeout: Union[int, float]):
log.debug('Barrier test passed with device.')


def _sort_callbacks(trainer: Trainer):
"""Sort callback so that checkpoint saving callbacks go first.
Args:
trainer (Trainer): Trainer object
"""

def _sort_key(c: Callback) -> int:
# CheckpointSaver goes first because the blocking time is shortest while upload is async.
if isinstance(c, CheckpointSaver):
return 0
if isinstance(c, HuggingFaceCheckpointer):
return 1
return 2

trainer.state.callbacks = sorted(trainer.state.callbacks, key=_sort_key)


def train(cfg: DictConfig) -> Trainer:
code_paths = cfg.get('code_paths', [])
# Import any user provided code
Expand Down Expand Up @@ -548,14 +567,7 @@ def train(cfg: DictConfig) -> Trainer:
spin_dataloaders=train_cfg.spin_dataloaders,
)

from composer.callbacks.checkpoint_saver import CheckpointSaver

print('before', trainer.state.callbacks)


trainer.state.callbacks = sorted(trainer.state.callbacks, key=lambda c: 0 if isinstance(c, CheckpointSaver) else 1 if isinstance(c, HuggingFaceCheckpointer) else 2)

print('after', trainer.state.callbacks)
_sort_callbacks(trainer)

# Optionally just save an HF checkpoint
if train_cfg.only_hf_checkpoint:
Expand Down

0 comments on commit ab493fc

Please sign in to comment.