Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Sep 29, 2023
1 parent 33f21a9 commit ed4eaf4
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.models import HuggingFaceModel
from composer.utils import dist, format_name_with_dist_and_time, parse_uri
from transformers import PreTrainedTokenizerBase, PreTrainedModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils.huggingface_hub_utils import \
Expand Down Expand Up @@ -177,13 +177,13 @@ def _save_checkpoint(self, state: State, logger: Logger):
# TODO: after torch 2.1, we can load a state dict into a meta model
# and skip the extra model init
log.debug(f'Creating new model instance')
new_model_instance = type(model_class)(copied_config)
new_model_instance: PreTrainedModel = type(model_class)(
copied_config)
new_model_instance.to(dtype=self.dtype)
new_model_instance.load_state_dict(state_dict)
del state_dict

log.debug('Saving Hugging Face checkpoint to disk')
assert isinstance(new_model_instance, PreTrainedModel)
new_model_instance.save_pretrained(temp_save_dir)
if state.model.tokenizer is not None:
assert isinstance(state.model.tokenizer,
Expand Down

0 comments on commit ed4eaf4

Please sign in to comment.