diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 788a8943b1..c79537c781 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -74,12 +74,13 @@ def __init__( if self.mlflow_registered_model_name is not None: # Both the metadata and the task are needed in order for mlflow # and databricks optimized model serving to work - if 'metadata' not in mlflow_logging_config: - mlflow_logging_config['metadata'] = { - 'task': 'llm/v1/completions' - } - if 'task' not in mlflow_logging_config: - mlflow_logging_config['task'] = 'text-generation' + default_metadata = {'task': 'llm/v1/completions'} + passed_metadata = mlflow_logging_config.get('metadata', {}) + mlflow_logging_config['metadata'] = { + **default_metadata, + **passed_metadata + } + mlflow_logging_config.setdefault('task', 'text-generation') self.mlflow_logging_config = mlflow_logging_config self.huggingface_folder_name_fstr = os.path.join( @@ -93,7 +94,6 @@ def __init__( self.save_interval = save_interval self.check_interval = create_interval_scheduler( save_interval, include_end_of_training=True) - self.remote_ud = maybe_create_remote_uploader_downloader_from_uri( save_folder, loggers=[]) if self.remote_ud is not None: diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 142e714b55..dedf6f5434 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -73,7 +73,8 @@ def build_icl_data_and_gauntlet( return icl_evaluators, logger_keys, eval_gauntlet_cb -def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: +def build_callback(name: str, kwargs: Union[DictConfig, Dict[str, + Any]]) -> Callback: if name == 'lr_monitor': return LRMonitor() elif name == 'memory_monitor': @@ -117,6 +118,8 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: elif name == 'early_stopper': return EarlyStopper(**kwargs) elif name == 'hf_checkpointer': + if isinstance(kwargs, DictConfig): + kwargs = om.to_object(kwargs) # pyright: ignore return HuggingFaceCheckpointer(**kwargs) else: raise ValueError(f'Not sure how to build callback: {name}') diff --git a/tests/test_builders.py b/tests/test_builders.py index 0d24d2154f..237e27b52b 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -6,8 +6,10 @@ import pytest from composer.callbacks import Generate +from omegaconf import OmegaConf as om from transformers import PreTrainedTokenizerBase +from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper from llmfoundry.utils.builders import build_callback, build_tokenizer @@ -78,3 +80,33 @@ def test_build_generate_callback_unspecified_interval(): 'foo': 'bar', 'something': 'else', }) + + +def test_build_hf_checkpointer_callback(): + with mock.patch.object(HuggingFaceCheckpointer, + '__init__') as mock_hf_checkpointer: + mock_hf_checkpointer.return_value = None + save_folder = 'path_to_save_folder' + save_interval = 1 + mlflow_logging_config_dict = { + 'metadata': { + 'databricks_model_family': 'MptForCausalLM', + 'databricks_model_size_parameters': '7b', + 'databricks_model_source': 'mosaic-fine-tuning', + 'task': 'llm/v1/completions' + } + } + build_callback(name='hf_checkpointer', + kwargs=om.create({ + 'save_folder': save_folder, + 'save_interval': save_interval, + 'mlflow_logging_config': mlflow_logging_config_dict + })) + + assert mock_hf_checkpointer.call_count == 1 + _, _, kwargs = mock_hf_checkpointer.mock_calls[0] + assert kwargs['save_folder'] == save_folder + assert kwargs['save_interval'] == save_interval + assert isinstance(kwargs['mlflow_logging_config'], dict) + assert isinstance(kwargs['mlflow_logging_config']['metadata'], dict) + assert kwargs['mlflow_logging_config'] == mlflow_logging_config_dict diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index af94126225..dcb743b536 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -5,7 +5,7 @@ import os import pathlib import sys -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch from composer import Trainer from composer.loggers import MLFlowLogger @@ -242,9 +242,22 @@ def get_config( return cast(DictConfig, test_cfg) -def test_callback_inits_with_defaults(): +def test_callback_inits(): + # test with defaults _ = HuggingFaceCheckpointer(save_folder='test', save_interval='1ba') + # test default metatdata when mlflow registered name is given + hf_checkpointer = HuggingFaceCheckpointer( + save_folder='test', + save_interval='1ba', + mlflow_registered_model_name='test_model_name') + assert hf_checkpointer.mlflow_logging_config == { + 'task': 'text-generation', + 'metadata': { + 'task': 'llm/v1/completions' + } + } + @pytest.mark.world_size(2) @pytest.mark.gpu @@ -425,10 +438,18 @@ def test_huggingface_conversion_callback( trainer.fit() if dist.get_global_rank() == 0: - assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow - else 0) - assert mlflow_logger_mock.register_model.call_count == ( - 1 if log_to_mlflow else 0) + if log_to_mlflow: + assert mlflow_logger_mock.save_model.call_count == 1 + mlflow_logger_mock.save_model.assert_called_with( + flavor='transformers', + transformers_model=ANY, + path=ANY, + task='text-generation', + metadata={'task': 'llm/v1/completions'}) + assert mlflow_logger_mock.register_model.call_count == 1 + else: + assert mlflow_logger_mock.save_model.call_count == 0 + assert mlflow_logger_mock.register_model.call_count == 0 else: assert mlflow_logger_mock.log_model.call_count == 0 assert mlflow_logger_mock.register_model.call_count == 0