Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for automatically registering models to UC at the end of training #618

Merged
merged 67 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
c13767a
wip
dakinggg Sep 16, 2023
2c0ff34
wip
dakinggg Sep 19, 2023
63dc744
wip
dakinggg Sep 19, 2023
694eca4
small model works
dakinggg Sep 19, 2023
91c1c03
temp comment out
dakinggg Sep 19, 2023
5410557
more logs
dakinggg Sep 19, 2023
3f60498
tweaks
dakinggg Sep 19, 2023
765a599
fix fit end
dakinggg Sep 19, 2023
9095cdb
speedup attempt
dakinggg Sep 19, 2023
cfe7312
fix
dakinggg Sep 19, 2023
f354638
fix
dakinggg Sep 19, 2023
300bda5
fix meta
dakinggg Sep 19, 2023
9f270eb
fix config creation
dakinggg Sep 19, 2023
2f22c17
Merge branch 'main' into mlflow-log-model
dakinggg Sep 27, 2023
7685fbe
add mlflow log model test
dakinggg Sep 27, 2023
c52056b
fix
dakinggg Sep 27, 2023
e6bb71a
fix
dakinggg Sep 27, 2023
cfa730e
fix test
dakinggg Sep 27, 2023
d29fb57
precommit
dakinggg Sep 27, 2023
6180466
precommit
dakinggg Sep 27, 2023
62a0fd6
pyright
dakinggg Sep 27, 2023
1efc4ae
precommit
dakinggg Sep 27, 2023
bbc27d2
pyright
dakinggg Sep 27, 2023
a7132de
Merge branch 'main' into mlflow-log-model
dakinggg Sep 27, 2023
c755ee4
Merge branch 'main' into mlflow-log-model
dakinggg Sep 27, 2023
b55c71d
merge
dakinggg Sep 28, 2023
6715d88
precommit
dakinggg Sep 28, 2023
932ba7f
add logging
dakinggg Sep 28, 2023
ab6d082
no uc
dakinggg Sep 28, 2023
24f3702
update to new save and register
dakinggg Sep 29, 2023
0d1add2
monkeypatch
dakinggg Sep 29, 2023
67ee5bd
fix tests
dakinggg Sep 29, 2023
28436a9
precommit
dakinggg Sep 29, 2023
7cffafe
skip extra model load
dakinggg Sep 29, 2023
aa324c6
undo yaml changes
dakinggg Sep 29, 2023
bcfb534
precommit
dakinggg Sep 29, 2023
7a6ae1d
fixes
dakinggg Sep 29, 2023
7190848
precommit;
dakinggg Sep 29, 2023
5233e24
precommit
dakinggg Sep 29, 2023
33f21a9
precommit
dakinggg Sep 29, 2023
ed4eaf4
precommit
dakinggg Sep 29, 2023
d2f88b7
cleanup
dakinggg Sep 29, 2023
d5e0683
support ddp
dakinggg Sep 30, 2023
53ba514
precommit
dakinggg Sep 30, 2023
95c93cd
precommit
dakinggg Sep 30, 2023
1dfbaac
precommit
dakinggg Sep 30, 2023
c7161f8
Merge branch 'main' into mlflow-log-model
dakinggg Oct 2, 2023
c715832
switch to providing registered name
dakinggg Oct 2, 2023
5b5f039
precommit
dakinggg Oct 2, 2023
15455bc
Merge branch 'main' into mlflow-log-model
dakinggg Oct 4, 2023
b0878b5
Merge branch 'main' into mlflow-log-model
dakinggg Oct 10, 2023
6369b7b
Merge branch 'main' into mlflow-log-model
dakinggg Oct 10, 2023
a36f93b
Merge branch 'main' into mlflow-log-model
dakinggg Oct 11, 2023
d4f7eb3
Merge branch 'main' into mlflow-log-model
dakinggg Oct 11, 2023
499a121
bump composer pin
dakinggg Oct 11, 2023
047b320
pyright
dakinggg Oct 11, 2023
e952250
types
dakinggg Oct 11, 2023
795b9ae
fixes
dakinggg Oct 12, 2023
9af48a2
precommit
dakinggg Oct 12, 2023
4e81938
more precommit
dakinggg Oct 12, 2023
ebfacf6
Merge branch 'main' into mlflow-log-model
dakinggg Oct 12, 2023
0dae7a7
add comment
dakinggg Oct 12, 2023
881e5ff
Merge branch 'main' into mlflow-log-model
dakinggg Oct 12, 2023
bce5089
Merge branch 'main' into mlflow-log-model
dakinggg Oct 12, 2023
354bb11
update import path
dakinggg Oct 17, 2023
af5b1c0
precommit
dakinggg Oct 17, 2023
9255202
Merge branch 'main' into mlflow-log-model
dakinggg Oct 17, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 118 additions & 31 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@
# SPDX-License-Identifier: Apache-2.0

