Skip to content

Commit

Permalink
Add a callback to write huggingface checkpoints during the training r…
Browse files Browse the repository at this point in the history
…un (#594)
  • Loading branch information
dakinggg authored Sep 14, 2023
1 parent 3f3e998 commit 30544f0
Show file tree
Hide file tree
Showing 7 changed files with 596 additions and 67 deletions.
13 changes: 10 additions & 3 deletions llmfoundry/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet
from llmfoundry.callbacks.fdiff_callback import FDiffMetrics
from llmfoundry.callbacks.generate_callback import Generate
from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer
from llmfoundry.callbacks.model_gauntlet_callback import ModelGauntlet
from llmfoundry.callbacks.monolithic_ckpt_callback import \
MonolithicCheckpointSaver
Expand All @@ -18,7 +19,13 @@
) from e

__all__ = [
'FDiffMetrics', 'Generate', 'MonolithicCheckpointSaver', 'GlobalLRScaling',
'LayerFreezing', 'ScheduledGarbageCollector', 'EvalGauntlet',
'ModelGauntlet'
'FDiffMetrics',
'Generate',
'MonolithicCheckpointSaver',
'GlobalLRScaling',
'LayerFreezing',
'ScheduledGarbageCollector',
'EvalGauntlet',
'ModelGauntlet',
'HuggingFaceCheckpointer',
]
167 changes: 167 additions & 0 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import contextlib
import json
import logging
import os
import tempfile
from pathlib import Path
from typing import Optional, Union

import torch
from composer.callbacks.utils import create_interval_scheduler
from composer.core import Callback, Event, State, Time
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger
from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.models import HuggingFaceModel
from composer.utils import dist, format_name_with_dist_and_time, parse_uri
from transformers import PreTrainedTokenizerBase

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
from llmfoundry.utils.huggingface_hub_utils import \
edit_files_for_hf_compatibility

log = logging.getLogger(__name__)


class HuggingFaceCheckpointer(Callback):
"""Save a huggingface formatted checkpoint during training.
Args:
save_folder (str): Top level folder to save checkpoints to (can be a URI). It is likely that
this would be the same as your save_folder.
save_interval: Union[str, int, Time]: The interval describing how often checkpoints should be
saved. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
huggingface_folder_name (str): Folder to save each checkpoint under (can be a format string). Default is ``ba{batch}``.
precision: The precision to save the model in. Default is ``float32``. Options are ``bfloat16``, ``float16``, or ``float32``.
overwrite (bool): Whether to overwrite previous checkpoints.
"""

def __init__(
self,
save_folder: str,
save_interval: Union[str, int, Time],
huggingface_folder_name: str = 'ba{batch}',
precision: str = 'fp32',
overwrite: bool = False,
):
self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
save_folder)
self.overwrite = overwrite
self.precision = precision
self.dtype = {
'float32': torch.float32,
'float16': torch.float16,
'bfloat16': torch.bfloat16,
}[precision]
self.huggingface_folder_name_fstr = os.path.join(
'huggingface', huggingface_folder_name)
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)
self.upload_to_object_store = (self.backend != '')
if self.upload_to_object_store:
self.remote_ud = RemoteUploaderDownloader(
bucket_uri=f'{self.backend}://{self.bucket_name}',
num_concurrent_uploads=4)
else:
self.remote_ud = None

self.last_checkpoint_batch: Optional[Time] = None

def run_event(self, event: Event, state: State, logger: Logger) -> None:
# The interval scheduler handles only returning True for the appropriate events
if state.get_elapsed_duration() is not None and self.check_interval(
state,
event) and self.last_checkpoint_batch != state.timestamp.batch:
self._save_checkpoint(state, logger)
elif event == Event.INIT:
if not isinstance(state.model, HuggingFaceModel):
raise ValueError(
f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. '
+ f'Got {type(state.model)} instead.')
if self.upload_to_object_store and self.remote_ud is not None:
self.remote_ud.init(state, logger)
state.callbacks.append(self.remote_ud)

