diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 63c9b9ca0d..464e1fd755 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -17,8 +17,7 @@ from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader from composer.models import HuggingFaceModel from composer.utils import dist, format_name_with_dist_and_time, parse_uri -from transformers import (PreTrainedModel, - PreTrainedTokenizerBase) +from transformers import PreTrainedModel, PreTrainedTokenizerBase from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM from llmfoundry.utils.huggingface_hub_utils import \ @@ -73,7 +72,9 @@ def __init__( if mlflow_logging_config is None: mlflow_logging_config = {} if 'metadata' not in mlflow_logging_config: - mlflow_logging_config['metadata'] = {'task': 'llm/v1/completions'} + mlflow_logging_config['metadata'] = { + 'task': 'llm/v1/completions' + } if 'task' not in mlflow_logging_config: mlflow_logging_config['task'] = 'text-generation' self.mlflow_logging_config = mlflow_logging_config