import contextlib
import json
import copy
import logging
import os
import tempfile
from pathlib import Path
from typing import Optional, Union

import torch
from composer.callbacks.utils import create_interval_scheduler
from composer.core import Callback, Event, State, Time
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger
from composer.loggers import Logger, MLFlowLogger
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.models import HuggingFaceModel
from composer.utils import dist, format_name_with_dist_and_time, parse_uri
from transformers import PreTrainedTokenizerBase
from composer.utils.misc import create_interval_scheduler
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils.huggingface_hub_utils import \
Expand All @@ -39,6 +39,11 @@ class HuggingFaceCheckpointer(Callback):
huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``.
precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``.
overwrite (bool): Whether to overwrite previous checkpoints.
mlflow_registered_model_name (Optional[str]): The name to register the model under in the MLflow model registry. If ``None``, the model will not
be registered. Default is ``None``.
mlflow_logging_config (Optional[dict]): A dictionary of config arguments that will get passed along to the MLflow ``save_model`` call.
Expected to contain ``metadata`` and ``task`` keys. If either is unspecified, the defaults are ``'text-generation'`` and
``{'task': 'llm/v1/completions'}`` respectively.
"""

def __init__(
Expand All @@ -48,6 +53,8 @@ def __init__(
huggingface_folder_name: str = 'ba{batch}',
precision: str = 'float32',
overwrite: bool = False,
mlflow_registered_model_name: Optional[str] = None,
mlflow_logging_config: Optional[dict] = None,
):
self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
save_folder)
Expand All @@ -58,6 +65,22 @@ def __init__(
'float16': torch.float16,
'bfloat16': torch.bfloat16,
}[precision]

# mlflow config setup
self.mlflow_registered_model_name = mlflow_registered_model_name
if mlflow_logging_config is None:
mlflow_logging_config = {}
if self.mlflow_registered_model_name is not None:
# Both the metadata and the task are needed in order for mlflow
# and databricks optimized model serving to work
if 'metadata' not in mlflow_logging_config:
mlflow_logging_config['metadata'] = {
'task': 'llm/v1/completions'
}
if 'task' not in mlflow_logging_config:
mlflow_logging_config['task'] = 'text-generation'
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
self.mlflow_logging_config = mlflow_logging_config

self.huggingface_folder_name_fstr = os.path.join(
'huggingface', huggingface_folder_name)
self.check_interval = create_interval_scheduler(
Expand All @@ -71,6 +94,7 @@ def __init__(
self.remote_ud = None

self.last_checkpoint_batch: Optional[Time] = None
self.mlflow_loggers = []

def run_event(self, event: Event, state: State, logger: Logger) -> None:
# The interval scheduler handles only returning True for the appropriate events
Expand All @@ -87,6 +111,23 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
self.remote_ud.init(state, logger)
state.callbacks.append(self.remote_ud)

if self.mlflow_registered_model_name is not None:
self.mlflow_loggers = [
logger_destination
for logger_destination in logger.destinations
if isinstance(logger_destination, MLFlowLogger)
]
if len(self.mlflow_loggers) == 0:
raise ValueError(
f'`mlflow_registered_model_name` was set, but no `MLFlowLogger` was found in the `logger.destinations` list. '
+
'Please add an `MLFlowLogger` or set `mlflow_registered_model_name` to `None`.'
)

import mlflow
mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set(
'5GB')
dakinggg marked this conversation as resolved.
Show resolved Hide resolved

def _save_checkpoint(self, state: State, logger: Logger):
del logger # unused

Expand All @@ -99,8 +140,6 @@ def _save_checkpoint(self, state: State, logger: Logger):
MPTConfig.register_for_auto_class()
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')

assert isinstance(state.model, HuggingFaceModel)

save_dir = format_name_with_dist_and_time(
str(
Path(self.save_dir_format_str) /
Expand All @@ -114,44 +153,65 @@ def _save_checkpoint(self, state: State, logger: Logger):
assert isinstance(temp_save_dir,
str) # pyright doesn't know about enter_result

with fsdp_state_dict_type_context(state.model.model,
state_dict_type='full'):
state_dict = state.model.model.state_dict()
log.debug('Gathering state dict')
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

if state.is_model_ddp:
original_model: PreTrainedModel = state.model.module.model
state_dict_model = state.model.module.model
original_tokenizer = state.model.module.tokenizer
elif isinstance(state.model.model, FSDP):
original_model: PreTrainedModel = state.model.model.module
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer
else:
original_model: PreTrainedModel = state.model.model
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer

state_dict_context = fsdp_state_dict_type_context(
original_model, state_dict_type='full') if (
(not state.is_model_ddp) and isinstance(
state_dict_model, FSDP)) else contextlib.nullcontext()

with state_dict_context:
state_dict = state_dict_model.state_dict()

# convert the state dict to the requested precision
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)

if dist.get_global_rank() == 0:
# We raise above if the model is not a HuggingFaceModel, so this assert is safe
assert hasattr(state.model.model, 'save_pretrained')
state.model.model.save_pretrained(temp_save_dir,
state_dict=state_dict)

if state.model.tokenizer is not None:
assert isinstance(state.model.tokenizer,
log.debug('Saving Hugging Face checkpoint to disk')

copied_config = copy.deepcopy(original_model.config)
if copied_config.model_type == 'mpt':
copied_config.attn_config['attn_impl'] = 'torch'
copied_config.init_device = 'cpu'
dakinggg marked this conversation as resolved.
Show resolved Hide resolved

# TODO: after torch 2.1, we can load a state dict into a meta model
# and skip the extra model init
log.debug(f'Creating new model instance')
new_model_instance = type(original_model)(copied_config)
new_model_instance.to(dtype=self.dtype)
new_model_instance.load_state_dict(state_dict)
del state_dict

log.debug('Saving Hugging Face checkpoint to disk')
new_model_instance.save_pretrained(temp_save_dir)
if original_tokenizer is not None:
assert isinstance(original_tokenizer,
PreTrainedTokenizerBase)
state.model.tokenizer.save_pretrained(temp_save_dir)
original_tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if state.model.model.config.model_type == 'mpt':
if original_model.config.model_type == 'mpt':
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(temp_save_dir)

with open(os.path.join(temp_save_dir, 'config.json'), 'r') as f:
edited_config = json.load(f)

if state.model.model.config.model_type == 'mpt':
edited_config['attn_config']['attn_impl'] = 'torch'
edited_config['init_device'] = 'cpu'

edited_config['torch_dtype'] = self.precision
with open(os.path.join(temp_save_dir, 'config.json'), 'w') as f:
json.dump(edited_config, f, indent=4)

if self.upload_to_object_store:
assert self.remote_ud is not None
# TODO change to log after other pr
log.info(
f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}'
)
Expand All @@ -164,4 +224,31 @@ def _save_checkpoint(self, state: State, logger: Logger):
overwrite=self.overwrite,
)

dist.barrier()
elapsed_duration = state.get_elapsed_duration()
if self.mlflow_registered_model_name is not None and elapsed_duration is not None and elapsed_duration >= 1.0:
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer

log.debug('Logging Hugging Face model to MLFlow')
for i, mlflow_logger in enumerate(self.mlflow_loggers):
log.debug(
f'Registering model to UC at {mlflow_logger.model_registry_prefix}.{self.mlflow_registered_model_name}'
)
local_save_path = str(
Path(temp_save_dir) / f'mlflow_save_{i}')

# TODO: Remove after mlflow fixes the bug that makes this necessary
import mlflow
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: ''
mlflow_logger.save_model(
flavor='transformers',
transformers_model=components,
path=local_save_path,
**self.mlflow_logging_config,
)
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
mlflow_logger.register_model(
model_uri=local_save_path,
name=self.mlflow_registered_model_name,
await_registration_for=3600,
)
6 changes: 6 additions & 0 deletions llmfoundry/optim/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,19 @@ def _raise_if_units_dont_match(time: Union[str, Time], t_max: Union[str, Time],
time = Time.from_timestring(time)
if isinstance(t_max, str):
t_max = Time.from_timestring(t_max)

assert not isinstance(time, str) and not isinstance(t_max, str)

if time.unit != t_max.unit:
raise ValueError(f'{time.unit=} does not match {t_max.unit=}.')


def _raise_if_units_dur(time: Union[str, Time], name: str) -> None:
if isinstance(time, str):
time = Time.from_timestring(time)

assert not isinstance(time, str)

if time.unit == TimeUnit('dur'):
raise ValueError(f'{name} cannot be in units of "dur".')

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
]

install_requires = [
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.3,<0.17',
'mosaicml[libcloud,wandb,mlflow,oci,gcs]>=0.16.4,<0.17',
'accelerate>=0.20,<0.21', # for HF inference `device_map`
'transformers>=4.33,<4.34',
'mosaicml-streaming>=0.6,<0.7',
Expand Down
52 changes: 46 additions & 6 deletions tests/test_hf_conversion_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import os
import pathlib
import sys
from unittest.mock import MagicMock

from composer import Trainer
from composer.loggers import MLFlowLogger
from composer.utils import dist, get_device

from llmfoundry.callbacks import HuggingFaceCheckpointer
Expand All @@ -17,7 +19,7 @@
sys.path.append(repo_dir)
import shutil
from argparse import Namespace
from typing import cast
from typing import Optional, cast

import pytest
import torch
Expand Down Expand Up @@ -148,6 +150,23 @@ def check_hf_model_equivalence(model1: PreTrainedModel,
# so we remove it
expected_model_config_dict.pop('_name_or_path')
new_model_config_dict.pop('_name_or_path')

# Special case a couple of differences that correctly occur when saving MPT to huggingface format
# checkpoint
architectures_1 = expected_model_config_dict.pop('architectures', None)
architectures_2 = new_model_config_dict.pop('architectures', None)
if architectures_1 != architectures_2:
assert architectures_1 is None and architectures_2 == ['MPTForCausalLM']

auto_map_1 = expected_model_config_dict.pop('auto_map', None)
auto_map_2 = new_model_config_dict.pop('auto_map', None)
if auto_map_1 != auto_map_2:
assert auto_map_1 == {'AutoConfig': 'configuration_mpt.MPTConfig'}
assert auto_map_2 == {
'AutoConfig': 'configuration_mpt.MPTConfig',
'AutoModelForCausalLM': 'modeling_mpt.MPTForCausalLM'
}

assert expected_model_config_dict == new_model_config_dict
assert all(
torch.equal(p1.cpu(), p2.cpu())
Expand Down Expand Up @@ -183,9 +202,11 @@ def test_callback_inits_with_defaults():
@pytest.mark.world_size(2)
@pytest.mark.gpu
@pytest.mark.parametrize('model', ['mpt', 'neo', 'llama2'])
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded'])
@pytest.mark.parametrize('fsdp_state_dict_type', ['full', 'sharded', None])
@pytest.mark.parametrize('log_to_mlflow', [True, False])
def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
fsdp_state_dict_type: str):
fsdp_state_dict_type: Optional[str],
log_to_mlflow: bool):
delete_transformers_cache()

dist.initialize_dist(get_device('gpu'))
Expand All @@ -203,6 +224,8 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=f'{huggingface_save_interval_batches}ba',
precision=precision_str,
mlflow_registered_model_name='dummy-registered-name'
if log_to_mlflow else None,
)

# get small version of each model
Expand Down Expand Up @@ -324,20 +347,35 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
optimizer = build_optimizer(original_model, optimizer_name,
optimizer_config)

mlflow_logger_mock = MagicMock(spec=MLFlowLogger)
mlflow_logger_mock.state_dict = lambda *args, **kwargs: {}
mlflow_logger_mock.save_model = MagicMock()
mlflow_logger_mock.register_model = MagicMock()
mlflow_logger_mock.model_registry_prefix = ''
trainer = Trainer(
model=original_model,
device='gpu',
fsdp_config=fsdp_config,
fsdp_config=fsdp_config if fsdp_state_dict_type is not None else None,
train_dataloader=train_dataloader,
save_folder=os.path.join(tmp_path, 'checkpoints'),
save_interval=f'{save_interval_batches}ba',
max_duration=f'{max_duration_batches}ba',
callbacks=[checkpointer_callback],
loggers=[mlflow_logger_mock] if log_to_mlflow else [],
optimizers=optimizer,
save_latest_filename=None,
)
trainer.fit()

if dist.get_global_rank() == 0:
assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow
else 0)
assert mlflow_logger_mock.register_model.call_count == (
1 if log_to_mlflow else 0)
else:
assert mlflow_logger_mock.log_model.call_count == 0
assert mlflow_logger_mock.register_model.call_count == 0

# summon full params to check equivalence
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
with FSDP.summon_full_params(trainer.state.model,
Expand Down Expand Up @@ -390,8 +428,10 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path,
trust_remote_code=True,
)

check_hf_model_equivalence(trainer.state.model.model.to(precision),
loaded_model)
check_hf_model_equivalence(
trainer.state.model.model.to(precision) if fsdp_state_dict_type
is not None else trainer.state.model.module.model.to(precision),
loaded_model)
check_hf_tokenizer_equivalence(tokenizer, loaded_tokenizer)

delete_transformers_cache()
Expand Down
Loading