def _save_checkpoint(self, state: State, logger: Logger):
del logger # unused

self.last_checkpoint_batch = state.timestamp.batch

log.info('Saving HuggingFace formatted checkpoint')

from transformers.models.auto.configuration_auto import CONFIG_MAPPING
CONFIG_MAPPING._extra_content['mpt'] = MPTConfig
MPTConfig.register_for_auto_class()
MPTForCausalLM.register_for_auto_class('AutoModelForCausalLM')

assert isinstance(state.model, HuggingFaceModel)

save_dir = format_name_with_dist_and_time(
str(
Path(self.save_dir_format_str) /
self.huggingface_folder_name_fstr), state.run_name,
state.timestamp)
dir_context_mgr = tempfile.TemporaryDirectory(
) if self.upload_to_object_store else contextlib.nullcontext(
enter_result=save_dir)

with dir_context_mgr as temp_save_dir:
assert isinstance(temp_save_dir,
str) # pyright doesn't know about enter_result

with fsdp_state_dict_type_context(state.model.model,
state_dict_type='full'):
state_dict = state.model.model.state_dict()

# convert the state dict to the requested precision
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)

if dist.get_global_rank() == 0:
# We raise above if the model is not a HuggingFaceModel, so this assert is safe
assert hasattr(state.model.model, 'save_pretrained')
state.model.model.save_pretrained(temp_save_dir,
state_dict=state_dict)

if state.model.tokenizer is not None:
assert isinstance(state.model.tokenizer,
PreTrainedTokenizerBase)
state.model.tokenizer.save_pretrained(temp_save_dir)

# Only need to edit files for MPT because it has custom code
if state.model.model.config.model_type == 'mpt':
edit_files_for_hf_compatibility(temp_save_dir)

with open(os.path.join(temp_save_dir, 'config.json'), 'r') as f:
edited_config = json.load(f)

if state.model.model.config.model_type == 'mpt':
edited_config['attn_config']['attn_impl'] = 'torch'
edited_config['init_device'] = 'cpu'

edited_config['torch_dtype'] = self.precision
with open(os.path.join(temp_save_dir, 'config.json'), 'w') as f:
json.dump(edited_config, f, indent=4)

if self.upload_to_object_store:
assert self.remote_ud is not None
# TODO change to log after other pr
log.info(
f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}'
)
for filename in os.listdir(temp_save_dir):
self.remote_ud.upload_file(
state=state,
remote_file_name=os.path.join(save_dir, filename),
file_path=Path(os.path.join(temp_save_dir,
filename)),
overwrite=self.overwrite,
)

dist.barrier()
6 changes: 4 additions & 2 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from llmfoundry.callbacks import (EvalGauntlet, FDiffMetrics, Generate,
GlobalLRScaling, LayerFreezing,
MonolithicCheckpointSaver,
GlobalLRScaling, HuggingFaceCheckpointer,
LayerFreezing, MonolithicCheckpointSaver,
ScheduledGarbageCollector)
from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion,
DecoupledLionW, DecoupledLionW_8bit)
Expand Down Expand Up @@ -99,6 +99,8 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback:
return ScheduledGarbageCollector(**kwargs)
elif name == 'early_stopper':
return EarlyStopper(**kwargs)
elif name == 'hf_checkpointer':
return HuggingFaceCheckpointer(**kwargs)
else:
raise ValueError(f'Not sure how to build callback: {name}')

Expand Down
2 changes: 2 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
67 changes: 67 additions & 0 deletions tests/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import json
import os
from typing import Optional


def make_tiny_ft_dataset(
path: str,
size: int = 4,
add_bad_data_dropped: bool = False,
add_bad_data_error: bool = False,
add_just_bos_eos_pad: bool = False,
pad_token: Optional[str] = None,
start_token: Optional[str] = None,
end_token: Optional[str] = None,
):
good_sample = {'prompt': 'hello', 'response': 'goodbye'}
samples = [good_sample] * size
if add_bad_data_dropped:
if pad_token is None:
raise ValueError(
'pad_token, start_token, and end_token must be specified if add_bad_data is True'
)
# empty prompt
samples.append({'prompt': '', 'response': 'goodbye'})
# empty response
samples.append({'prompt': 'hello', 'response': ''})
# response just pad
samples.append({'prompt': 'hello', 'response': pad_token})
# response just pad multiple times
samples.append({'prompt': 'hello', 'response': pad_token * 3})

