From 1dfbaac584d5962c12e29e561c66f5da840d1fbc Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 29 Sep 2023 17:46:14 -0700 Subject: [PATCH] precommit --- llmfoundry/callbacks/hf_checkpointer.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 86444266ca..b44859e15a 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -17,7 +17,8 @@ from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader from composer.models import HuggingFaceModel from composer.utils import dist, format_name_with_dist_and_time, parse_uri -from transformers import PreTrainedModel, PreTrainedTokenizerBase, PretrainedConfig +from transformers import (PreTrainedModel, + PreTrainedTokenizerBase) from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils.huggingface_hub_utils import \ @@ -151,18 +152,15 @@ def _save_checkpoint(self, state: State, logger: Logger): from torch.distributed.fsdp import FullyShardedDataParallel as FSDP if state.is_model_ddp: - original_model = state.model.module.model + original_model: PreTrainedModel = state.model.module.model original_tokenizer = state.model.module.tokenizer elif isinstance(state.model.model, FSDP): - original_model = state.model.model.module + original_model: PreTrainedModel = state.model.model.module original_tokenizer = state.model.tokenizer else: - original_model = state.model.model + original_model: PreTrainedModel = state.model.model original_tokenizer = state.model.tokenizer - assert isinstance(original_model, PreTrainedModel) - assert isinstance(original_tokenizer, PreTrainedTokenizerBase) - state_dict_context = fsdp_state_dict_type_context( original_model, state_dict_type='full') if ( (not state.is_model_ddp) and isinstance( @@ -179,7 +177,6 @@ def _save_checkpoint(self, state: State, logger: Logger): log.debug('Saving Hugging Face checkpoint to disk') copied_config = copy.deepcopy(original_model.config) - assert isinstance(copied_config, PretrainedConfig) if copied_config.model_type == 'mpt': copied_config.attn_config['attn_impl'] = 'torch' copied_config.init_device = 'cpu'