Skip to content

Commit

Permalink
Merge branch 'main' into final-register-only
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Sep 10, 2024
2 parents bf5cc67 + 8a8de18 commit f0363d4
Show file tree
Hide file tree
Showing 13 changed files with 53 additions and 10 deletions.
21 changes: 21 additions & 0 deletions llmfoundry/command_utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.distributed
from composer import ComposerModel, Trainer
from composer.callbacks.checkpoint_saver import CheckpointSaver
from composer.core.callback import Callback
from composer.profiler import (
JSONTraceHandler,
Expand Down Expand Up @@ -187,6 +188,24 @@ def _initialize_dist_with_barrier(dist_timeout: Union[int, float]):
log.debug('Barrier test passed with device.')


def _sort_callbacks(trainer: Trainer):
"""Sort callback so that checkpoint saving callbacks go first.
Args:
trainer (Trainer): Trainer object
"""

def _sort_key(c: Callback) -> int:
# CheckpointSaver goes before HuggingFaceCheckpointer because the blocking time is shortest while upload is async.
if isinstance(c, CheckpointSaver):
return 1
if isinstance(c, HuggingFaceCheckpointer):
return 2
return 0

trainer.state.callbacks = sorted(trainer.state.callbacks, key=_sort_key)


def train(cfg: DictConfig) -> Trainer:
code_paths = cfg.get('code_paths', [])
# Import any user provided code
Expand Down Expand Up @@ -548,6 +567,8 @@ def train(cfg: DictConfig) -> Trainer:
spin_dataloaders=train_cfg.spin_dataloaders,
)

_sort_callbacks(trainer)

