Skip to content

Commit

Permalink
support ddp
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Sep 30, 2023
1 parent d2f88b7 commit d5e0683
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 29 deletions.
53 changes: 29 additions & 24 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.models import HuggingFaceModel
from composer.utils import dist, format_name_with_dist_and_time, parse_uri
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers import PreTrainedTokenizerBase, PreTrainedModel

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils.huggingface_hub_utils import \
Expand Down Expand Up @@ -134,8 +134,6 @@ def _save_checkpoint(self, state: State, logger: Logger):
MPTConfig.register_for_auto_class()
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')

assert isinstance(state.model, HuggingFaceModel)

save_dir = format_name_with_dist_and_time(
str(
Path(self.save_dir_format_str) /
Expand All @@ -150,9 +148,25 @@ def _save_checkpoint(self, state: State, logger: Logger):
str) # pyright doesn't know about enter_result

log.debug('Gathering state dict')
with fsdp_state_dict_type_context(state.model.model,
state_dict_type='full'):
state_dict = state.model.model.state_dict()
from torch.distributed.fsdp import \
FullyShardedDataParallel as FSDP

if state.is_model_ddp:
original_model = state.model.module.model
original_tokenizer = state.model.module.tokenizer
elif isinstance(state.model.model, FSDP):
original_model = state.model.model.module
original_tokenizer = state.model.tokenizer
else:
original_model = state.model.model
original_tokenizer = state.model.tokenizers

assert isinstance(original_model, PreTrainedModel)
assert isinstance(original_tokenizer, PreTrainedTokenizerBase)

state_dict_context = fsdp_state_dict_type_context(original_model, state_dict_type='full') if ((not state.is_model_ddp) and isinstance(state.model.model, FSDP)) else contextlib.nullcontext()
with state_dict_context:
state_dict = original_model.state_dict()

# convert the state dict to the requested precision
for k, v in state_dict.items():
Expand All @@ -162,36 +176,27 @@ def _save_checkpoint(self, state: State, logger: Logger):
if dist.get_global_rank() == 0:
log.debug('Saving Hugging Face checkpoint to disk')

from torch.distributed.fsdp import \
FullyShardedDataParallel as FSDP
if isinstance(state.model.model, FSDP):
model_class = state.model.model.module
else:
model_class = state.model.model

copied_config = copy.deepcopy(state.model.model.config)
if state.model.model.config.model_type == 'mpt':
copied_config = copy.deepcopy(original_model.config)
if original_model.config.model_type == 'mpt':
copied_config.attn_config['attn_impl'] = 'torch'
copied_config.init_device = 'cpu'

# TODO: after torch 2.1, we can load a state dict into a meta model
# and skip the extra model init
log.debug(f'Creating new model instance')
new_model_instance: PreTrainedModel = type(model_class)(
copied_config)
new_model_instance = type(original_model)(copied_config)
new_model_instance.to(dtype=self.dtype)
new_model_instance.load_state_dict(state_dict)
del state_dict

log.debug('Saving Hugging Face checkpoint to disk')
new_model_instance.save_pretrained(temp_save_dir)
if state.model.tokenizer is not None:
assert isinstance(state.model.tokenizer,
PreTrainedTokenizerBase)
state.model.tokenizer.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(original_tokenizer, PreTrainedTokenizerBase)
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if state.model.model.config.model_type == 'mpt':
if original_model.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(temp_save_dir)

Expand All @@ -212,8 +217,8 @@ def _save_checkpoint(self, state: State, logger: Logger):
elapsed_duration = state.get_elapsed_duration()
if self.log_to_mlflow and elapsed_duration is not None and elapsed_duration >= 1.0:
components = {'model': new_model_instance}
if state.model.tokenizer is not None:
components['tokenizer'] = state.model.tokenizer
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer

log.debug('Logging Hugging Face model to MLFlow')
registered_model_name = f'{state.run_name}_{os.path.basename(save_dir)}'
Expand Down
10 changes: 5 additions & 5 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
sys.path.append(repo_dir)
import shutil
from argparse import Namespace
from typing import cast
from typing import cast, Optional

import pytest
import torch
Expand Down Expand Up @@ -202,10 +202,10 @@ def test_callback_inits_with_defaults():
@pytest.mark.world_size(2)
@pytest.mark.gpu
@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2'])
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded'])
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None])
@pytest.mark.parametrize('log_to_mlflow', [True, False])
def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
fsdp_state_dict_type: str,
fsdp_state_dict_type: Optional[str],
log_to_mlflow: bool):
delete_transformers_cache()

Expand Down Expand Up @@ -354,7 +354,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
trainer = Trainer(
model=original_model,
device='gpu',
fsdp_config=fsdp_config,
fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None,
train_dataloader=train_dataloader,
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=f'{save_interval_batches}ba',
Expand Down Expand Up @@ -427,7 +427,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
trust_remote_code=True,
)

check_hf_model_equivalence(trainer.state.model.model.to(precision),
check_hf_model_equivalence(trainer.state.model.model.to(precision) if fsdp_state_dict_type is not None else trainer.state.model.module.model.to(precision),
loaded_model)
check_hf_tokenizer_equivalence(tokenizer, loaded_tokenizer)

Expand Down

0 comments on commit d5e0683

Please sign in to comment.