From 95c93cd01baa0456c5f1ae6a465c432487c1470c Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 29 Sep 2023 17:44:27 -0700 Subject: [PATCH] precommit --- llmfoundry/callbacks/hf_checkpointer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index c00f6d266f..86444266ca 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -17,7 +17,7 @@ from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader from composer.models import HuggingFaceModel from composer.utils import dist, format_name_with_dist_and_time, parse_uri -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import PreTrainedModel, PreTrainedTokenizerBase, PretrainedConfig from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils.huggingface_hub_utils import \ @@ -158,7 +158,7 @@ def _save_checkpoint(self, state: State, logger: Logger): original_tokenizer = state.model.tokenizer else: original_model = state.model.model - original_tokenizer = state.model.tokenizers + original_tokenizer = state.model.tokenizer assert isinstance(original_model, PreTrainedModel) assert isinstance(original_tokenizer, PreTrainedTokenizerBase) @@ -179,7 +179,8 @@ def _save_checkpoint(self, state: State, logger: Logger): log.debug('Saving Hugging Face checkpoint to disk') copied_config = copy.deepcopy(original_model.config) - if original_model.config.model_type == 'mpt': + assert isinstance(copied_config, PretrainedConfig) + if copied_config.model_type == 'mpt': copied_config.attn_config['attn_impl'] = 'torch' copied_config.init_device = 'cpu'