From 7a6ae1d18d41d07ea43e8530a529df3e5da1170e Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 29 Sep 2023 22:57:09 +0000 Subject: [PATCH] fixes --- llmfoundry/callbacks/hf_checkpointer.py | 4 ++-- tests/test_hf_conversion_script.py | 16 +++++++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 8e84374828..e1aecfb3ca 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -183,10 +183,10 @@ def _save_checkpoint(self, state: State, logger: Logger): del state_dict log.debug('Saving Hugging Face checkpoint to disk') - new_model_instance.save_pretrained('temp_save_dir') + new_model_instance.save_pretrained(temp_save_dir) if state.model.tokenizer is not None: assert isinstance(state.model.tokenizer, PreTrainedTokenizerBase) - state.model.tokenizer.save_pretrained('temp_save_dir') + state.model.tokenizer.save_pretrained(temp_save_dir) # Only need to edit files for MPT because it has custom code if state.model.model.config.model_type == 'mpt': diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index 53cf447050..30412b8844 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -12,7 +12,7 @@ from composer.utils import dist, get_device from llmfoundry.callbacks import HuggingFaceCheckpointer -from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM +from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM, MPTConfig, MPTForCausalLM # Add repo root to path so we can import scripts and test it repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) @@ -150,6 +150,20 @@ def check_hf_model_equivalence(model1: PreTrainedModel, # so we remove it expected_model_config_dict.pop('_name_or_path') new_model_config_dict.pop('_name_or_path') + + # Special case a couple of differences that correctly occur when saving MPT to huggingface format + # checkpoint + architectures_1 = expected_model_config_dict.pop('architectures', None) + architectures_2 = new_model_config_dict.pop('architectures', None) + if architectures_1 != architectures_2: + assert architectures_1 is None and architectures_2 == ['MPTForCausalLM'] + + auto_map_1 = expected_model_config_dict.pop('auto_map', None) + auto_map_2 = new_model_config_dict.pop('auto_map', None) + if auto_map_1 != auto_map_2: + assert auto_map_1 == {'AutoConfig': 'configuration_mpt.MPTConfig'} + assert auto_map_2 == {'AutoConfig': 'configuration_mpt.MPTConfig', 'AutoModelForCausalLM': 'modeling_mpt.MPTForCausalLM'} + assert expected_model_config_dict == new_model_config_dict assert all( torch.equal(p1.cpu(), p2.cpu())