Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TE HF checkpoint saving #1280

Merged
merged 13 commits into from
Jun 18, 2024
26 changes: 21 additions & 5 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import torch
import torch.nn as nn
from composer.core import Callback, Event, State, Time, TimeUnit
from composer.core import Callback, Event, Precision, State, Time, TimeUnit
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger, MLFlowLogger
from composer.models import HuggingFaceModel
Expand All @@ -37,6 +37,12 @@
from llmfoundry.utils.huggingface_hub_utils import \
edit_files_for_hf_compatibility

try:
import transformer_engine.pytorch as te
is_te_imported = True
except ModuleNotFoundError:
is_te_imported = False

log = logging.getLogger(__name__)

__all__ = ['HuggingFaceCheckpointer']
Expand Down Expand Up @@ -486,10 +492,20 @@ def dtensor_to_tensor_hook(
)

log.debug('Saving Hugging Face checkpoint to disk')
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(original_tokenizer, PreTrainedTokenizerBase)
original_tokenizer.save_pretrained(temp_save_dir)
# This context manager casts the TE extra state in io.BytesIO format to tensor format
# Needed for proper hf ckpt saving.
context_manager = te.onnx_export(
j316chuck marked this conversation as resolved.
Show resolved Hide resolved
True,
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
j316chuck marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(
original_tokenizer,
PreTrainedTokenizerBase,
)
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if original_model.config.model_type == 'mpt':
Expand Down
58 changes: 34 additions & 24 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ def _get_model_and_tokenizer(
model: str,
max_seq_len: int,
tie_word_embeddings: bool,
precision: str,
):
if model == 'mpt':
model_cfg = {
Expand All @@ -482,6 +483,7 @@ def _get_model_and_tokenizer(
'attn_config': {
'attn_impl': 'torch',
},
'fc_type': 'te' if precision == 'amp_fp8' else 'torch',
'loss_fn': 'torch_crossentropy',
'tie_word_embeddings': tie_word_embeddings,
}
Expand Down Expand Up @@ -783,8 +785,9 @@ def _assert_checkpoint_equivalence(
)
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None])
@pytest.mark.parametrize(
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints',
[('1ba', '1ba', '1ba', 1, 1)],
'hf_save_interval,save_interval,max_duration,expected_hf_checkpoints,expected_normal_checkpoints,trainer_precision',
[('1ba', '1ba', '1ba', 1, 1, 'amp_bf16'),
('1ba', '1ba', '1ba', 1, 1, 'amp_fp8')],
)
@patch('os.cpu_count', MagicMock(return_value=1))
@patch(
Expand All @@ -801,10 +804,13 @@ def test_huggingface_conversion_callback(
max_duration: str,
expected_hf_checkpoints: int,
expected_normal_checkpoints: int,
trainer_precision: str,
peft_config: Optional[dict],
):
if model == 'mptmoe' and fsdp_state_dict_type is None:
pytest.skip('mptmoe requires FSDP')
if (model == 'neo' or model == 'llama2') and trainer_precision == 'amp_fp8':
pytest.skip('Precision amp_fp8 requires mpt models, not hf models')
delete_transformers_cache()

dist.initialize_dist(get_device('gpu'))
Expand All @@ -825,9 +831,10 @@ def test_huggingface_conversion_callback(

# Get small version of each model
model_cfg, tokenizer_name = _get_model_and_tokenizer(
model,
max_seq_len,
tie_word_embeddings,
model=model,
max_seq_len=max_seq_len,
tie_word_embeddings=tie_word_embeddings,
precision=trainer_precision,
)
assert model_cfg is not None
assert tokenizer_name is not None
Expand Down Expand Up @@ -883,7 +890,7 @@ def test_huggingface_conversion_callback(
trainer = Trainer(
model=original_model,
device='gpu',
precision='amp_bf16',
precision=trainer_precision,
fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None,
train_dataloader=train_dataloader,
save_folder=os.path.join(tmp_path, 'checkpoints'),
Expand All @@ -899,25 +906,28 @@ def test_huggingface_conversion_callback(
_assert_mlflow_logger_calls(mlflow_logger_mock, peft_config)

# summon full params to check equivalence
import transformer_engine.pytorch as te
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
with FSDP.summon_full_params(
trainer.state.model,
writeback=False,
recurse=True,
):
_assert_checkpoint_equivalence(
tmp_path=tmp_path,
expected_normal_checkpoints=expected_normal_checkpoints,
expected_hf_checkpoints=expected_hf_checkpoints,
trainer=trainer,
batches_per_epoch=batches_per_epoch,
original_model=original_model,
precision=precision,
model=model,
tokenizer=tokenizer,
fsdp_state_dict_type=fsdp_state_dict_type,
peft_config=peft_config,
)

with te.onnx_export(True): # Ensure proper ckpting of TE modules.
with FSDP.summon_full_params(
trainer.state.model,
writeback=False,
recurse=True,
):
_assert_checkpoint_equivalence(
tmp_path=tmp_path,
expected_normal_checkpoints=expected_normal_checkpoints,
expected_hf_checkpoints=expected_hf_checkpoints,
trainer=trainer,
batches_per_epoch=batches_per_epoch,
original_model=original_model,
precision=precision,
model=model,
tokenizer=tokenizer,
fsdp_state_dict_type=fsdp_state_dict_type,
peft_config=peft_config,
)

dist.barrier()
delete_transformers_cache()
Expand Down
Loading