Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Sep 30, 2023
1 parent 95c93cd commit 1dfbaac
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.models import HuggingFaceModel
from composer.utils import dist, format_name_with_dist_and_time, parse_uri
from transformers import PreTrainedModel, PreTrainedTokenizerBase, PretrainedConfig
from transformers import (PreTrainedModel,
PreTrainedTokenizerBase)

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils.huggingface_hub_utils import \
Expand Down Expand Up @@ -151,18 +152,15 @@ def _save_checkpoint(self, state: State, logger: Logger):
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

if state.is_model_ddp:
original_model = state.model.module.model
original_model: PreTrainedModel = state.model.module.model
original_tokenizer = state.model.module.tokenizer
elif isinstance(state.model.model, FSDP):
original_model = state.model.model.module
original_model: PreTrainedModel = state.model.model.module
original_tokenizer = state.model.tokenizer
else:
original_model = state.model.model
original_model: PreTrainedModel = state.model.model
original_tokenizer = state.model.tokenizer

assert isinstance(original_model, PreTrainedModel)
assert isinstance(original_tokenizer, PreTrainedTokenizerBase)

state_dict_context = fsdp_state_dict_type_context(
original_model, state_dict_type='full') if (
(not state.is_model_ddp) and isinstance(
Expand All @@ -179,7 +177,6 @@ def _save_checkpoint(self, state: State, logger: Logger):
log.debug('Saving Hugging Face checkpoint to disk')

copied_config = copy.deepcopy(original_model.config)
assert isinstance(copied_config, PretrainedConfig)
if copied_config.model_type == 'mpt':
copied_config.attn_config['attn_impl'] = 'torch'
copied_config.init_device = 'cpu'
Expand Down

0 comments on commit 1dfbaac

Please sign in to comment.