diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index f9d8470292..f375fb432b 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -17,7 +17,7 @@ from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader from composer.models import HuggingFaceModel from composer.utils import dist, format_name_with_dist_and_time, parse_uri -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import PreTrainedTokenizerBase, PreTrainedModel from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils.huggingface_hub_utils import \ @@ -134,8 +134,6 @@ def _save_checkpoint(self, state: State, logger: Logger): MPTConfig.register_for_auto_class() MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM') - assert isinstance(state.model, HuggingFaceModel) - save_dir = format_name_with_dist_and_time( str( Path(self.save_dir_format_str) / @@ -150,9 +148,25 @@ def _save_checkpoint(self, state: State, logger: Logger): str) # pyright doesn't know about enter_result log.debug('Gathering state dict') - with fsdp_state_dict_type_context(state.model.model, - state_dict_type='full'): - state_dict = state.model.model.state_dict() + from torch.distributed.fsdp import \ + FullyShardedDataParallel as FSDP + + if state.is_model_ddp: + original_model = state.model.module.model + original_tokenizer = state.model.module.tokenizer + elif isinstance(state.model.model, FSDP): + original_model = state.model.model.module + original_tokenizer = state.model.tokenizer + else: + original_model = state.model.model + original_tokenizer = state.model.tokenizers + + assert isinstance(original_model, PreTrainedModel) + assert isinstance(original_tokenizer, PreTrainedTokenizerBase) + + state_dict_context = fsdp_state_dict_type_context(original_model, state_dict_type='full') if ((not state.is_model_ddp) and isinstance(state.model.model, FSDP)) else contextlib.nullcontext() + with state_dict_context: + state_dict = original_model.state_dict() # convert the state dict to the requested precision for k, v in state_dict.items(): @@ -162,36 +176,27 @@ def _save_checkpoint(self, state: State, logger: Logger): if dist.get_global_rank() == 0: log.debug('Saving Hugging Face checkpoint to disk') - from torch.distributed.fsdp import \ - FullyShardedDataParallel as FSDP - if isinstance(state.model.model, FSDP): - model_class = state.model.model.module - else: - model_class = state.model.model - - copied_config = copy.deepcopy(state.model.model.config) - if state.model.model.config.model_type == 'mpt': + copied_config = copy.deepcopy(original_model.config) + if original_model.config.model_type == 'mpt': copied_config.attn_config['attn_impl'] = 'torch' copied_config.init_device = 'cpu' # TODO: after torch 2.1, we can load a state dict into a meta model # and skip the extra model init log.debug(f'Creating new model instance') - new_model_instance: PreTrainedModel = type(model_class)( - copied_config) + new_model_instance = type(original_model)(copied_config) new_model_instance.to(dtype=self.dtype) new_model_instance.load_state_dict(state_dict) del state_dict log.debug('Saving Hugging Face checkpoint to disk') new_model_instance.save_pretrained(temp_save_dir) - if state.model.tokenizer is not None: - assert isinstance(state.model.tokenizer, - PreTrainedTokenizerBase) - state.model.tokenizer.save_pretrained(temp_save_dir) + if original_tokenizer is not None: + assert isinstance(original_tokenizer, PreTrainedTokenizerBase) + original_tokenizer.save_pretrained(temp_save_dir) # Only need to edit files for MPT because it has custom code - if state.model.model.config.model_type == 'mpt': + if original_model.config.model_type == 'mpt': log.debug('Editing MPT files for HuggingFace compatibility') edit_files_for_hf_compatibility(temp_save_dir) @@ -212,8 +217,8 @@ def _save_checkpoint(self, state: State, logger: Logger): elapsed_duration = state.get_elapsed_duration() if self.log_to_mlflow and elapsed_duration is not None and elapsed_duration >= 1.0: components = {'model': new_model_instance} - if state.model.tokenizer is not None: - components['tokenizer'] = state.model.tokenizer + if original_tokenizer is not None: + components['tokenizer'] = original_tokenizer log.debug('Logging Hugging Face model to MLFlow') registered_model_name = f'{state.run_name}_{os.path.basename(save_dir)}' diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index 62922b8a36..7d31121db4 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -19,7 +19,7 @@ sys.path.append(repo_dir) import shutil from argparse import Namespace -from typing import cast +from typing import cast, Optional import pytest import torch @@ -202,10 +202,10 @@ def test_callback_inits_with_defaults(): @pytest.mark.world_size(2) @pytest.mark.gpu @pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2']) -@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded']) +@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None]) @pytest.mark.parametrize('log_to_mlflow', [True, False]) def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, - fsdp_state_dict_type: str, + fsdp_state_dict_type: Optional[str], log_to_mlflow: bool): delete_transformers_cache() @@ -354,7 +354,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, trainer = Trainer( model=original_model, device='gpu', - fsdp_config=fsdp_config, + fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None, train_dataloader=train_dataloader, save_folder=os.path.join(tmp_path, 'checkpoints'), save_interval=f'{save_interval_batches}ba', @@ -427,7 +427,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, trust_remote_code=True, ) - check_hf_model_equivalence(trainer.state.model.model.to(precision), + check_hf_model_equivalence(trainer.state.model.model.to(precision) if fsdp_state_dict_type is not None else trainer.state.model.module.model.to(precision), loaded_model) check_hf_tokenizer_equivalence(tokenizer, loaded_tokenizer)