# Optionally just save an HF checkpoint
if train_cfg.only_hf_checkpoint:
hf_checkpointer_callbacks = [
Expand Down
4 changes: 4 additions & 0 deletions llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def forward(
extra_kwargs = {}
if prev_layer_key_value is not None:
extra_kwargs['prev_layer_key_value'] = prev_layer_key_value
if key_value_states is not None:
extra_kwargs['key_value_states'] = key_value_states

if self.fuse_norm_attn_norm:
x, m, attn_weights, past_key_value = self.norm_attn_norm(
x,
Expand Down Expand Up @@ -336,7 +338,9 @@ def forward(
extra_kwargs = {}
if prev_layer_key_value is not None:
extra_kwargs['prev_layer_key_value'] = prev_layer_key_value
if key_value_states is not None:
extra_kwargs['key_value_states'] = key_value_states

b, attn_weights, past_key_value = self.attn(
a,
past_key_value=past_key_value,
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-1b-eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ integrations:
command: |
cd llm-foundry/scripts/
composer eval/eval.py /mnt/config/parameters.yaml
image: mosaicml/llm-foundry:2.4.0_cu124-latest
image: mosaicml/llm-foundry:2.3.1_cu121-latest
name: mpt-1b-eval

compute:
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-1b-max-seq-len-8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ command: |
--out_root ./my-copy-c4 --splits train_small val_small \
--concat_tokens 8192 --tokenizer EleutherAI/gpt-neox-20b --eos_text '<|endoftext|>'
composer train/train.py /mnt/config/parameters.yaml
image: mosaicml/llm-foundry:2.4.0_cu124-latest
image: mosaicml/llm-foundry:2.3.1_cu121-latest
name: mpt-1b-ctx-8k-gpus-8

compute:
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-1b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ command: |
eval_loader.dataset.split=val_small \
max_duration=100ba \
eval_interval=0
image: mosaicml/llm-foundry:2.4.0_cu124-latest
image: mosaicml/llm-foundry:2.3.1_cu121-latest
name: mpt-1b-gpus-8

compute:
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-benchmark-mpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ compute:
# cluster: TODO # Name of the cluster to use for this run
# gpu_type: a100_80gb # Type of GPU to use. We use a100_80gb in our experiments

image: mosaicml/llm-foundry:2.4.0_cu124-latest
image: mosaicml/llm-foundry:2.3.1_cu121-latest

integrations:
- integration_type: git_repo
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-convert-composer-to-hf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ command: |
--hf_output_path s3://bucket/folder/hf/ \
--output_precision bf16 \
image: mosaicml/llm-foundry:2.4.0_cu124-latest
image: mosaicml/llm-foundry:2.3.1_cu121-latest
name: convert-composer-hf

compute:
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-hf-eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ gpu_num: 8
# gpu_type:
# cluster: # replace with your cluster here!

image: mosaicml/llm-foundry:2.4.0_cu124-latest
image: mosaicml/llm-foundry:2.3.1_cu121-latest

# The below is injected as a YAML file: /mnt/config/parameters.yaml
parameters:
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-hf-generate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ command: |
"Here's a quick recipe for baking chocolate chip cookies: Start by" \
"The best 5 cities to visit in Europe are"
image: mosaicml/llm-foundry:2.4.0_cu124-latest
image: mosaicml/llm-foundry:2.3.1_cu121-latest
name: hf-generate

compute:
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-llama2-finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ integrations:
command: |
cd llm-foundry/scripts
composer train/train.py /mnt/config/parameters.yaml
image: mosaicml/llm-foundry:2.4.0_cu124-latest
image: mosaicml/llm-foundry:2.3.1_cu121-latest
name: llama2-finetune

compute:
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-openai-eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ gpu_num: #
gpu_type: #
cluster: # replace with your cluster here!

image: mosaicml/llm-foundry:2.4.0_cu124-latest
image: mosaicml/llm-foundry:2.3.1_cu121-latest

# The below is injected as a YAML file: /mnt/config/parameters.yaml
parameters:
Expand Down
2 changes: 1 addition & 1 deletion mcli/mcli-pretokenize-oci-upload.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: c4-2k-pre-tokenized
image: mosaicml/llm-foundry:2.4.0_cu124-latest
image: mosaicml/llm-foundry:2.3.1_cu121-latest
compute:
gpus: 8 # Number of GPUs to use

Expand Down
18 changes: 18 additions & 0 deletions tests/a_scripts/train/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
import os
import pathlib
from typing import Optional
from unittest.mock import Mock

import pytest
from composer.callbacks import CheckpointSaver
from composer.loggers import InMemoryLogger
from omegaconf import DictConfig, ListConfig
from omegaconf import OmegaConf as om

from llmfoundry.callbacks import HuggingFaceCheckpointer, RunTimeoutCallback
from llmfoundry.command_utils import TrainConfig # noqa: E402
from llmfoundry.command_utils import TRAIN_CONFIG_KEYS, train, validate_config
from llmfoundry.command_utils.train import _sort_callbacks
from llmfoundry.utils.config_utils import (
make_dataclass_and_log_config,
update_batch_size_info,
Expand Down Expand Up @@ -110,6 +114,20 @@ def test_train_gauntlet(averages: Optional[dict], tmp_path: pathlib.Path):
-1][-1] == 0


def test_sort_callbacks():
trainer_mock = Mock()
trainer_mock.state.callbacks = [
CheckpointSaver(),
HuggingFaceCheckpointer('save-folder', '1ba'),
RunTimeoutCallback(),
]
_sort_callbacks(trainer_mock)

assert isinstance(trainer_mock.state.callbacks[0], RunTimeoutCallback)
assert isinstance(trainer_mock.state.callbacks[1], CheckpointSaver)
assert isinstance(trainer_mock.state.callbacks[2], HuggingFaceCheckpointer)


def test_train_multi_eval(tmp_path: pathlib.Path):
"""Test training run with multiple eval datasets."""
c4_dataset_name = create_c4_dataset_xxsmall(tmp_path)
Expand Down

0 comments on commit f0363d4

Please sign in to comment.