From ed4eaf49b6170e2987a017d80093d319a55dd04c Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 29 Sep 2023 16:51:26 -0700 Subject: [PATCH] precommit --- llmfoundry/callbacks/hf_checkpointer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 8d96128c93..c3da99bbcf 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -17,7 +17,7 @@ from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader from composer.models import HuggingFaceModel from composer.utils import dist, format_name_with_dist_and_time, parse_uri -from transformers import PreTrainedTokenizerBase, PreTrainedModel +from transformers import PreTrainedModel, PreTrainedTokenizerBase from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils.huggingface_hub_utils import \ @@ -177,13 +177,13 @@ def _save_checkpoint(self, state: State, logger: Logger): # TODO: after torch 2.1, we can load a state dict into a meta model # and skip the extra model init log.debug(f'Creating new model instance') - new_model_instance = type(model_class)(copied_config) + new_model_instance: PreTrainedModel = type(model_class)( + copied_config) new_model_instance.to(dtype=self.dtype) new_model_instance.load_state_dict(state_dict) del state_dict log.debug('Saving Hugging Face checkpoint to disk') - assert isinstance(new_model_instance, PreTrainedModel) new_model_instance.save_pretrained(temp_save_dir) if state.model.tokenizer is not None: assert isinstance(state.model.tokenizer,