Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Sep 29, 2023
1 parent bcfb534 commit 7a6ae1d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
4 changes: 2 additions & 2 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,10 @@ def _save_checkpoint(self, state: State, logger: Logger):
del state_dict

log.debug('Saving Hugging Face checkpoint to disk')
new_model_instance.save_pretrained('temp_save_dir')
new_model_instance.save_pretrained(temp_save_dir)
if state.model.tokenizer is not None:
assert isinstance(state.model.tokenizer, PreTrainedTokenizerBase)
state.model.tokenizer.save_pretrained('temp_save_dir')
state.model.tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if state.model.model.config.model_type == 'mpt':
Expand Down
16 changes: 15 additions & 1 deletion tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from composer.utils import dist, get_device

from llmfoundry.callbacks import HuggingFaceCheckpointer
from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM
from llmfoundry.models.mpt.modeling_mpt import ComposerMPTCausalLM, MPTConfig, MPTForCausalLM

# Add repo root to path so we can import scripts and test it
repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
Expand Down Expand Up @@ -150,6 +150,20 @@ def check_hf_model_equivalence(model1: PreTrainedModel,
# so we remove it
expected_model_config_dict.pop('_name_or_path')
new_model_config_dict.pop('_name_or_path')

# Special case a couple of differences that correctly occur when saving MPT to huggingface format
# checkpoint
architectures_1 = expected_model_config_dict.pop('architectures', None)
architectures_2 = new_model_config_dict.pop('architectures', None)
if architectures_1 != architectures_2:
assert architectures_1 is None and architectures_2 == ['MPTForCausalLM']

auto_map_1 = expected_model_config_dict.pop('auto_map', None)
auto_map_2 = new_model_config_dict.pop('auto_map', None)
if auto_map_1 != auto_map_2:
assert auto_map_1 == {'AutoConfig': 'configuration_mpt.MPTConfig'}
assert auto_map_2 == {'AutoConfig': 'configuration_mpt.MPTConfig', 'AutoModelForCausalLM': 'modeling_mpt.MPTForCausalLM'}

assert expected_model_config_dict == new_model_config_dict
assert all(
torch.equal(p1.cpu(), p2.cpu())
Expand Down

0 comments on commit 7a6ae1d

Please sign in to comment.