if add_bad_data_error:
# prompt just None
samples.append({
'prompt': None,
'response': 'goodbye'
}) # type: ignore (intentional test)
# response just None
samples.append({
'prompt': 'hello',
'response': None
}) # type: ignore (intentional test)

if add_just_bos_eos_pad:
if pad_token is None or start_token is None or end_token is None:
raise ValueError(
'pad_token, start_token, and end_token must be specified if add_just_bos_eos is True'
)
# prompt just start
samples.append({'prompt': start_token, 'response': 'goodbye'})
# response just start
samples.append({'prompt': 'hello', 'response': start_token})
# prompt just end
samples.append({'prompt': end_token, 'response': 'goodbye'})
# response just end
samples.append({'prompt': 'hello', 'response': end_token})
# prompt just pad
samples.append({'prompt': pad_token, 'response': 'goodbye'})

os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'w') as _f:
for sample in samples:
_f.write(json.dumps(sample))
_f.write('\n')
63 changes: 1 addition & 62 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
import contextlib
import json
import os
import pathlib
import shutil
Expand All @@ -25,6 +24,7 @@
repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(repo_dir)
from scripts.data_prep.convert_dataset_hf import main as main_hf
from tests.data_utils import make_tiny_ft_dataset


def get_config(conf_path: str = 'yamls/mpt/125m.yaml'):
Expand Down Expand Up @@ -279,67 +279,6 @@ def test_finetuning_dataloader(decoder_only_format: bool,
break


def make_tiny_ft_dataset(
path: str,
size: int = 4,
add_bad_data_dropped: bool = False,
add_bad_data_error: bool = False,
add_just_bos_eos_pad: bool = False,
pad_token: Optional[str] = None,
start_token: Optional[str] = None,
end_token: Optional[str] = None,
):
good_sample = {'prompt': 'hello', 'response': 'goodbye'}
samples = [good_sample] * size
if add_bad_data_dropped:
if pad_token is None:
raise ValueError(
'pad_token, start_token, and end_token must be specified if add_bad_data is True'
)
# empty prompt
samples.append({'prompt': '', 'response': 'goodbye'})
# empty response
samples.append({'prompt': 'hello', 'response': ''})
# response just pad
samples.append({'prompt': 'hello', 'response': pad_token})
# response just pad multiple times
samples.append({'prompt': 'hello', 'response': pad_token * 3})

if add_bad_data_error:
# prompt just None
samples.append({
'prompt': None,
'response': 'goodbye'
}) # type: ignore (intentional test)
# response just None
samples.append({
'prompt': 'hello',
'response': None
}) # type: ignore (intentional test)

if add_just_bos_eos_pad:
if pad_token is None or start_token is None or end_token is None:
raise ValueError(
'pad_token, start_token, and end_token must be specified if add_just_bos_eos is True'
)
# prompt just start
samples.append({'prompt': start_token, 'response': 'goodbye'})
# response just start
samples.append({'prompt': 'hello', 'response': start_token})
# prompt just end
samples.append({'prompt': end_token, 'response': 'goodbye'})
# response just end
samples.append({'prompt': 'hello', 'response': end_token})
# prompt just pad
samples.append({'prompt': pad_token, 'response': 'goodbye'})

os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'w') as _f:
for sample in samples:
_f.write(json.dumps(sample))
_f.write('\n')


@pytest.mark.world_size(2)
@pytest.mark.parametrize('dataset_size', [4, 8])
@pytest.mark.parametrize('device_batch_size', [2, 4])
Expand Down
Loading

0 comments on commit 30544f0

Please sign in to comment.