Skip to content

Commit

Permalink
Add support for automatically registering models to UC at the end of …
Browse files Browse the repository at this point in the history
…training (#618)
  • Loading branch information
dakinggg authored Oct 17, 2023
1 parent 4fa2dd8 commit cc238a3
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 38 deletions.
149 changes: 118 additions & 31 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@
# SPDX-License-Identifier: Apache-2.0

import contextlib
import json
import copy
import logging
import os
import tempfile
from pathlib import Path
from typing import Optional, Union

import torch
from composer.callbacks.utils import create_interval_scheduler
from composer.core import Callback, Event, State, Time
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger
from composer.loggers import Logger, MLFlowLogger
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.models import HuggingFaceModel
from composer.utils import dist, format_name_with_dist_and_time, parse_uri
from transformers import PreTrainedTokenizerBase
from composer.utils.misc import create_interval_scheduler
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils.huggingface_hub_utils import \
Expand All @@ -39,6 +39,11 @@ class HuggingFaceCheckpointer(Callback):
huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``.
precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``.
overwrite (bool): Whether to overwrite previous checkpoints.
mlflow_registered_model_name (Optional[str]): The name to register the model under in the MLflow model registry. If ``None``, the model will not
be registered. Default is ``None``.
mlflow_logging_config (Optional[dict]): A dictionary of config arguments that will get passed along to the MLflow ``save_model`` call.
Expected to contain ``metadata`` and ``task`` keys. If either is unspecified, the defaults are ``'text-generation'`` and
``{'task': 'llm/v1/completions'}`` respectively.
"""

def __init__(
Expand All @@ -48,6 +53,8 @@ def __init__(
huggingface_folder_name: str = 'ba{batch}',
precision: str = 'float32',
overwrite: bool = False,
mlflow_registered_model_name: Optional[str] = None,
mlflow_logging_config: Optional[dict] = None,
):
self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
save_folder)
Expand All @@ -58,6 +65,22 @@ def __init__(
'float16': torch.float16,
'bfloat16': torch.bfloat16,
}[precision]

# mlflow config setup
self.mlflow_registered_model_name = mlflow_registered_model_name
if mlflow_logging_config is None:
mlflow_logging_config = {}
if self.mlflow_registered_model_name is not None:
# Both the metadata and the task are needed in order for mlflow
# and databricks optimized model serving to work
if 'metadata' not in mlflow_logging_config:
mlflow_logging_config['metadata'] = {
'task': 'llm/v1/completions'
}
if 'task' not in mlflow_logging_config:
mlflow_logging_config['task'] = 'text-generation'
self.mlflow_logging_config = mlflow_logging_config

self.huggingface_folder_name_fstr = os.path.join(
'huggingface', huggingface_folder_name)
self.check_interval = create_interval_scheduler(
Expand All @@ -71,6 +94,7 @@ def __init__(
self.remote_ud = None

self.last_checkpoint_batch: Optional[Time] = None
self.mlflow_loggers = []

def run_event(self, event: Event, state: State, logger: Logger) -> None:
# The interval scheduler handles only returning True for the appropriate events
Expand All @@ -87,6 +111,23 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
self.remote_ud.init(state, logger)
state.callbacks.append(self.remote_ud)

if self.mlflow_registered_model_name is not None:
self.mlflow_loggers = [
logger_destination
for logger_destination in logger.destinations
if isinstance(logger_destination, MLFlowLogger)
]
if len(self.mlflow_loggers) == 0:
raise ValueError(
f'`mlflow_registered_model_name` was set, but no `MLFlowLogger` was found in the `logger.destinations` list. '
+
'Please add an `MLFlowLogger` or set `mlflow_registered_model_name` to `None`.'
)

import mlflow
mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set(
'5GB')

def _save_checkpoint(self, state: State, logger: Logger):
del logger # unused

Expand All @@ -99,8 +140,6 @@ def _save_checkpoint(self, state: State, logger: Logger):
MPTConfig.register_for_auto_class()
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')

assert isinstance(state.model, HuggingFaceModel)

save_dir = format_name_with_dist_and_time(
str(
Path(self.save_dir_format_str) /
Expand All @@ -114,44 +153,65 @@ def _save_checkpoint(self, state: State, logger: Logger):
assert isinstance(temp_save_dir,
str) # pyright doesn't know about enter_result

with fsdp_state_dict_type_context(state.model.model,
state_dict_type='full'):
state_dict = state.model.model.state_dict()
log.debug('Gathering state dict')
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

if state.is_model_ddp:
original_model: PreTrainedModel = state.model.module.model
state_dict_model = state.model.module.model
original_tokenizer = state.model.module.tokenizer
elif isinstance(state.model.model, FSDP):
original_model: PreTrainedModel = state.model.model.module
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer
else:
original_model: PreTrainedModel = state.model.model
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer

state_dict_context = fsdp_state_dict_type_context(
original_model, state_dict_type='full') if (
(not state.is_model_ddp) and isinstance(
state_dict_model, FSDP)) else contextlib.nullcontext()

with state_dict_context:
state_dict = state_dict_model.state_dict()

# convert the state dict to the requested precision
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)

if dist.get_global_rank() == 0:
# We raise above if the model is not a HuggingFaceModel, so this assert is safe
assert hasattr(state.model.model, 'save_pretrained')
state.model.model.save_pretrained(temp_save_dir,
state_dict=state_dict)

if state.model.tokenizer is not None:
assert isinstance(state.model.tokenizer,
log.debug('Saving Hugging Face checkpoint to disk')

copied_config = copy.deepcopy(original_model.config)
if copied_config.model_type == 'mpt':
copied_config.attn_config['attn_impl'] = 'torch'
copied_config.init_device = 'cpu'

# TODO: after torch 2.1, we can load a state dict into a meta model
# and skip the extra model init
log.debug(f'Creating new model instance')
new_model_instance = type(original_model)(copied_config)
new_model_instance.to(dtype=self.dtype)
new_model_instance.load_state_dict(state_dict)
del state_dict

log.debug('Saving Hugging Face checkpoint to disk')
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(original_tokenizer,
PreTrainedTokenizerBase)
state.model.tokenizer.save_pretrained(temp_save_dir)
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if state.model.model.config.model_type == 'mpt':
if original_model.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(temp_save_dir)

with open(os.path.join(temp_save_dir, 'config.json'), 'r') as f:
edited_config = json.load(f)

if state.model.model.config.model_type == 'mpt':
edited_config['attn_config']['attn_impl'] = 'torch'
edited_config['init_device'] = 'cpu'

edited_config['torch_dtype'] = self.precision
with open(os.path.join(temp_save_dir, 'config.json'), 'w') as f:
json.dump(edited_config, f, indent=4)

if self.upload_to_object_store:
assert self.remote_ud is not None
# TODO change to log after other pr
log.info(
f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}'
)
Expand All @@ -164,4 +224,31 @@ def _save_checkpoint(self, state: State, logger: Logger):
overwrite=self.overwrite,
)

dist.barrier()
elapsed_duration = state.get_elapsed_duration()
if self.mlflow_registered_model_name is not None and elapsed_duration is not None and elapsed_duration >= 1.0:
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer

log.debug('Logging Hugging Face model to MLFlow')
for i, mlflow_logger in enumerate(self.mlflow_loggers):
log.debug(
f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}'
)
local_save_path = str(
Path(temp_save_dir) / f'mlflow_save_{i}')

# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
mlflow_logger.save_model(
flavor='transformers',
transformers_model=components,
path=local_save_path,
**self.mlflow_logging_config,
)
mlflow_logger.register_model(
model_uri=local_save_path,
name=self.mlflow_registered_model_name,
await_registration_for=3600,
)
6 changes: 6 additions & 0 deletions llmfoundry/optim/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,19 @@ def _raise_if_units_dont_match(time: Union[str, Time], t_max: Union[str, Time],
time = Time.from_timestring(time)
if isinstance(t_max, str):
t_max = Time.from_timestring(t_max)

assert not isinstance(time, str) and not isinstance(t_max, str)

if time.unit != t_max.unit:
raise ValueError(f'{time.unit=} does not match {t_max.unit=}.')


def _raise_if_units_dur(time: Union[str, Time], name: str) -> None:
if isinstance(time, str):
time = Time.from_timestring(time)

assert not isinstance(time, str)

if time.unit == TimeUnit('dur'):
raise ValueError(f'{name} cannot be in units of "dur".')

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
]

install_requires = [
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.3,<0.17',
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.4,<0.17',
'accelerate>=0.20,<0.21', # for HF inference `device_map`
'transformers>=4.33,<4.34',
'mosaicml-streaming>=0.6,<0.7',
Expand Down
52 changes: 46 additions & 6 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import os
import pathlib
import sys
from unittest.mock import MagicMock

from composer import Trainer
from composer.loggers import MLFlowLogger
from composer.utils import dist, get_device

from llmfoundry.callbacks import HuggingFaceCheckpointer
Expand All @@ -17,7 +19,7 @@
sys.path.append(repo_dir)
import shutil
from argparse import Namespace
from typing import cast
from typing import Optional, cast

import pytest
import torch
Expand Down Expand Up @@ -148,6 +150,23 @@ def check_hf_model_equivalence(model1: PreTrainedModel,
# so we remove it
expected_model_config_dict.pop('_name_or_path')
new_model_config_dict.pop('_name_or_path')

# Special case a couple of differences that correctly occur when saving MPT to huggingface format
# checkpoint
architectures_1 = expected_model_config_dict.pop('architectures', None)
architectures_2 = new_model_config_dict.pop('architectures', None)
if architectures_1 != architectures_2:
assert architectures_1 is None and architectures_2 == ['MPTForCausalLM']

auto_map_1 = expected_model_config_dict.pop('auto_map', None)
auto_map_2 = new_model_config_dict.pop('auto_map', None)
if auto_map_1 != auto_map_2:
assert auto_map_1 == {'AutoConfig': 'configuration_mpt.MPTConfig'}
assert auto_map_2 == {
'AutoConfig': 'configuration_mpt.MPTConfig',
'AutoModelForCausalLM': 'modeling_mpt.MPTForCausalLM'
}

assert expected_model_config_dict == new_model_config_dict
assert all(
torch.equal(p1.cpu(), p2.cpu())
Expand Down Expand Up @@ -183,9 +202,11 @@ def test_callback_inits_with_defaults():
@pytest.mark.world_size(2)
@pytest.mark.gpu
@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2'])
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded'])
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None])
@pytest.mark.parametrize('log_to_mlflow', [True, False])
def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
fsdp_state_dict_type: str):
fsdp_state_dict_type: Optional[str],
log_to_mlflow: bool):
delete_transformers_cache()

dist.initialize_dist(get_device('gpu'))
Expand All @@ -203,6 +224,8 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=f'{huggingface_save_interval_batches}ba',
precision=precision_str,
mlflow_registered_model_name='dummy-registered-name'
if log_to_mlflow else None,
)

# get small version of each model
Expand Down Expand Up @@ -324,20 +347,35 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
optimizer = build_optimizer(original_model, optimizer_name,
optimizer_config)

mlflow_logger_mock = MagicMock(spec=MLFlowLogger)
mlflow_logger_mock.state_dict = lambda *args, **kwargs: {}
mlflow_logger_mock.save_model = MagicMock()
mlflow_logger_mock.register_model = MagicMock()
mlflow_logger_mock.model_registry_prefix = ''
trainer = Trainer(
model=original_model,
device='gpu',
fsdp_config=fsdp_config,
fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None,
train_dataloader=train_dataloader,
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=f'{save_interval_batches}ba',
max_duration=f'{max_duration_batches}ba',
callbacks=[checkpointer_callback],
loggers=[mlflow_logger_mock] if log_to_mlflow else [],
optimizers=optimizer,
save_latest_filename=None,
)
trainer.fit()

if dist.get_global_rank() == 0:
assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow
else 0)
assert mlflow_logger_mock.register_model.call_count == (
1 if log_to_mlflow else 0)
else:
assert mlflow_logger_mock.log_model.call_count == 0
assert mlflow_logger_mock.register_model.call_count == 0

# summon full params to check equivalence
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
with FSDP.summon_full_params(trainer.state.model,
Expand Down Expand Up @@ -390,8 +428,10 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
trust_remote_code=True,
)

check_hf_model_equivalence(trainer.state.model.model.to(precision),
loaded_model)
check_hf_model_equivalence(
trainer.state.model.model.to(precision) if fsdp_state_dict_type
is not None else trainer.state.model.module.model.to(precision),
loaded_model)
check_hf_tokenizer_equivalence(tokenizer, loaded_tokenizer)

delete_transformers_cache()
Expand Down

0 comments on commit cc238a3

Please sign in to comment.