Skip to content

Commit

Permalink
Fix passed metadata to mlflow logging (#713)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenfeiy-db authored Nov 15, 2023
1 parent f114dad commit db279d0
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 14 deletions.
14 changes: 7 additions & 7 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,13 @@ def __init__(
if self.mlflow_registered_model_name is not None:
# Both the metadata and the task are needed in order for mlflow
# and databricks optimized model serving to work
if 'metadata' not in mlflow_logging_config:
mlflow_logging_config['metadata'] = {
'task': 'llm/v1/completions'
}
if 'task' not in mlflow_logging_config:
mlflow_logging_config['task'] = 'text-generation'
default_metadata = {'task': 'llm/v1/completions'}
passed_metadata = mlflow_logging_config.get('metadata', {})
mlflow_logging_config['metadata'] = {
**default_metadata,
**passed_metadata
}
mlflow_logging_config.setdefault('task', 'text-generation')
self.mlflow_logging_config = mlflow_logging_config

self.huggingface_folder_name_fstr = os.path.join(
Expand All @@ -93,7 +94,6 @@ def __init__(
self.save_interval = save_interval
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)

self.remote_ud = maybe_create_remote_uploader_downloader_from_uri(
save_folder, loggers=[])
if self.remote_ud is not None:
Expand Down
5 changes: 4 additions & 1 deletion llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def build_icl_data_and_gauntlet(
return icl_evaluators, logger_keys, eval_gauntlet_cb


def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback:
def build_callback(name: str, kwargs: Union[DictConfig, Dict[str,
Any]]) -> Callback:
if name == 'lr_monitor':
return LRMonitor()
elif name == 'memory_monitor':
Expand Down Expand Up @@ -117,6 +118,8 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback:
elif name == 'early_stopper':
return EarlyStopper(**kwargs)
elif name == 'hf_checkpointer':
if isinstance(kwargs, DictConfig):
kwargs = om.to_object(kwargs) # pyright: ignore
return HuggingFaceCheckpointer(**kwargs)
else:
raise ValueError(f'Not sure how to build callback: {name}')
Expand Down
32 changes: 32 additions & 0 deletions tests/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

import pytest
from composer.callbacks import Generate
from omegaconf import OmegaConf as om
from transformers import PreTrainedTokenizerBase

from llmfoundry.callbacks import HuggingFaceCheckpointer
from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper
from llmfoundry.utils.builders import build_callback, build_tokenizer

Expand Down Expand Up @@ -78,3 +80,33 @@ def test_build_generate_callback_unspecified_interval():
'foo': 'bar',
'something': 'else',
})


def test_build_hf_checkpointer_callback():
with mock.patch.object(HuggingFaceCheckpointer,
'__init__') as mock_hf_checkpointer:
mock_hf_checkpointer.return_value = None
save_folder = 'path_to_save_folder'
save_interval = 1
mlflow_logging_config_dict = {
'metadata': {
'databricks_model_family': 'MptForCausalLM',
'databricks_model_size_parameters': '7b',
'databricks_model_source': 'mosaic-fine-tuning',
'task': 'llm/v1/completions'
}
}
build_callback(name='hf_checkpointer',
kwargs=om.create({
'save_folder': save_folder,
'save_interval': save_interval,
'mlflow_logging_config': mlflow_logging_config_dict
}))

assert mock_hf_checkpointer.call_count == 1
_, _, kwargs = mock_hf_checkpointer.mock_calls[0]
assert kwargs['save_folder'] == save_folder
assert kwargs['save_interval'] == save_interval
assert isinstance(kwargs['mlflow_logging_config'], dict)
assert isinstance(kwargs['mlflow_logging_config']['metadata'], dict)
assert kwargs['mlflow_logging_config'] == mlflow_logging_config_dict
33 changes: 27 additions & 6 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import pathlib
import sys
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch

from composer import Trainer
from composer.loggers import MLFlowLogger
Expand Down Expand Up @@ -242,9 +242,22 @@ def get_config(
return cast(DictConfig, test_cfg)


def test_callback_inits_with_defaults():
def test_callback_inits():
# test with defaults
_ = HuggingFaceCheckpointer(save_folder='test', save_interval='1ba')

# test default metatdata when mlflow registered name is given
hf_checkpointer = HuggingFaceCheckpointer(
save_folder='test',
save_interval='1ba',
mlflow_registered_model_name='test_model_name')
assert hf_checkpointer.mlflow_logging_config == {
'task': 'text-generation',
'metadata': {
'task': 'llm/v1/completions'
}
}


@pytest.mark.world_size(2)
@pytest.mark.gpu
Expand Down Expand Up @@ -425,10 +438,18 @@ def test_huggingface_conversion_callback(
trainer.fit()

if dist.get_global_rank() == 0:
assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow
else 0)
assert mlflow_logger_mock.register_model.call_count == (
1 if log_to_mlflow else 0)
if log_to_mlflow:
assert mlflow_logger_mock.save_model.call_count == 1
mlflow_logger_mock.save_model.assert_called_with(
flavor='transformers',
transformers_model=ANY,
path=ANY,
task='text-generation',
metadata={'task': 'llm/v1/completions'})
assert mlflow_logger_mock.register_model.call_count == 1
else:
assert mlflow_logger_mock.save_model.call_count == 0
assert mlflow_logger_mock.register_model.call_count == 0
else:
assert mlflow_logger_mock.log_model.call_count == 0
assert mlflow_logger_mock.register_model.call_count == 0
Expand Down

0 comments on commit db279d0

Please sign in to comment.