Skip to content

Commit

Permalink
Fix TE HF checkpoint saving (#1280)
Browse files Browse the repository at this point in the history
  • Loading branch information
j316chuck authored Jun 18, 2024
1 parent 618db6f commit c23be4a
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 27 deletions.
22 changes: 19 additions & 3 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import torch
import torch.nn as nn
from composer.core import Callback, Event, State, Time, TimeUnit
from composer.core import Callback, Event, Precision, State, Time, TimeUnit
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger, MLFlowLogger
from composer.models import HuggingFaceModel
Expand All @@ -37,6 +37,12 @@
from llmfoundry.utils.huggingface_hub_utils import \
edit_files_for_hf_compatibility

try:
import transformer_engine.pytorch as te
is_te_imported = True
except ModuleNotFoundError:
is_te_imported = False

log = logging.getLogger(__name__)

__all__ = ['HuggingFaceCheckpointer']
Expand Down Expand Up @@ -486,9 +492,19 @@ def dtensor_to_tensor_hook(
)

log.debug('Saving Hugging Face checkpoint to disk')
new_model_instance.save_pretrained(temp_save_dir)
# This context manager casts the TE extra state in io.BytesIO format to tensor format
# Needed for proper hf ckpt saving.
context_manager = te.onnx_export(
True,
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(original_tokenizer, PreTrainedTokenizerBase)
assert isinstance(
original_tokenizer,
PreTrainedTokenizerBase,
)
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
Expand Down
78 changes: 54 additions & 24 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import contextlib
import json
import math
import os
Expand Down Expand Up @@ -468,6 +469,7 @@ def _get_model_and_tokenizer(
model: str,
max_seq_len: int,
tie_word_embeddings: bool,
precision: str,
):
if model == 'mpt':
model_cfg = {
Expand All @@ -482,6 +484,7 @@ def _get_model_and_tokenizer(
'attn_config': {
'attn_impl': 'torch',
},
'fc_type': 'te' if precision == 'amp_fp8' else 'torch',
'loss_fn': 'torch_crossentropy',
'tie_word_embeddings': tie_word_embeddings,
}
Expand Down Expand Up @@ -783,8 +786,9 @@ def _assert_checkpoint_equivalence(
)
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None])
@pytest.mark.parametrize(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('1ba', '1ba', '1ba', 1, 1)],
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints,trainer_precision',
[('1ba', '1ba', '1ba', 1, 1, 'amp_bf16'),
('1ba', '1ba', '1ba', 1, 1, 'amp_fp8')],
)
@patch('os.cpu_count', MagicMock(return_value=1))
@patch(
Expand All @@ -801,10 +805,30 @@ def test_huggingface_conversion_callback(
max_duration: str,
expected_hf_checkpoints: int,
expected_normal_checkpoints: int,
trainer_precision: str,
peft_config: Optional[dict],
):
if model == 'mptmoe' and fsdp_state_dict_type is None:
pytest.skip('mptmoe requires FSDP')
if trainer_precision == 'amp_fp8':
# Check if transformer-engine is installed for FP8.
try:
import transformer_engine.pytorch as te
except ImportError:
pytest.skip(
'Precision amp_fp8 requires transformer-engine to be installed',
)

# Check we are using mpt models only for FP8.
if (model == 'neo' or model == 'llama2'):
pytest.skip(
'Precision amp_fp8 works only for mpt models, not hf models',
)

# Check that we are using H100 or later for FP8.
if not (torch.cuda.get_device_capability() >= (8, 9)):
pytest.skip('Amp FP8 requires a GPU with compute capability >= 8.9')

delete_transformers_cache()

dist.initialize_dist(get_device('gpu'))
Expand All @@ -825,9 +849,10 @@ def test_huggingface_conversion_callback(

# Get small version of each model
model_cfg, tokenizer_name = _get_model_and_tokenizer(
model,
max_seq_len,
tie_word_embeddings,
model=model,
max_seq_len=max_seq_len,
tie_word_embeddings=tie_word_embeddings,
precision=trainer_precision,
)
assert model_cfg is not None
assert tokenizer_name is not None
Expand Down Expand Up @@ -883,7 +908,7 @@ def test_huggingface_conversion_callback(
trainer = Trainer(
model=original_model,
device='gpu',
precision='amp_bf16',
precision=trainer_precision,
fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None,
train_dataloader=train_dataloader,
save_folder=os.path.join(tmp_path, 'checkpoints'),
Expand All @@ -900,24 +925,29 @@ def test_huggingface_conversion_callback(

# summon full params to check equivalence
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
with FSDP.summon_full_params(
trainer.state.model,
writeback=False,
recurse=True,
):
_assert_checkpoint_equivalence(
tmp_path=tmp_path,
expected_normal_checkpoints=expected_normal_checkpoints,
expected_hf_checkpoints=expected_hf_checkpoints,
trainer=trainer,
batches_per_epoch=batches_per_epoch,
original_model=original_model,
precision=precision,
model=model,
tokenizer=tokenizer,
fsdp_state_dict_type=fsdp_state_dict_type,
peft_config=peft_config,
)

context_manager = te.onnx_export( # type: ignore
True,
) if trainer_precision == 'amp_fp8' else contextlib.nullcontext()
with context_manager:
with FSDP.summon_full_params(
trainer.state.model,
writeback=False,
recurse=True,
):
_assert_checkpoint_equivalence(
tmp_path=tmp_path,
expected_normal_checkpoints=expected_normal_checkpoints,
expected_hf_checkpoints=expected_hf_checkpoints,
trainer=trainer,
batches_per_epoch=batches_per_epoch,
original_model=original_model,
precision=precision,
model=model,
tokenizer=tokenizer,
fsdp_state_dict_type=fsdp_state_dict_type,
peft_config=peft_config,
)

dist.barrier()
delete_transformers_cache()
Expand Down

0 comments on commit c23be4a

Please sign in to comment.