From bcfb534d39060aaf5592ea9f26d6bc41bf93cb3e Mon Sep 17 00:00:00 2001 From: Daniel King Date: Fri, 29 Sep 2023 15:21:43 -0700 Subject: [PATCH] precommit --- llmfoundry/callbacks/hf_checkpointer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index ff2924203d..8e84374828 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -3,7 +3,6 @@ import contextlib import copy -import json import logging import os import tempfile @@ -18,7 +17,7 @@ from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader from composer.models import HuggingFaceModel from composer.utils import dist, format_name_with_dist_and_time, parse_uri -from transformers import AutoTokenizer, PreTrainedTokenizerBase +from transformers import PreTrainedTokenizerBase from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils.huggingface_hub_utils import \ @@ -183,9 +182,10 @@ def _save_checkpoint(self, state: State, logger: Logger): new_model_instance.load_state_dict(state_dict) del state_dict - log.debug("Saving Hugging Face checkpoint to disk") + log.debug('Saving Hugging Face checkpoint to disk') new_model_instance.save_pretrained('temp_save_dir') if state.model.tokenizer is not None: + assert isinstance(state.model.tokenizer, PreTrainedTokenizerBase) state.model.tokenizer.save_pretrained('temp_save_dir') # Only need to edit files for MPT because it has custom code